## Sum of squares of two largest of three values

### March 16, 2012

Before you read further, check your function to see if it returns the proper response of 25 for the inputs 3, 3, and 4.

This problem recently came through on Stack Overflow, and I was reminded of a great discussion of the problem on the Usenet group `comp.lang.scheme` several years ago. The solution that the authors of SICP wanted is probably something like this (we call our function `f` because `sum-of-squares-of-max-2-of-3` is a little bit unwieldy):

```(define (f x y z)   (if (< x y)       (if (< x z)           (+ (* y y) (* z z))           (+ (* x x) (* y y)))       (if (< y z)           (+ (* x x) (* z z))           (+ (* x x) (* y y)))))```

An alternate solution uses features of Scheme that hadn’t been introduced at that point in SICP; it calls itself recursively, rotating the arguments until the first argument is the minimum, when it takes the sum of the squares of the other two:

```(define (f x y z)   (if (= x (min x y z))       (+ (* y y) (* z z))       (f y z x)))```

The consensus answer in the thread on `comp.lang.scheme` was this solution:

```(define (f x y z)  (cond ((and (< x y) (< x z)) (+ (* y y) (* z z)))        ((and (< y x) (< y z)) (+ (* x x) (* z z)))        (else (+ (* x x) (* y y)))))```

The idea was to make a three-way choice to determine which element is the minimum, then take the sum of the squares of the other two. But that fails on `(f 3 3 4)`. Do you see why?

Since this problem is hard, it makes sense to have a proper test suite. There are six permutations of three different values, three permutations where the minimum value is duplicated, three permutations where the maximum value is duplicated, and one permutation where all three values are the same, a total of thirteen permutations. It’s not hard to test all thirteen:

```(define (test f)   (assert (f 3 4 5) 41)   (assert (f 3 5 4) 41)   (assert (f 4 3 5) 41)   (assert (f 4 5 3) 41)   (assert (f 5 3 4) 41)   (assert (f 5 4 3) 41)   (assert (f 3 3 4) 25)   (assert (f 3 4 3) 25)   (assert (f 4 3 3) 25)   (assert (f 3 4 4) 32)   (assert (f 4 3 4) 32)   (assert (f 4 4 3) 32)   (assert (f 3 3 3) 18))```

When run with the third version of f, test shows the error:

```> (test f) failed assertion: (f 3 3 4) expected: 25 returned: 18```

You can run the program at http://programmingpraxis.codepad.org/l2YWpGaf.

Pages: 1 2

### 31 Responses to “Sum of squares of two largest of three values”

1. rainer said
```sum_of_squares: procedure
parse arg a,b,c
return (max(a,b)**2)+(max(min(a,b),c))**2
```
2. DGel said
```sumSquares a b c  = sq (max a b) + sq (max (min a b) c) where sq a = a*a
```
3. hannesty said
```def two_largest_square_sum(a, b, c)
[ a, b, c ].sort[1..2].map { |x| x * x }.inject(:+)
end
```
4. Jussi Piitulainen said
```(define (med x y z) (max (min x y) (min x z) (min y z)))
(define (wev x y z) (+ (expt (med x y z) 2) (expt (max x y z) 2)))
```

Knuth says the median of x, y, z is “probably the most important ternary operation in the entire universe” (TAoCP volume 4 fascicle 0) so I take this opportunity to give it some visibility. Spread the word.

5. Johann Hibschman said
```sumLargestSquares=: 3 : '+/*:}:\:~y'
```
6. Note that at the point where the exercise appears in SICP, the functions min and max haven’t mean introduced.
In the context of SICP it must therefore me considered cheating to use min and max !

That said, the simplest expression to use is

x^2 + y^2 + z^2 – min(x,y,z)^2

7. Anonymous Coward said

(define (wev2 x y z) (+ (expt x 2) (expt y 2) (expt z 2) (- (expt (min x y z) 2)))

8. Mike said
```def sum_square(x,y,z):
return (x*x+y*y if x>z<y else
z*z+x*x if z>y<x else
y*y+z*z)
```

N.B. (x > z z and z < y)

9. philipgoh said
```def sum_squared(*args):
x = sorted(args, reverse=True)
return x * x + x *x
```

It doesn’t do any parameter checking, but it should always return the sum of the two largest squares.

10. Pablo B said

(defun med (x y z) (max (min x y) (min x z) (min y z)))
(defun sum-squares-of-max-2 (x y z)
(apply #’+ (mapcar #’square (list (max x y z) (med x y z)))))

11. Yet another “”sort first” approach.

```
int lrgsq(int a, int b, int c)
{
if (a < b) std::swap(a, b);
if (b < c) std::swap(c, b);
return a * a + b * b;
}
```

Is there a smart solution with no branching whatsoever? The one with subtracting the smallest square comes close, but min() uses branching internally.

12. Alex said

def sumSquares(num1, num2, num3):
numList = [num1,num2,num3]
chosen1 = max(numList)
numList.remove(chosen1)
chosen2 = max(numList)
return chosen1**2 + chosen2**2

13. Oh, I’ve just realized that you *can* write min() without branching! I hope I didn’t make any mistake here:

```int f(int a, int b, int c)
{
return a*a+b*b+c*c-(c+((((b+((a-b)-abs(a-b))/2))-c)-abs(((b+((a-b)-abs(a-b))/2))-c))/2)*(c+((((b+((a-b)-abs(a-b))/2))-c)-abs(((b+((a-b)-abs(a-b))/2))-c))/2);
}
```
14. […] solution, posted in comments, to the following problem needs some explanation. First of all, the task is to take three numbers as arguments and return the […]

15. ardnew said

Tomasz, I highly doubt its possible to avoid branching altogether, especially if your input is real numbers (not just integers). However, I took your approach and considered an integer-only solution:

```//
// function f(a, b, c):
//     accepts three signed integers a, b, and c as parameters and returns
//     the sum of the two largest of these integers squared
//
unsigned long long f(const signed int a,
const signed int b,
const signed int c)
{
signed int m;
unsigned long long j = b, // wider types to avoid casting and overflow
k = c; //

//
// branchless minimum of two integers:
//     min(x, y) = y ^ ((x ^ y) & -(x < y));
//
// performed twice to find minimum of three integers
//
m = b ^ ((a ^ b) & -(a < b));
m = c ^ ((m ^ c) & -(m < c));

//
// select the two largest integers, j and k, from the three input integers
// and return the sum of their squares. without branching!
//
// basically, we abuse the short-circuiting behavior of C's logical
// operators to simulate a proper if-else control structure.
//
return (m == a || (j = a)|1 && m == b || (k = b)|1) * j * j + k * k;
}
```

The return statement could probably use some explanation. To avoid computing the square of all terms and subtracting off the minimum squared, only the two greater values are squared and summed.

This should help protect against overflow, and it should slightly increase performance.

The logic is all tied into a boolean expression that will always return 1. It’s important to note how the values of j and k are initialized at the beginning of the function.

The “(x = y)|1” expression is used to handle the case where y = 0. We still want the assignment to evaluate as true, so the bitwise “OR 1” is appended to the expression.

Using proper control structures, the return statement’s logic could be rewritten as follows:

```j = b;
k = c;

if (min != a)
{
j = a;

if (min != b)
{
k = b;
}
}

return j * j + k * k;
```

I’m not entirely convinced the compiler will not break up this mess and generate branching code. Maybe someone more familiar with this sort of thing has some insight?

16. ardnew said

And just for fun, here’s a golf’d version of basically the same algorithm. the assignment ordering is a little different since we don’t have the wide variable placeholders j and k. it’s also not quite as safe since the types are all implicitly declared “int”, so you can easily overflow.

However, it is valid C code (built using cygwin, gcc 4.3.4), and it works correctly assuming the absolute value of each input is less than sqrt(2^31-1)/2.

Branch-free algorithm for computing sum of squares of two largest of three values (with non-golf’d test code):

```#include <stdio.h>

//
// lol wat
//
f(a,b,c){int m=b^((a^b)&-(a<b));m=c^((m^c)&-(m<c));return(m==a||m==b&&(b=a)|1||(c=b)|(b=a)|1)*b*b+c*c;}

//
// function test(a, b, c, r):
//     verifies the calculation of f(a, b, c) equals the expected value r
//
void test(int a, int b, int c, int r)
{
int s = f(a, b, c);
printf("\nmin(%d, %d, %d)\n%12s = %d\n%12s = %d%s",
a, b, c, "expected", r, "returned", s, r == s ? "\n" : " <= BANANAS!\n");
}

//
// function main()
//     driver program for testing various inputs with known solutions
int main()
{
test(3, 4, 5, 41);
test(3, 5, 4, 41);
test(4, 3, 5, 41);
test(4, 5, 3, 41);
test(5, 3, 4, 41);
test(5, 4, 3, 41);
test(3, 3, 4, 25);
test(3, 4, 3, 25);
test(4, 3, 3, 25);
test(3, 4, 4, 32);
test(4, 3, 4, 32);
test(4, 4, 3, 32);
test(3, 3, 3, 18);
test(-2, 0, -3, 4);
test(0, 2147483647, 2147483647, 9223372028264841218ULL);

return 0;
}
```
17. Mike said

As long as your (ab)using short circut behavior of logical expressions, you might as well go all the way and (ab)use the fact that logical expressions have integer values (1 for true, 0 for false) and integers have logical values (false for 0, true for non-zero):

```f=lambda x,y,z:(x>=y<=z)*(x*x+z*z)or(x>=z)*(x*x+y*y)or y*y+z*z

def test(reps=1000):
from itertools import permutations
from random import sample

def ref(x,y,z):
a,b = sorted((x,y,z))[1:]
return a*a + b*b

population = [x/100.0 for x in range(-1000,1001)]

for _ in range(reps):
for x,y,z in permutations(sample(population, 3)):
if f(x,y,z) == ref(x,y,z):
continue

print x,y,z, f(x,y,z), ref(x,y,z)
```

I believe the equivalent C code would be:

f(x,y,x){return(x>=y&z>=y)*(x*x+z*z)||(x>=z)*(x*x+y*y)|| y*y+z*z}
[/sourcecodde]

18. Mait said

Here is a short java method to achieve the same –

public static int sumOfSquares(int a, int b, int c)
{
int min = (a<b)?((a<c)?a:(c<b)?c:c):((b<c)?b:c);
//System.out.println("min : "+min);
return (int) (Math.pow(a, 2)+Math.pow(b, 2)+Math.pow(c, 2)-Math.pow(min, 2));

}

19. ardnew said

Wow I feel like a doofus, Mike. Your approach is so much more clean and simple!

By the way, I do use the integer value of C’s logical expressions in the arithmetic of the return statement, I just force the expression to always evaluate to true.

20. PHP CLI :

```<? sort(\$argv,SORT_NUMERIC);echo pow(\$argv,2)+pow(\$argv,2); ?>
```

C:

```#include<stdio.h>
int main(int argc,char* argv[])
{
int a=atoi(argv),b=atoi(argv),c=atoi(argv);
printf("%d",c<a&&c<b?a*a+b*b:a<c&&a<b?b*b+c*c:a*a+c*c);
}
```
21. jonathanjohansen said
`(defun max22 (a b c) (reduce #'+ (remove (min a b c) (list a b c)) :key (lambda (x) (* x x))))`
22. In Forth you can do this without variables.

: sqr ( u — du ) DUP * ;
: sum-of-squares ( n1 n2 n3 — ) 2DUP > IF SWAP ENDIF sqr -ROT MAX sqr + . ;

23. Simple implementation in java. – Amandeep Dhanjal

public static int calculateSum(int a, int b, int c){
int sum = 0;

if(a < b && a < c){
System.out.println("Nums: "+b+" : "+c);
sum = b*b + c*c;
}else if(b < c && b < a){
System.out.println("Nums: "+a+" : "+c);
sum = a*a + c*c;
}else{
System.out.println("Nums: "+a+" : "+b);
sum = a*a + b*b;
}

return sum;

}

24. Dinesh Damaraju said

In java,

public int squareOfLargestTwo(int i, int j, int k){
int sum = 0;
sum = sum + (i>j?i*i:j*j) + (j>k?j*j:k*k);
return sum;
}

25. ardnew said

Dinesh Damaraju, that solution will not work for all input i, j, and k.

When j is the largest value, your function will return j * j + j * j.

26. treeowl said

Short-circuit boolean computations *are* branching: (or (a) (b)) is compiled exactly the same as (let ((r (a))) (if r r (b)))). If it’s possible to do away with the branching (which I suspect it is), you’d probably have to use a bunch of bitwise operations combined with shifts and multiplications: the result will probably have lots of operations, though, and of course to avoid branching in a loop, you’ll probably have to unroll one. Take with a grain of salt: I’m no assembly programmer.

27. Mike said

Some CPUs support conditional execution of instructions based on CPU status

Here is an example branchless implementation in ARM assembly language.
(All values–input, intermiediate, and output–must fit in a 32-bit register)
The comments show an “equivalent” python statement of the ARM instruction.

```    # On Entry: Registers R0, R1, R2 contain the numbers to process
# On Exit: R0 contains the sum of the squares of the two larger numbers

# The first 8 instructions works kind of like a bubble sort, to move the smallest
# number to R2 and the two larger numbers to R0 and R1
CMP     R0, R1          # LT = R0 < R1
MOVLT   R0, R0, R1      # if LT: R4 = R0  \  these swap R0 and R1,
MOVLT   R1, R0, R1      # if LT: R0 = R1   +   but only execute
MOVLT   R0, R0, R1      # if LT: R1 = R4  /      if R0 < R1.

CMP     R1, R2          # LT = R1 < R2
MOVLT   R1, R1, R2      # if LT: R4 = R1  \  these swap R1 and R2
MOVLT   R2, R1, R2      # if LT: R1 = R2   +   but only execute
MOVLT   R1, R1, R2      # if LT: R2 = R4  /      if R1 < R2

# these instructions square the two larger values and add them together
SMULL   R0, R0, R0      # R0 = R0 * R0
SMULL   R1, R1, R1      # R1 = R1 * R1
SADD    R0, R0, R1      # R0 = R0 + R1

```
28. Christian Siegert said

My solution written in Go lang: https://gist.github.com/2942682 (I took the approach used by rainer March 16, 2012 at 9:35 AM).

29. treeowl said

`typedef double num; /* or int, or whatever */`

``` ```

```num ssq(num a, num b, num c) { return((a>=b ? a*a + (likely(c>=b)?c*c:b*b) : b*b + (likely(c>=a)?c*c:a*a)) }```

30. […] Another clever way to do this question is by using rotations, as specified in the blog ProgrammingPraxis. […]