Square-Sum Puzzle
January 16, 2018
I don’t watch a lot of television, but the YouTube channel Numberphile is one of the places I am careful not to miss. Numberphile recently had an episode called “The Square-Sum Puzzle” that makes a good exercise:
Rearrange the numbers from 1 to 15 so that any two adjacent numbers must sum to a square number.
Your task is to write a program to solve the Numberphile square-sum puzzle. When you are finished, you are welcome to read or run a suggested solution, or to post your own solution or discuss the exercise in the comments below.
In Python. It is easy to solve this by hand, as there are not many possibilities. A brute force solution should do the trick. There appears only 1 solution.
numbers = list(range(1, 16)) squares = set((i ** 2 for i in range(6))) successors = {i: set(j for j in numbers if i != j and i + j in squares) for i in numbers} # apparently 8 and 9 should be at beginning and end (only 1 possibility) def solve(): Q = [([8], set(numbers)-set([8]))] while Q: part, remain = Q.pop() if not remain: print(part) continue s = successors[part[-1]] & remain for i in s: Q.append([part + [i], remain - set([i])]) solve() # -> [8, 1, 15, 10, 6, 3, 13, 12, 4, 5, 11, 14, 2, 7, 9]@Paul, there are two solutions. The one you listed and its reverse.
Here’s mine.
#!/usr/bin/env python3 # https://programmingpraxis.com/2018/01/16/square-sum-puzzle/ N = 15 A = [None] * N squares = {i * i for i in range(N) if i * i < N + N} def place(pos, prev, unused): if unused == 0: # All numbers used, solution found print(A) assert all(a + aa in squares for (a, aa) in zip(A[:-1], A[1:])) elif pos == 0: # First number for i in range(1, N): bit = 1 << i if unused & bit: A[pos] = i place(pos + 1, i, unused & ~bit) else: # General case for sq in squares: dif = sq - prev if dif < 0 or dif > N: continue bit = 1 << dif if unused & bit: A[pos] = dif place(pos + 1, dif, unused & ~bit) place(0, 0, (1 << N + 1) - 2)@kernelbob: it is indeed more correct to say that there are 2 solutions, but the are essentially the same.
MAX = 15 numCalls = 0 class Node: def __init__(self, value): self.value = value self.neighbors = [] def addNeighbor(self, neighbor): self.neighbors.append(neighbor) def __add__(self, other): return self.value + other.value def __repr__(self): return 'Node(' + str(self.value) + ')' def findPath(self, used): global numCalls numCalls += 1 for neighbor in self.neighbors: if neighbor not in used: used.add(neighbor) if len(used) == MAX: return [neighbor] else: result = neighbor.findPath(used) if result: result.append(neighbor) return result else: used.remove(neighbor) return False values = tuple(range(1,MAX+1)) nodes = [Node(i) for i in values] squares = tuple(i*i for i in values) # Find all of the paths for i in values: for j in values[i:]: value = i+j if value > squares[-1]: break if value in squares: nodes[i-1].addNeighbor(nodes[j-1]) nodes[j-1].addNeighbor(nodes[i-1]) for node in nodes: path = node.findPath(set()) if path: assert all(path[i] + path[i+1] in squares for i in range(MAX-1)) print([p.value for p in path]) print() print('MAX:', MAX, 'numCalls:', numCalls) breakFinding a path to 15 is pretty easy, it only takes 48 tries for my depth-first search. However, the Numberphile2 channel says they searched up to 299. What sort of pathfinding algorithm did they use? 42 my laptop does in a second with 945,000 calls, but even the short jump to 45 takes a long time with over 43,000,000 calls. There must be a smarter way to start winding your way through the paths. Any ideas?
A little more digging shows Charlie used Sage and its hamiltonian_path() function which found a path through 150 nodes in 90 ms. (Bang fist on table) There must be a better way.
In Python using DFS. Function count and takewhile are from the itertools module. This code needs 69 ms for 299 for the first solution. I got an enormous speed up by searching first the nodes that have least amount of successors left (last line of the code).
def solve(max): numbers = set(range(1, max + 1)) squares = set(takewhile(lambda i: i < 2 * max, (i ** 2 for i in count(1)))) successors = {i: set(j for j in numbers if i != j and i + j in squares) for i in numbers} Q = [([], None, set(numbers))] while Q: solution, last, remain = Q.pop() if not remain: print(solution) break else: cand = [(solution + [i], i, remain - set([i])) for i in successors.get(last, numbers) & remain] Q += sorted(cand, key=lambda x: len(successors[x[1]] & remain), reverse=True)Hi Luke,
There are some options for pruning in your algorithm. On line 53 you have already constructed a full path, but you can check along the way.
For example, add a check on line 23 that does something like: if we have used an even amount of nodes, check if the last 2 numbers add up to a square.
if (len(used) % 2) == 0 and self.value + neighbor.value in squares: # keep going else: # stop here, no need to construct the rest of the pathHere’s a solution in C++ that generates all possible paths, using Algorithm X from TAOCP 7.2.1.2 to itereate over permutations (constructing only permutations with valid prefixes).
I started writing it in C, then switched to C++ for the STL.
/* square_sum_puzzle.cpp */ #include <cstdbool> #include <cstdint> #include <cstdlib> #include <iostream> #include <unordered_set> #include <vector> using std::cerr; using std::cout; using std::endl; using std::unordered_set; using std::vector; bool arrange(const uint32_t n, vector<vector<uint32_t>>* result) { if (n == 0) return false; unordered_set<int> squares; for (int i = 2;; ++i) { int square = i * i; if (square >= 2 * n) break; squares.insert(square); } uint32_t a[n+1]; uint32_t l[n+1]; uint32_t u[n+1]; uint32_t k, p, q; x1: for (k = 0; k < n; ++k) { l[k] = k + 1; } l[n] = 0; k = 1; x2: p = 0; q = l[0]; x3: a[k] = q; // Test a[1], ..., a[k] // Test if a[k] + a[k-1] is a perfect square. if (k > 1) { uint32_t sum = a[k] + a[k-1]; if (squares.find(sum) == squares.end()) goto x5; } if (k == n) { vector<uint32_t> v; for (int i = 1; i <= n; ++i) { v.push_back(a[i]); } result->push_back(v); goto x6; } x4: u[k] = p; l[p] = l[q]; ++k; goto x2; x5: p = q; q = l[p]; if (q != 0) goto x3; x6: --k; if (k == 0) goto done; p = u[k]; q = a[k]; l[p] = q; goto x5; done: return false; } int main(int argc, char* argv[]) { if (argc != 2) { cerr << "Usage: " << argv[0] << " <N>" << endl; return EXIT_FAILURE; } const uint32_t n = (const uint32_t)strtoul(argv[1], NULL, 10); vector<vector<uint32_t>> result; arrange(n, &result); for (const vector<uint32_t>& v : result) { for (size_t i = 0; i < v.size(); ++i) { if (i > 0) cout << ","; cout << v[i]; } cout << endl; } return EXIT_SUCCESS; }Example: