import numpy as np
import math
import matplotlib.pyplot as plt

def main():
    plotRows = 4
    plotCols = 8
    mForFirstPlot = 3
    fig, axs = plt.subplots(plotRows, plotCols, gridspec_kw={'hspace': 0.5})
    for ploty in range(plotRows):
        for plotx in range(plotCols):
            m = plotCols * ploty + plotx + mForFirstPlot
            plotModPow(axs[ploty, plotx], m)
    plt.show()
    """
    fig, ax = plt.subplots()
    plotModPow(ax, 25)
    plt.show()
    """

def plotModPow(ax, m):
    ysize = 2*m
    y = np.arange(ysize)
    x = np.arange(m)
    X, Y = np.meshgrid(x, y)
    
    coprime_ns = [n for n in range(m) if math.gcd(n, m) == 1]
    
    Z_coprime = np.zeros((ysize, m))
    Z_notCoprime = np.zeros((ysize, m))
    for r in range(ysize):
        for n in range(m):
            z = pow(n, r, m)
            Z_coprime[r, n] = z if n in coprime_ns else None
            Z_notCoprime[r, n] = None if n in coprime_ns else z
    
    factors = primeFactors(m)
    title = str(m) + (' = ' + '*'.join(str(f) for f in factors) if len(factors) > 1 else ' prime')
    
    # Using a trick from https://stackoverflow.com/a/10800579/925960
    # to show two different color maps depending on coprime vs. not coprime.
    pwargs = {'origin': 'lower'}
    ax.imshow(Z_coprime, cmap=plt.cm.bone, **pwargs)
    ax.imshow(Z_notCoprime, cmap=plt.cm.winter, **pwargs)
    ax.set_title(title)
    ax.set_xticks([0, m - 1])
    ax.set_yticks([0, m, ysize - 1])

# From here: https://stackoverflow.com/a/22808285/925960
def primeFactors(m):
    i = 2
    factors = []
    while i * i <= m:
        if m % i:
            i += 1
        else:
            m //= i
            factors.append(i)
    if m > 1:
        factors.append(m)
    return factors

if __name__== "__main__":
    main()
