Sum of squares of two largest of three values
March 16, 2012
Today’s exercise comes to us from the book Structure and Interpretation of Computer Programs by Abelson and Sussman (exercise 1.3):
Define a procedure that takes three numbers as arguments and returns the sum of the squares of the two larger numbers.
Your task is to write the indicated function. When you are finished, you are welcome to read or run a suggested solution, or to post your own solution or discuss the exercise in the comments below.
sum_of_squares: procedure parse arg a,b,c return (max(a,b)**2)+(max(min(a,b),c))**2def two_largest_square_sum(a, b, c) [ a, b, c ].sort[1..2].map { |x| x * x }.inject(:+) endKnuth says the median of x, y, z is “probably the most important ternary operation in the entire universe” (TAoCP volume 4 fascicle 0) so I take this opportunity to give it some visibility. Spread the word.
Note that at the point where the exercise appears in SICP, the functions min and max haven’t mean introduced.
In the context of SICP it must therefore me considered cheating to use min and max !
That said, the simplest expression to use is
x^2 + y^2 + z^2 – min(x,y,z)^2
(define (wev2 x y z) (+ (expt x 2) (expt y 2) (expt z 2) (- (expt (min x y z) 2)))
def sum_square(x,y,z): return (x*x+y*y if x>z<y else z*z+x*x if z>y<x else y*y+z*z)N.B. (x > z z and z < y)
It doesn’t do any parameter checking, but it should always return the sum of the two largest squares.
(defun med (x y z) (max (min x y) (min x z) (min y z)))
(defun sum-squares-of-max-2 (x y z)
(apply #’+ (mapcar #’square (list (max x y z) (med x y z)))))
Yet another “”sort first” approach.
int lrgsq(int a, int b, int c) { if (a < b) std::swap(a, b); if (b < c) std::swap(c, b); return a * a + b * b; }Is there a smart solution with no branching whatsoever? The one with subtracting the smallest square comes close, but min() uses branching internally.
def sumSquares(num1, num2, num3):
numList = [num1,num2,num3]
chosen1 = max(numList)
numList.remove(chosen1)
chosen2 = max(numList)
return chosen1**2 + chosen2**2
Oh, I’ve just realized that you *can* write min() without branching! I hope I didn’t make any mistake here:
int f(int a, int b, int c) { return a*a+b*b+c*c-(c+((((b+((a-b)-abs(a-b))/2))-c)-abs(((b+((a-b)-abs(a-b))/2))-c))/2)*(c+((((b+((a-b)-abs(a-b))/2))-c)-abs(((b+((a-b)-abs(a-b))/2))-c))/2); }[…] solution, posted in comments, to the following problem needs some explanation. First of all, the task is to take three numbers as arguments and return the […]
Tomasz, I highly doubt its possible to avoid branching altogether, especially if your input is real numbers (not just integers). However, I took your approach and considered an integer-only solution:
// // function f(a, b, c): // accepts three signed integers a, b, and c as parameters and returns // the sum of the two largest of these integers squared // unsigned long long f(const signed int a, const signed int b, const signed int c) { signed int m; unsigned long long j = b, // wider types to avoid casting and overflow k = c; // // // branchless minimum of two integers: // min(x, y) = y ^ ((x ^ y) & -(x < y)); // // performed twice to find minimum of three integers // m = b ^ ((a ^ b) & -(a < b)); m = c ^ ((m ^ c) & -(m < c)); // // select the two largest integers, j and k, from the three input integers // and return the sum of their squares. without branching! // // basically, we abuse the short-circuiting behavior of C's logical // operators to simulate a proper if-else control structure. // return (m == a || (j = a)|1 && m == b || (k = b)|1) * j * j + k * k; }The return statement could probably use some explanation. To avoid computing the square of all terms and subtracting off the minimum squared, only the two greater values are squared and summed.
This should help protect against overflow, and it should slightly increase performance.
The logic is all tied into a boolean expression that will always return 1. It’s important to note how the values of j and k are initialized at the beginning of the function.
The “(x = y)|1” expression is used to handle the case where y = 0. We still want the assignment to evaluate as true, so the bitwise “OR 1” is appended to the expression.
Using proper control structures, the return statement’s logic could be rewritten as follows:
j = b; k = c; if (min != a) { j = a; if (min != b) { k = b; } } return j * j + k * k;I’m not entirely convinced the compiler will not break up this mess and generate branching code. Maybe someone more familiar with this sort of thing has some insight?
And just for fun, here’s a golf’d version of basically the same algorithm. the assignment ordering is a little different since we don’t have the wide variable placeholders j and k. it’s also not quite as safe since the types are all implicitly declared “int”, so you can easily overflow.
However, it is valid C code (built using cygwin, gcc 4.3.4), and it works correctly assuming the absolute value of each input is less than sqrt(2^31-1)/2.
Branch-free algorithm for computing sum of squares of two largest of three values (with non-golf’d test code):
#include <stdio.h> // // lol wat // f(a,b,c){int m=b^((a^b)&-(a<b));m=c^((m^c)&-(m<c));return(m==a||m==b&&(b=a)|1||(c=b)|(b=a)|1)*b*b+c*c;} // // function test(a, b, c, r): // verifies the calculation of f(a, b, c) equals the expected value r // void test(int a, int b, int c, int r) { int s = f(a, b, c); printf("\nmin(%d, %d, %d)\n%12s = %d\n%12s = %d%s", a, b, c, "expected", r, "returned", s, r == s ? "\n" : " <= BANANAS!\n"); } // // function main() // driver program for testing various inputs with known solutions int main() { test(3, 4, 5, 41); test(3, 5, 4, 41); test(4, 3, 5, 41); test(4, 5, 3, 41); test(5, 3, 4, 41); test(5, 4, 3, 41); test(3, 3, 4, 25); test(3, 4, 3, 25); test(4, 3, 3, 25); test(3, 4, 4, 32); test(4, 3, 4, 32); test(4, 4, 3, 32); test(3, 3, 3, 18); test(-2, 0, -3, 4); test(0, 2147483647, 2147483647, 9223372028264841218ULL); return 0; }As long as your (ab)using short circut behavior of logical expressions, you might as well go all the way and (ab)use the fact that logical expressions have integer values (1 for true, 0 for false) and integers have logical values (false for 0, true for non-zero):
f=lambda x,y,z:(x>=y<=z)*(x*x+z*z)or(x>=z)*(x*x+y*y)or y*y+z*z def test(reps=1000): from itertools import permutations from random import sample def ref(x,y,z): a,b = sorted((x,y,z))[1:] return a*a + b*b population = [x/100.0 for x in range(-1000,1001)] for _ in range(reps): for x,y,z in permutations(sample(population, 3)): if f(x,y,z) == ref(x,y,z): continue print x,y,z, f(x,y,z), ref(x,y,z)I believe the equivalent C code would be:
f(x,y,x){return(x>=y&z>=y)*(x*x+z*z)||(x>=z)*(x*x+y*y)|| y*y+z*z}
[/sourcecodde]
Here is a short java method to achieve the same –
public static int sumOfSquares(int a, int b, int c)
{
int min = (a<b)?((a<c)?a:(c<b)?c:c):((b<c)?b:c);
//System.out.println("min : "+min);
return (int) (Math.pow(a, 2)+Math.pow(b, 2)+Math.pow(c, 2)-Math.pow(min, 2));
}
Wow I feel like a doofus, Mike. Your approach is so much more clean and simple!
By the way, I do use the integer value of C’s logical expressions in the arithmetic of the return statement, I just force the expression to always evaluate to true.
PHP CLI :
C:
#include<stdio.h> int main(int argc,char* argv[]) { int a=atoi(argv[1]),b=atoi(argv[2]),c=atoi(argv[3]); printf("%d",c<a&&c<b?a*a+b*b:a<c&&a<b?b*b+c*c:a*a+c*c); }In Forth you can do this without variables.
: sqr ( u — du ) DUP * ;
: sum-of-squares ( n1 n2 n3 — ) 2DUP > IF SWAP ENDIF sqr -ROT MAX sqr + . ;
Simple implementation in java. – Amandeep Dhanjal
public static int calculateSum(int a, int b, int c){
int sum = 0;
if(a < b && a < c){
System.out.println("Nums: "+b+" : "+c);
sum = b*b + c*c;
}else if(b < c && b < a){
System.out.println("Nums: "+a+" : "+c);
sum = a*a + c*c;
}else{
System.out.println("Nums: "+a+" : "+b);
sum = a*a + b*b;
}
return sum;
}
In java,
public int squareOfLargestTwo(int i, int j, int k){
int sum = 0;
sum = sum + (i>j?i*i:j*j) + (j>k?j*j:k*k);
return sum;
}
Dinesh Damaraju, that solution will not work for all input i, j, and k.
When j is the largest value, your function will return j * j + j * j.
Short-circuit boolean computations *are* branching: (or (a) (b)) is compiled exactly the same as (let ((r (a))) (if r r (b)))). If it’s possible to do away with the branching (which I suspect it is), you’d probably have to use a bunch of bitwise operations combined with shifts and multiplications: the result will probably have lots of operations, though, and of course to avoid branching in a loop, you’ll probably have to unroll one. Take with a grain of salt: I’m no assembly programmer.
Some CPUs support conditional execution of instructions based on CPU status
flags. See http://everything2.com/title/conditional+execution for some more info.
Here is an example branchless implementation in ARM assembly language.
(All values–input, intermiediate, and output–must fit in a 32-bit register)
The comments show an “equivalent” python statement of the ARM instruction.
# On Entry: Registers R0, R1, R2 contain the numbers to process # On Exit: R0 contains the sum of the squares of the two larger numbers # The first 8 instructions works kind of like a bubble sort, to move the smallest # number to R2 and the two larger numbers to R0 and R1 CMP R0, R1 # LT = R0 < R1 MOVLT R0, R0, R1 # if LT: R4 = R0 \ these swap R0 and R1, MOVLT R1, R0, R1 # if LT: R0 = R1 + but only execute MOVLT R0, R0, R1 # if LT: R1 = R4 / if R0 < R1. CMP R1, R2 # LT = R1 < R2 MOVLT R1, R1, R2 # if LT: R4 = R1 \ these swap R1 and R2 MOVLT R2, R1, R2 # if LT: R1 = R2 + but only execute MOVLT R1, R1, R2 # if LT: R2 = R4 / if R1 < R2 # these instructions square the two larger values and add them together SMULL R0, R0, R0 # R0 = R0 * R0 SMULL R1, R1, R1 # R1 = R1 * R1 SADD R0, R0, R1 # R0 = R0 + R1My solution written in Go lang: https://gist.github.com/2942682 (I took the approach used by rainer March 16, 2012 at 9:35 AM).
ruby solution (http://codepad.org/tLtSsKBQ)
typedef double num; /* or int, or whatever */num ssq(num a, num b, num c) {
return((a>=b ? a*a + (likely(c>=b)?c*c:b*b)
: b*b + (likely(c>=a)?c*c:a*a))
}
[…] Another clever way to do this question is by using rotations, as specified in the blog ProgrammingPraxis. […]