### June 17, 2011

We first import two functions from the `random` module:

`from random import randrange, sample`

These will furnish us with the random integers we require:`randrange(1,p)` returns a uniformly distributed random integer in the half-open interval [1, p), while `sample(xrange(1, p), n)` will select n of the integers in [1, p) with no repeats.

We build our polynomial via a modified version of Horner's Method:

```def horner_mod(coeffs, mod):     return lambda x: (x,         reduce(lambda a, b: a * x + b % mod, reversed(coeffs), 0) % mod)```

This function, given a list of coefficients and a modulus, returns the function that maps x to the pair (x, y) where y is the result of running \$x\$ through the polynomial modulo `mod` whose coefficients are listed in `coeffs`.

From there, Shamir's Threshold scheme involves building our list of coefficients (the first — really zeroth — of which is the secret S), then evaluating the polynomial we've built n times and returning the (x, y) pairs:

```def shamir_threshold(S, k, n, p):     coeffs = [S]     coeffs.extend(randrange(1, p) for _ in xrange(k – 1))     return map(horner_mod(coeffs, p), sample(xrange(1, p), n))```

Finally, we us Lagrange Interpolation to find the constant term S given a list of k or more (x, y) pairs. For good measure, we first check that we have enough pairs, then use just the first k to find S:

```def interp_const(xy_pairs, k, p):     assert len(xy_pairs) >= k, "Not enough points for interpolation"     x = lambda i: xy_pairs[i][0]     y = lambda i: xy_pairs[i][1]     return sum(y(i) * prod(x(j) * mod_inv(x(j) - x(i), p) % p for j in xrange(k)         if j != i) for i in xrange(k)) % p```

We first define inline functions that allow us to pick out the ith x and y values from our list of pairs, then use them to write the equation for S given on the previous page.

As for testing, let's encode the words "PRAXIS" as an integer by considering all digits and letters as integers in base 36. We'll take p to be the next prime after S — this choice of p is arbitrary, it just needs to be bigger than S. To keep things manageable, let's take n to be 20 and k to be 5, so we're constructing a degree 4 polynomial and we'll hand out 20 (x, y) pairs.

```if __name__ == "__main__":     from pprint import pprint                   # Pretty printing     S = int("PRAXIS", 36); print S              # Prints 1557514036     n, k, p = 20, 5, 1557514061                 # p is the next prime after S     xy_pairs = shamir_threshold(S, k, n, p)     pprint(xy_pairs)                            # Prints all 20 (x, y) pairs     print interp_const(pairs, k, p)             # Should print 1557514036 ```

which gives the output (on one running; pairs depends on the random polynomial created)

```1557514036 [(697286162, 445615394L),  (471866046, 757728985L),  (112045393, 1132162792L),  (397324764, 486286231L),  (135120894, 1142009194L),  (508637994, 1556915744L),  (488738532, 834401917L),  (1369874096, 1345716686L),  (91597754, 487556032L),  (970187759, 341284274L),  (1102805729, 224871713L),  (245100902, 1306749801L),  (413372256, 568733054L),  (1218343037, 63534734L),  (442535975, 1060000953L),  (1173207231, 400308586L),  (515043844, 141960722L),  (1162691976, 374990038L),  (73252341, 785232686L),  (934671161, 486917357L)] 1557514036```

(Ignore the L on the end of the numbers on the right; this just signifies that Python used long integer arithmetic to find them. Python 3.x, the newer iteration of Python, removes the distinction between integers and long integers.)

Our solution used `mod_inv` from a previous exercise; we also defined the function `prod`, analogous to `sum` from the Standard Prelude. You can run the full program at http://codepad.org/NvYVdZFT.

### 5 Responses to “Adi Shamir’s Threshold Scheme”

1. Graham said

I also wrote a version in Common Lisp, since I’m trying to learn a new language. If there are any CL gurus in the audience, I’d appreciate any feedback!

2. Graham said

For good measure, here it is in C and Haskell.

3. razvan said

Here’s my solution in C:

```#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <time.h>
#define max(m, n) ((m) > (n) ? (m) : (n))

typedef struct pair
{
int x;
long y;
} pair;

long ipow(int x, int n)
{
if(n == 0) return 1;
if(n % 2 == 0) return ipow(x * x, n / 2);
return x * ipow(x, n - 1);
}

void euclid(int a, int b, int *d, int *x, int *y)
{
if (b == 0)
{
*d = a;
*x = 1;
*y = 0;
}
else {
int x0, y0;
euclid(b, a % b, d, &x0, &y0);
*x = y0;
*y = x0 - (a / b) * y0;
}
}

long eval(int* P, int x, int n)
{
assert(P != NULL);
assert(n >= 0);
int res = 0;
int i;
for(i=0; i<=n; i++)
res += P[i] * ipow(x, i);
return res;
}

pair* encrypt(int S, int n, int k, int p)
{
assert(k <= n);
assert(p > max(S, n));
int* P = malloc(k * sizeof(int));
if(P == NULL)
exit(1);
int i;
P[0] = S;
for(i=1; i<=k-1; i++)
P[i] = rand() % p;
pair* out = malloc(n * sizeof(pair));
if(out == NULL)
exit(1);
for(i=0; i<n; i++)
{
out[i].x = rand() % p;
int unique = 0;
while(!unique)
{
unique = 1;
int j;
for(j=0; j<i; j++)
if(out[j].x == out[i].x)
{
unique = 0;
out[i].x = rand() % p;
break;
}
}
out[i].y = eval(P, out[i].x, k-1) % p;
}
return out;
}

int decrypt(pair* pairs, int k, int p)
{
assert(p > k);
assert(pairs != NULL);
int S = 0;
int i, j;
for(i=0; i<k; i++)
{
int prod = pairs[i].y;
for(j=0; j<k; j++)
{
if(j == i) continue;
prod*= pairs[j].x;
int d, x, y;
euclid(p, (pairs[j].x - pairs[i].x + p) % p, &d, &x, &y);
prod *= (y + p) % p;
prod %= p;
}
S += prod;
}
S %= p;
return S;
}

int main(int argc, char **argv)
{
srand(time(NULL));
int S = 17, n = 3, k = 2, p = 23;
pair* pairs = encrypt(S, n, k, p);
int i;
for(i=0; i<n; i++)
printf("(%d, %ld)\n", pairs[i].x, pairs[i].y);
printf("\nS = %d", decrypt(pairs, n, p));
free(pairs);
return 0;
}

```
4. razvan said

Err, that should rather be

`printf("\nS = %d", decrypt(pairs, k + 1, p));`
5. I changed github accounts, so the gists I linked to above are dead (apologies). Here are new ones:
Common Lisp