Adi Shamir’s Threshold Scheme

June 17, 2011

We first import two functions from the random module:

from random import randrange, sample

These will furnish us with the random integers we require:randrange(1,p) returns a uniformly distributed random integer in the half-open interval [1, p), while sample(xrange(1, p), n) will select n of the integers in [1, p) with no repeats.

We build our polynomial via a modified version of Horner's Method:

def horner_mod(coeffs, mod):
    return lambda x: (x,
        reduce(lambda a, b: a * x + b % mod, reversed(coeffs), 0) % mod)

This function, given a list of coefficients and a modulus, returns the function that maps x to the pair (x, y) where y is the result of running $x$ through the polynomial modulo mod whose coefficients are listed in coeffs.

From there, Shamir's Threshold scheme involves building our list of coefficients (the first — really zeroth — of which is the secret S), then evaluating the polynomial we've built n times and returning the (x, y) pairs:

def shamir_threshold(S, k, n, p):
    coeffs = [S]
    coeffs.extend(randrange(1, p) for _ in xrange(k – 1))
    return map(horner_mod(coeffs, p), sample(xrange(1, p), n))

Finally, we us Lagrange Interpolation to find the constant term S given a list of k or more (x, y) pairs. For good measure, we first check that we have enough pairs, then use just the first k to find S:

def interp_const(xy_pairs, k, p):
    assert len(xy_pairs) >= k, "Not enough points for interpolation"
    x = lambda i: xy_pairs[i][0]
    y = lambda i: xy_pairs[i][1]
    return sum(y(i) * prod(x(j) * mod_inv(x(j) - x(i), p) % p for j in xrange(k)
        if j != i) for i in xrange(k)) % p

We first define inline functions that allow us to pick out the ith x and y values from our list of pairs, then use them to write the equation for S given on the previous page.

As for testing, let's encode the words "PRAXIS" as an integer by considering all digits and letters as integers in base 36. We'll take p to be the next prime after S — this choice of p is arbitrary, it just needs to be bigger than S. To keep things manageable, let's take n to be 20 and k to be 5, so we're constructing a degree 4 polynomial and we'll hand out 20 (x, y) pairs.

if __name__ == "__main__":
    from pprint import pprint                   # Pretty printing
    S = int("PRAXIS", 36); print S              # Prints 1557514036
    n, k, p = 20, 5, 1557514061                 # p is the next prime after S
    xy_pairs = shamir_threshold(S, k, n, p)
    pprint(xy_pairs)                            # Prints all 20 (x, y) pairs
    print interp_const(pairs, k, p)             # Should print 1557514036

which gives the output (on one running; pairs depends on the random polynomial created)

[(697286162, 445615394L),
 (471866046, 757728985L),
 (112045393, 1132162792L),
 (397324764, 486286231L),
 (135120894, 1142009194L),
 (508637994, 1556915744L),
 (488738532, 834401917L),
 (1369874096, 1345716686L),
 (91597754, 487556032L),
 (970187759, 341284274L),
 (1102805729, 224871713L),
 (245100902, 1306749801L),
 (413372256, 568733054L),
 (1218343037, 63534734L),
 (442535975, 1060000953L),
 (1173207231, 400308586L),
 (515043844, 141960722L),
 (1162691976, 374990038L),
 (73252341, 785232686L),
 (934671161, 486917357L)]

(Ignore the L on the end of the numbers on the right; this just signifies that Python used long integer arithmetic to find them. Python 3.x, the newer iteration of Python, removes the distinction between integers and long integers.)

Our solution used mod_inv from a previous exercise; we also defined the function prod, analogous to sum from the Standard Prelude. You can run the full program at

About these ads

Pages: 1 2

5 Responses to “Adi Shamir’s Threshold Scheme”

  1. Graham said

    I also wrote a version in Common Lisp, since I’m trying to learn a new language. If there are any CL gurus in the audience, I’d appreciate any feedback!

  2. Graham said

    For good measure, here it is in C and Haskell.

  3. razvan said

    Here’s my solution in C:

    #include <stdio.h>
    #include <stdlib.h>
    #include <assert.h>
    #include <time.h>
    #define max(m, n) ((m) > (n) ? (m) : (n))
    typedef struct pair 
        int x;
        long y;
    } pair;
    long ipow(int x, int n)
        if(n == 0) return 1;
        if(n % 2 == 0) return ipow(x * x, n / 2);
        return x * ipow(x, n - 1);
    void euclid(int a, int b, int *d, int *x, int *y)
        if (b == 0) 
            *d = a;
            *x = 1;
            *y = 0;
        else {
            int x0, y0;
            euclid(b, a % b, d, &x0, &y0);
            *x = y0;
            *y = x0 - (a / b) * y0;
    long eval(int* P, int x, int n)
        assert(P != NULL);
        assert(n >= 0);
        int res = 0;
        int i;
        for(i=0; i<=n; i++)
            res += P[i] * ipow(x, i);
        return res;
    pair* encrypt(int S, int n, int k, int p)
        assert(k <= n);
        assert(p > max(S, n));
        int* P = malloc(k * sizeof(int));
        if(P == NULL)
        int i;
        P[0] = S;
        for(i=1; i<=k-1; i++)
            P[i] = rand() % p;
        pair* out = malloc(n * sizeof(pair));
        if(out == NULL)
        for(i=0; i<n; i++)
            out[i].x = rand() % p;
            int unique = 0;
                unique = 1;
                int j;
                for(j=0; j<i; j++)
                    if(out[j].x == out[i].x)
                        unique = 0;
                        out[i].x = rand() % p;
            out[i].y = eval(P, out[i].x, k-1) % p;
        return out;
    int decrypt(pair* pairs, int k, int p)
        assert(p > k);
        assert(pairs != NULL);
        int S = 0;
        int i, j;
        for(i=0; i<k; i++)
            int prod = pairs[i].y;
            for(j=0; j<k; j++)
                if(j == i) continue;
                prod*= pairs[j].x;
                int d, x, y;
                euclid(p, (pairs[j].x - pairs[i].x + p) % p, &d, &x, &y);
                prod *= (y + p) % p;
                prod %= p;
            S += prod;
        S %= p;
        return S;
    int main(int argc, char **argv)
        int S = 17, n = 3, k = 2, p = 23;
        pair* pairs = encrypt(S, n, k, p);
        int i;
        for(i=0; i<n; i++)
            printf("(%d, %ld)\n", pairs[i].x, pairs[i].y);
        printf("\nS = %d", decrypt(pairs, n, p));
        return 0;
  4. razvan said

    Err, that should rather be

    printf("\nS = %d", decrypt(pairs, k + 1, p));
  5. I changed github accounts, so the gists I linked to above are dead (apologies). Here are new ones:
    Common Lisp
    C, version 1
    C, version 2 (improved with a little help from Razvan’s solution).

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s


Get every new post delivered to your Inbox.

Join 576 other followers

%d bloggers like this: