AKS Primality Prover, Part 2

October 5, 2012

Because of the work we did in the previous exercise, today’s exercise will be simple. We begin with a function to compute the value of r in Step 2, which is done by incrementing r, starting from 3, ignoring those r that are not co-prime to n, until the target condition is met:

(define (compute-r n)
  (let ((target (* 4 (square (log2 n)))))
    (let loop ((r 3))
      (if (not (= (gcd r n) 1)) (loop (+ r 1))
        (if (< target (ord r n)) r (loop (+ r 1)))))))

Now we are ready for the AKS algorithm. Scheme makes it messier than the direct statement of the algorithm because it has no early return, so we have to arrange that every step has a path to the end of the function; hence the long chain of if statements:

(define (aks-prime? n)
  (if (prime-power? n) #f
    (let* ((r (compute-r n))
           (phi-r (phi r))
           (sqrt-phi-r (sqrt phi-r))
           (log2-n (log2 n))
           (sqrt-phi-r-log2-n (* sqrt-phi-r log2-n)))
      (let loop ((a 1))
        (if (<= a r)
            (if (< 1 (gcd a n) n) #f (loop (+ a 1)))
            (if (<= n r) #t
              (let loop ((a 1))
                (if (<= a sqrt-phi-r-log2-n)
                  (if (binomial-test? a r n) (loop (+ a 1)) #f)
                  #t))))))))

The binomial test is provided by an auxiliary function:

(define (binomial-test? a r n)
  (not (equal? (poly-power-mod (list 1 a) n r n)
               (append (list 1)
                       (make-list (- r 1) 0)
                       (list a)))))

Testing with n = 89 is quick because r = 191 is greater than n, so the algorithm reduces to trial division. Proving the primality of 887 takes longer, about fifteen seconds, because r = 389 is less than n, so all the polynomial modular exponentiation tests must be performed:

> (compute-r 89)
191
> (aks-prime? 89)
#t
> (compute-r 887)
389
> (time (aks-prime? 887))
(time (aks-prime? 887))
    687 collections
    13506 ms elapsed cpu time, including 91 ms collecting
    58199 ms elapsed real time, including 345 ms collecting
    2894145264 bytes allocated, including 2894646840 bytes reclaimed
#t

We re-used much code from other sources. Split, make-list, square, log2, expm, and ilog come from the Standard Prelude. Td-prime?, uniq-factors, prime-power? and phi are variants of functions that we have seen in previous exercises. Ord, poly-mult-mod and poly-power-mod come from Part 1 of the AKS exercise. You can see all of the code assembled at http://programmingpraxis.codepad.org/6ZHrsEmx.

Our version of the AKS primality prover has running time O(log12 n), and we made no attempt at tuning the algorithm; for instance, we used the naive O(n2) polynomial multiplication algorithm instead of the faster O(n log n) algorithm, and we iterated over all r instead of prime r in the compute-r function. There are better versions of the AKS algorithm, and much tuning is possible; Daniel Bernstein has an O(log4 n) that he claims is two million times faster than the original. Even so, the AKS algorithm is extremely slow, and it is not used for serious primality proving.

Pages: 1 2

2 Responses to “AKS Primality Prover, Part 2”

  1. Paul said

    A version in Python. The multiplication of the polynomials and the division by
    X^r – 1 is very slow. The division moves all amplitudes to position %r (the
    amplitude of X^k goes to X^(k%r)). I made a faster version that uses library
    numpy to do the polynomial mutliplication. Then all amplitudes for k >-= r are
    moved to the range 0 – r-1.
    For my original version the time needed for n=887 was 51 seconds. This version
    needs 0.14 seconds

    import math
    import itertools as IT
    import numpy as NP
    import fractions
    
    def ordr(r, n):
        for k in IT.count(3):
            if pow(n, k, r) == 1:
                return k
                
    def isqrt(x):
        if x < 0:
            raise ValueError('square root not defined for negative numbers')
        n = int(x)
        if n == 0:
            return 0
        a, b = divmod(n.bit_length(), 2)
        x = 2 ** (a + b)
        while True:
            y = (x + n // x) // 2
            if y >= x:
                return x
            x = y
            
    def mmultn(a, b, r, n):
        """ Dividing by X^r - 1 is equivalent to shifting the amplitude from
            position k to k - r
            a and b are vectors of length r maximum
            convolve them (equivalent to polynomial mult) and add all amplitudes
            with exp k of r and higher to exp k - r
            After the multiplication all amplitudes are taken %n
        """
        res = NP.zeros(2 * r, dtype=NP.int64)
        res[:len(a)+len(b)-1] = NP.convolve(a, b)
        res = res[:-r] + res[-r:]
        return res % n
    
    def powmodn(pn, n, r, m):
        res = [1]
        while n:
            if n & 1:
                res = mmultn(res, pn, r, m)
            n //= 2
            if n:
                pn = mmultn(pn, pn, r, m)
        return res
    
    def testan(a, n, r):
        pp = powmodn([a, 1], n, r, n)
        pp[n%r] = (pp[n%r] - 1 ) % n # subtract X^n 
        pp[0] = (pp[0] - a) % n      # subtract a
        return not any(pp)
         
    def phi(n):
        return sum(fractions.gcd(i, n) == 1 for i in xrange(1, n))
            
    def aks(n):
        for a in xrange(2, isqrt(n) + 1):
            for b in xrange(2, n):
                t = a ** b
                if t == n:
                    return False
                if t > n:
                    break
        logn = math.log(n, 2)
        logn2 = logn ** 2
        for r in IT.count(3):
            if fractions.gcd(r, n) == 1 and ordr(r, n) >= logn2: 
                break
        for a in xrange(2, r + 1):
            if 1 < fractions.gcd(a, n) < n:
                return False
        if n <= r:
            return True
        for a in xrange(1, int(math.sqrt(phi(r)) * logn)):
            if not testan(a, n, r):
                return False
        return True
    
  2. danaj said

    Paul, great job writing very straightforward code, and truly amazing speed out of Python by using numpy’s convolve. However by putting off the modulo until after the convolve and fold, it does restrict the usable range. For example, aks(190000003) returns False on my 64-bit machine. As I mentioned before, for smaller numbers it’s rocking fast and (more important) very readable complete code.

Leave a comment