AKS Primality Prover, Part 2
October 5, 2012
In the previous exercise we wrote some auxiliary functions needed for the implementation of the AKS primality prover. Today we will implement the AKS algorithm:
AKS Primality Prover: Given an integer n > 1, determine if it is prime or composite.
1. If n = ab for integers a > 0 and b > 1, output COMPOSITE.
2. Find the smallest r such that ordr(n) > (log2 n)2.
3. If 1 < gcd(a, n) < n for some a ≤ r, output COMPOSITE.
4. If n ≤ r, output PRIME.
5. For each a from 1 to floor √φ(r) · log2 n), if (x + a)n &neq; xn + a (mod xr − 1, n), output COMPOSITE.
6. Output PRIME.
Here ordr(n) is the multiplicative order of n modulo r, both logarithms are to the base 2, φ(r) is Euler’s totient function, and the polynomial modular exponentiation is done as in the previous exercise.
Your task is to write a program to prove primality using the AKS algorithm. When you are finished, you are welcome to read or run a suggested solution, or to post your own solution or discuss the exercise in the comments below.
A version in Python. The multiplication of the polynomials and the division by
X^r – 1 is very slow. The division moves all amplitudes to position %r (the
amplitude of X^k goes to X^(k%r)). I made a faster version that uses library
numpy to do the polynomial mutliplication. Then all amplitudes for k >-= r are
moved to the range 0 – r-1.
For my original version the time needed for n=887 was 51 seconds. This version
needs 0.14 seconds
import math import itertools as IT import numpy as NP import fractions def ordr(r, n): for k in IT.count(3): if pow(n, k, r) == 1: return k def isqrt(x): if x < 0: raise ValueError('square root not defined for negative numbers') n = int(x) if n == 0: return 0 a, b = divmod(n.bit_length(), 2) x = 2 ** (a + b) while True: y = (x + n // x) // 2 if y >= x: return x x = y def mmultn(a, b, r, n): """ Dividing by X^r - 1 is equivalent to shifting the amplitude from position k to k - r a and b are vectors of length r maximum convolve them (equivalent to polynomial mult) and add all amplitudes with exp k of r and higher to exp k - r After the multiplication all amplitudes are taken %n """ res = NP.zeros(2 * r, dtype=NP.int64) res[:len(a)+len(b)-1] = NP.convolve(a, b) res = res[:-r] + res[-r:] return res % n def powmodn(pn, n, r, m): res = [1] while n: if n & 1: res = mmultn(res, pn, r, m) n //= 2 if n: pn = mmultn(pn, pn, r, m) return res def testan(a, n, r): pp = powmodn([a, 1], n, r, n) pp[n%r] = (pp[n%r] - 1 ) % n # subtract X^n pp[0] = (pp[0] - a) % n # subtract a return not any(pp) def phi(n): return sum(fractions.gcd(i, n) == 1 for i in xrange(1, n)) def aks(n): for a in xrange(2, isqrt(n) + 1): for b in xrange(2, n): t = a ** b if t == n: return False if t > n: break logn = math.log(n, 2) logn2 = logn ** 2 for r in IT.count(3): if fractions.gcd(r, n) == 1 and ordr(r, n) >= logn2: break for a in xrange(2, r + 1): if 1 < fractions.gcd(a, n) < n: return False if n <= r: return True for a in xrange(1, int(math.sqrt(phi(r)) * logn)): if not testan(a, n, r): return False return TruePaul, great job writing very straightforward code, and truly amazing speed out of Python by using numpy’s convolve. However by putting off the modulo until after the convolve and fold, it does restrict the usable range. For example, aks(190000003) returns False on my 64-bit machine. As I mentioned before, for smaller numbers it’s rocking fast and (more important) very readable complete code.