Folds
February 2, 2018
Folds, in the parlance of functional programming, are a way to convert lists to a value of some other type; a fold applies a function pair-wise to each element of a list and an accumulator, then returns the accumulator when the list is exhausted. The fundamental fold is foldr:
foldr f a [x1, x2, ..., xn] = f x1 (f x2 (... (f xn a)...))
Here, f is a function with two arguments, a is an initial value, and [x1, x2, ..., xn] is the input list. The name foldr stands for “fold right”, because the parentheses stack on the right side of the expansion, the items in the list are processed right-to-left, and the accumulator is on the right side of the binary function. Foldl is similar:
foldl f a [x1, x2, ..., xn] = (...((f a x1) x2) ... xn)
The arguments have the same meaning, with “fold left” referring to the fact that the parentheses stack on the left, the items in the list are processed left-to-right, and the accumulator is on the left side of the binary function. Note that foldl and foldr have different types, because the binary function takes its arguments in opposite order. In some cases, that makes a difference; for instance, when f is cons, you must use foldr. But when the function is associative, such as +, you can use either foldl or foldr. Here are some examples:
foldr + 0 [1,2,3,4] → 10 foldl + 0 [1,2,3,4] → 10 foldr cons [] [1,2,3,4] → [1,2,3,4] foldl cons [] [1,2,3,4] → [[[[[],1],2],3],4] foldr plusone 0 [1,2,3,4] → 4 foldl snoc [] [1,2,3,4] → [4,3,2,1]
Sometimes there is no obvious starting value. For instance, if you want to find the maximum item in a list, there is no “guaranteed to be less than anything else” value to use for a. In that case you can use the foldl1 and foldr1 variants that take the first item in the list as the initial value. Here, max is a binary function that takes two numbers and returns the larger; it is applied pair-wise at each item in the list (we ignore the fact that the built-in max can take more than two arguments):
foldr1 max [1,2,3,4] → 4 foldl1 min [1,2,3,4] → 1
Related to foldl is scan, which applies foldl to every initial segment of a list:
scan f a [x1, x2, ..., xn] = [a, f(a, x1), f(f(a, x1), x2), ..., f(f(f(a, x1), x2), x3)]
For instance:
scan + 0 [1,2,3,4] → [0,1,3,6,10] scan snoc [] [1,2,3,4] → [[], [1], [2,1],[3,2,1],[4,3,2,1]]
Your task is to implement all the various folds shown above; if your language provides them natively, you should re-implement them from first principles. When you are finished, you are welcome to read or run a suggested solution, or to post your own solution or discuss the exercise in the comments below.
Here’s a fully tail-recursive solution in standard R7RS Scheme.
(import (scheme base) (scheme write)) (define (reverse-arguments proc) (lambda args (apply proc (reverse args)))) (define snoc (reverse-arguments cons)) (define-syntax show/result (syntax-rules () ((_ e ...) (begin (begin (display 'e) (newline) (display e) (newline)) ...)))) (define (foldl f a xs) (if (null? xs) a (foldl f (f a (car xs)) (cdr xs)))) (show/result (foldl cons '() '(1 2 3)) (foldl snoc '() '(1 2 3))) (define (foldr f a xs) (let lp ((xs (reverse xs)) (a a)) (if (null? xs) a (lp (cdr xs) (f (car xs) a))))) (show/result (foldr + 0 '(1 2 3 4)) (foldl + 0 '(1 2 3 4)) (foldr cons '() '(1 2 3 4)) (foldl cons '() '(1 2 3 4)) (foldr (lambda (x y) (+ y 1)) 0 '(1 2 3 4)) (foldl snoc '() '(1 2 3 4))) (define (foldl1 f xs) (foldl f (car xs) (cdr xs))) (define (foldr1 f xs) (let ((xs (reverse xs))) (let lp ((xs (cdr xs)) (a (car xs))) (if (null? xs) a (lp (cdr xs) (f (car xs) a)))))) (show/result (foldr1 max '(1 2 3 4)) (foldl1 min '(1 2 3 4)) (foldr1 cons '(1 2 3 4)) (foldl1 cons '(1 2 3 4))) (define (scan f a xs) (let lp ((xs xs) (r (list a))) (if (null? xs) (reverse r) (lp (cdr xs) (cons (f (car r) (car xs)) r))))) (show/result (scan + 0 '(1 2 3 4)) (scan snoc '() '(1 2 3 4)))Here’s a solution in C.
#include <stdio.h> #include <stdlib.h> #include <string.h> void foldr(void (*function)(const void* x, void* accumulatorp), void* accumulatorp, void* array, size_t nel, size_t width) { for (size_t i = 0; i < nel; ++i) { char* p = (char*)array; function(p + width * (nel - 1 - i), accumulatorp); } } void foldl(void (*function)(void* accumulatorp, const void* x), void* accumulatorp, void* array, size_t nel, size_t width) { for (size_t i = 0; i < nel; ++i) { char* p = (char*)array; function(accumulatorp, p + width * i); } } void foldr1(void (*function)(const void* x, void* accumulatorp), void* accumulatorp, void* array, size_t nel, size_t width) { memcpy(accumulatorp, array, width); foldr(function, accumulatorp, array, nel, width); } void foldl1(void (*function)(void* accumulatorp, const void* x), void* accumulatorp, void* array, size_t nel, size_t width) { memcpy(accumulatorp, array, width); foldl(function, accumulatorp, array, nel, width); } void scan(void (*function)(void* accumulatorp, const void* x), void* accumulatorp, void* input, size_t input_width, void* output, size_t output_width, size_t nel) { char* pin = (char*)input; char* pout = (char*)output; memcpy(pout, accumulatorp, output_width); for (size_t i = 0; i < nel; ++i) { function(accumulatorp, pin + input_width * i); memcpy(pout + output_width * (i + 1), accumulatorp, output_width); } } void addr(const void* x, void* accumulatorp) { *(int*)accumulatorp += *(int*)x; } void addl(void* accumulatorp, const void* x) { *(int*)accumulatorp += *(int*)x; } void appendr(const void* x, void* accumulatorp) { **((int**)accumulatorp) = *(int*)x; (*(int**)accumulatorp)++; } void appendl(void* accumulatorp, const void* x) { appendr(x, accumulatorp); } void plusoner(const void* x, void* accumulatorp) { (void)x; (*(int*)accumulatorp)++; } void maxr(const void* x, void* accumulatorp) { if (*(int*)x > *(int*)accumulatorp) { *(int*)accumulatorp = *(int*)x; } } void minl(void* accumulatorp, const void* x) { if (*(int*)x < *(int*)accumulatorp) { *(int*)accumulatorp = *(int*)x; } } void print_array(int* array, size_t nel) { printf("{"); for (size_t i = 0; i < nel; ++i) { if (i > 0) printf(","); printf("%d", array[i]); } printf("}"); } int main(void) { int array[] = {1,2,3,4}; size_t nel = sizeof(array) / sizeof(int); { int accumulator = 0; foldr(addr, &accumulator, array, nel, sizeof(int)); printf("foldr add 0 {1,2,3,4}\n "); printf("%d\n", accumulator); } { int* accumulator_base = alloca(sizeof(int) * nel); int* accumulator = accumulator_base; foldr(appendr, &accumulator_base, array, nel, sizeof(int)); printf("foldr append {...} {1,2,3,4}\n "); print_array(accumulator, nel); printf("\n"); } { int* accumulator_base = alloca(sizeof(int) * nel); int* accumulator = accumulator_base; foldl(appendl, &accumulator_base, array, nel, sizeof(int)); printf("foldl append {...} {1,2,3,4}\n "); print_array(accumulator, nel); printf("\n"); } { int accumulator = 0; foldr(plusoner, &accumulator, array, nel, sizeof(int)); printf("foldr plusone 0 {1,2,3,4}\n "); printf("%d\n", accumulator); } { int accumulator; foldr1(maxr, &accumulator, array, nel, sizeof(int)); printf("foldr1 max {1,2,3,4}\n "); printf("%d\n", accumulator); } { int accumulator; foldl1(minl, &accumulator, array, nel, sizeof(int)); printf("foldl1 min {1,2,3,4}\n "); printf("%d\n", accumulator); } { int accumulator = 0; int output[nel+1]; scan(addl, &accumulator, array, sizeof(int), output, sizeof(int), nel); printf("scan add {1,2,3,4}\n "); print_array(output, nel+1); printf("\n"); } return 0; }Output:
foldr add 0 {1,2,3,4} 10 foldr append {...} {1,2,3,4} {4,3,2,1} foldl append {...} {1,2,3,4} {1,2,3,4} foldr plusone 0 {1,2,3,4} 4 foldr1 max {1,2,3,4} 4 foldl1 min {1,2,3,4} 1 scan add {1,2,3,4} {0,1,3,6,10}Folds make sense for any algebraic data type, so we can fold over trees, for example:
Implementation of the various derived folds is left as an exercise.
Here’s one that takes a two-argument function:
Here’s some Haskell…
import Prelude hiding (foldl, foldl1, foldr, foldr1, scanl, scanr, all, any, map, maximum, minimum) import qualified Data.List foldr :: (a -> b -> b) -> b -> [a] -> b foldr _ e [] = e foldr f e (x:xs) = x `f` foldr f e xs foldr1 :: (a -> a -> a) -> [a] -> a foldr1 _ [] = error "called foldr1 with empty list" foldr1 f (x:xs) = foldr f x xs scanr :: (a -> b -> b) -> b -> [a] -> [b] scanr f e = foldr (\x (y:ys) -> x `f` y : (y:ys)) [e] foldl :: (b -> a -> b) -> b -> [a] -> b foldl _ e [] = e foldl f e (x:xs) = foldl f (e `f` x) xs foldl1 :: (a -> a -> a) -> [a] -> a foldl1 _ [] = error "called foldl1 with empty list" foldl1 f (x:xs) = foldl f x xs scanl :: (b -> a -> b) -> b -> [a] -> [b] scanl _ e [] = [e] scanl f e (x:xs) = e : scanl f (e `f` x) xs -------------------------------------------------------------------------------- -- Here are various familiar functions implemented using our folds and scans. -- -- Note that they may not be the most efficient implementations in Haskell, -- especially those based on left folds. map :: (a -> b) -> [a] -> [b] map f = foldr (\x xs -> f x : xs) [] maximum :: Ord a => [a] -> a maximum = foldr1 max minimum :: Ord a => [a] -> a minimum = foldl1 min any :: (a -> Bool) -> [a] -> Bool any p = foldr (\x b -> p x || b) False all :: (a -> Bool) -> [a] -> Bool all p = foldl (\b x -> p x && b) True inits :: [a] -> [[a]] inits = scanl (\xs x -> xs ++ [x]) [] tails :: [a] -> [[a]] tails = scanr (:) [] -------------------------------------------------------------------------------- test :: (Eq d, Show d) => (a -> b -> c -> d) -> (a -> b -> c -> d) -> a -> b -> c -> IO () test f1 f2 g e xs = let r1 = f1 g e xs r2 = f2 g e xs in putStrLn $ show r1 ++ " == " ++ show r2 ++ " ? " ++ show (r1 == r2) main :: IO () main = do let xs = [1..5] :: [Int] -- Check that our functions give the same results as the standard ones. test Data.List.foldr foldr (+) 0 xs -- sum test Data.List.foldr foldr (*) 1 xs -- product test Data.List.foldr foldr (-) 3 xs test Data.List.foldr foldr (:) [] xs -- id test Data.List.foldr foldr (\y ys -> ys ++ [y]) [] xs -- reverse test Data.List.foldl foldl (+) 0 xs -- sum test Data.List.foldl foldl (*) 1 xs -- product test Data.List.foldl foldl (-) 3 xs test Data.List.foldl foldl (\ys y -> ys ++ [y]) [] xs -- id test Data.List.foldl foldl (flip (:)) [] xs -- reverse -- Exercise our versions of some common functions. print $ map (+2) xs print $ maximum xs print $ minimum xs print $ any (> 3) xs print $ all (> 3) xs print $ inits xs print $ tails xs