## Sum of squares of two largest of three values

### March 16, 2012

Before you read further, check your function to see if it returns the proper response of 25 for the inputs 3, 3, and 4.

This problem recently came through on Stack Overflow, and I was reminded of a great discussion of the problem on the *Usenet* group `comp.lang.scheme`

several years ago. The solution that the authors of SICP wanted is probably something like this (we call our function `f`

because `sum-of-squares-of-max-2-of-3`

is a little bit unwieldy):

`(define (f x y z)`

(if (< x y)

(if (< x z)

(+ (* y y) (* z z))

(+ (* x x) (* y y)))

(if (< y z)

(+ (* x x) (* z z))

(+ (* x x) (* y y)))))

An alternate solution uses features of Scheme that hadn’t been introduced at that point in SICP; it calls itself recursively, rotating the arguments until the first argument is the minimum, when it takes the sum of the squares of the other two:

`(define (f x y z)`

(if (= x (min x y z))

(+ (* y y) (* z z))

(f y z x)))

The consensus answer in the thread on `comp.lang.scheme`

was this solution:

`(define (f x y z)`

(cond ((and (< x y) (< x z)) (+ (* y y) (* z z)))

((and (< y x) (< y z)) (+ (* x x) (* z z)))

(else (+ (* x x) (* y y)))))

The idea was to make a three-way choice to determine which element is the minimum, then take the sum of the squares of the other two. But that fails on `(f 3 3 4)`

. Do you see why?

Since this problem is hard, it makes sense to have a proper test suite. There are six permutations of three different values, three permutations where the minimum value is duplicated, three permutations where the maximum value is duplicated, and one permutation where all three values are the same, a total of thirteen permutations. It’s not hard to test all thirteen:

`(define (test f)`

(assert (f 3 4 5) 41)

(assert (f 3 5 4) 41)

(assert (f 4 3 5) 41)

(assert (f 4 5 3) 41)

(assert (f 5 3 4) 41)

(assert (f 5 4 3) 41)

(assert (f 3 3 4) 25)

(assert (f 3 4 3) 25)

(assert (f 4 3 3) 25)

(assert (f 3 4 4) 32)

(assert (f 4 3 4) 32)

(assert (f 4 4 3) 32)

(assert (f 3 3 3) 18))

When run with the third version of f, test shows the error:

`> (test f)`

failed assertion:

(f 3 3 4)

expected: 25

returned: 18

You can run the program at http://programmingpraxis.codepad.org/l2YWpGaf.

Knuth says the median of x, y, z is “probably the most important ternary operation in the entire universe” (TAoCP volume 4 fascicle 0) so I take this opportunity to give it some visibility. Spread the word.

Note that at the point where the exercise appears in SICP, the functions min and max haven’t mean introduced.

In the context of SICP it must therefore me considered cheating to use min and max !

That said, the simplest expression to use is

x^2 + y^2 + z^2 – min(x,y,z)^2

(define (wev2 x y z) (+ (expt x 2) (expt y 2) (expt z 2) (- (expt (min x y z) 2)))

N.B. (x > z z and z < y)

It doesn’t do any parameter checking, but it should always return the sum of the two largest squares.

(defun med (x y z) (max (min x y) (min x z) (min y z)))

(defun sum-squares-of-max-2 (x y z)

(apply #’+ (mapcar #’square (list (max x y z) (med x y z)))))

Yet another “”sort first” approach.

Is there a smart solution with no branching whatsoever? The one with subtracting the smallest square comes close, but min() uses branching internally.

def sumSquares(num1, num2, num3):

numList = [num1,num2,num3]

chosen1 = max(numList)

numList.remove(chosen1)

chosen2 = max(numList)

return chosen1**2 + chosen2**2

Oh, I’ve just realized that you *can* write min() without branching! I hope I didn’t make any mistake here:

[…] solution, posted in comments, to the following problem needs some explanation. First of all, the task is to take three numbers as arguments and return the […]

Tomasz, I highly doubt its possible to avoid branching altogether, especially if your input is real numbers (not just integers). However, I took your approach and considered an integer-only solution:

The return statement could probably use some explanation. To avoid computing the square of all terms and subtracting off the minimum squared, only the two greater values are squared and summed.

This should help protect against overflow, and it should slightly increase performance.

The logic is all tied into a boolean expression that will always return 1. It’s important to note how the values of j and k are initialized at the beginning of the function.

The “(x = y)|1” expression is used to handle the case where y = 0. We still want the assignment to evaluate as true, so the bitwise “OR 1” is appended to the expression.

Using proper control structures, the return statement’s logic could be rewritten as follows:

I’m not entirely convinced the compiler will not break up this mess and generate branching code. Maybe someone more familiar with this sort of thing has some insight?

And just for fun, here’s a golf’d version of basically the same algorithm. the assignment ordering is a little different since we don’t have the wide variable placeholders j and k. it’s also not quite as safe since the types are all implicitly declared “int”, so you can easily overflow.

However, it is valid C code (built using cygwin, gcc 4.3.4), and it works correctly assuming the absolute value of each input is less than sqrt(2^31-1)/2.

Branch-free algorithm for computing sum of squares of two largest of three values (with non-golf’d test code):

As long as your (ab)using short circut behavior of logical expressions, you might as well go all the way and (ab)use the fact that logical expressions have integer values (1 for true, 0 for false) and integers have logical values (false for 0, true for non-zero):

I believe the equivalent C code would be:

f(x,y,x){return(x>=y&z>=y)*(x*x+z*z)||(x>=z)*(x*x+y*y)|| y*y+z*z}

[/sourcecodde]

Here is a short java method to achieve the same –

public static int sumOfSquares(int a, int b, int c)

{

int min = (a<b)?((a<c)?a:(c<b)?c:c):((b<c)?b:c);

//System.out.println("min : "+min);

return (int) (Math.pow(a, 2)+Math.pow(b, 2)+Math.pow(c, 2)-Math.pow(min, 2));

}

Wow I feel like a doofus, Mike. Your approach is so much more clean and simple!

By the way, I do use the integer value of C’s logical expressions in the arithmetic of the return statement, I just force the expression to always evaluate to true.

PHP CLI :

C:

In Forth you can do this without variables.

: sqr ( u — du ) DUP * ;

: sum-of-squares ( n1 n2 n3 — ) 2DUP > IF SWAP ENDIF sqr -ROT MAX sqr + . ;

Simple implementation in java. – Amandeep Dhanjal

public static int calculateSum(int a, int b, int c){

int sum = 0;

if(a < b && a < c){

System.out.println("Nums: "+b+" : "+c);

sum = b*b + c*c;

}else if(b < c && b < a){

System.out.println("Nums: "+a+" : "+c);

sum = a*a + c*c;

}else{

System.out.println("Nums: "+a+" : "+b);

sum = a*a + b*b;

}

return sum;

}

In java,

public int squareOfLargestTwo(int i, int j, int k){

int sum = 0;

sum = sum + (i>j?i*i:j*j) + (j>k?j*j:k*k);

return sum;

}

Dinesh Damaraju, that solution will not work for all input i, j, and k.

When j is the largest value, your function will return j * j + j * j.

Short-circuit boolean computations *are* branching: (or (a) (b)) is compiled exactly the same as (let ((r (a))) (if r r (b)))). If it’s possible to do away with the branching (which I suspect it is), you’d probably have to use a bunch of bitwise operations combined with shifts and multiplications: the result will probably have lots of operations, though, and of course to avoid branching in a loop, you’ll probably have to unroll one. Take with a grain of salt: I’m no assembly programmer.

Some CPUs support conditional execution of instructions based on CPU status

flags. See http://everything2.com/title/conditional+execution for some more info.

Here is an example branchless implementation in ARM assembly language.

(All values–input, intermiediate, and output–must fit in a 32-bit register)

The comments show an “equivalent” python statement of the ARM instruction.

My solution written in Go lang: https://gist.github.com/2942682 (I took the approach used by rainer March 16, 2012 at 9:35 AM).

ruby solution (http://codepad.org/tLtSsKBQ)

`typedef double num; /* or int, or whatever */`

`num ssq(num a, num b, num c) {`

return((a>=b ? a*a + (likely(c>=b)?c*c:b*b)

: b*b + (likely(c>=a)?c*c:a*a))

}

[…] Another clever way to do this question is by using rotations, as specified in the blog ProgrammingPraxis. […]