Folds

February 2, 2018

Folds, in the parlance of functional programming, are a way to convert lists to a value of some other type; a fold applies a function pair-wise to each element of a list and an accumulator, then returns the accumulator when the list is exhausted. The fundamental fold is foldr:

foldr f a [x1, x2, ..., xn] = f x1 (f x2 (... (f xn a)...))

Here, f is a function with two arguments, a is an initial value, and [x1, x2, ..., xn] is the input list. The name foldr stands for “fold right”, because the parentheses stack on the right side of the expansion, the items in the list are processed right-to-left, and the accumulator is on the right side of the binary function. Foldl is similar:

foldl f a [x1, x2, ..., xn] = (...((f a x1) x2) ... xn)

The arguments have the same meaning, with “fold left” referring to the fact that the parentheses stack on the left, the items in the list are processed left-to-right, and the accumulator is on the left side of the binary function. Note that foldl and foldr have different types, because the binary function takes its arguments in opposite order. In some cases, that makes a difference; for instance, when f is cons, you must use foldr. But when the function is associative, such as +, you can use either foldl or foldr. Here are some examples:

foldr + 0 [1,2,3,4] → 10
foldl + 0 [1,2,3,4] → 10
foldr cons [] [1,2,3,4] → [1,2,3,4]
foldl cons [] [1,2,3,4] → [[[[[],1],2],3],4]
foldr plusone 0 [1,2,3,4] → 4
foldl snoc [] [1,2,3,4] → [4,3,2,1]

Sometimes there is no obvious starting value. For instance, if you want to find the maximum item in a list, there is no “guaranteed to be less than anything else” value to use for a. In that case you can use the foldl1 and foldr1 variants that take the first item in the list as the initial value. Here, max is a binary function that takes two numbers and returns the larger; it is applied pair-wise at each item in the list (we ignore the fact that the built-in max can take more than two arguments):

foldr1 max [1,2,3,4] → 4
foldl1 min [1,2,3,4] → 1

Related to foldl is scan, which applies foldl to every initial segment of a list:

scan f a [x1, x2, ..., xn] = [a, f(a, x1), f(f(a, x1), x2), ..., f(f(f(a, x1), x2), x3)]

For instance:

scan + 0 [1,2,3,4] → [0,1,3,6,10]
scan snoc [] [1,2,3,4] → [[], [1], [2,1],[3,2,1],[4,3,2,1]]

Your task is to implement all the various folds shown above; if your language provides them natively, you should re-implement them from first principles. When you are finished, you are welcome to read or run a suggested solution, or to post your own solution or discuss the exercise in the comments below.

Advertisement

Pages: 1 2

5 Responses to “Folds”

  1. chaw said

    Here’s a fully tail-recursive solution in standard R7RS Scheme.

    (import (scheme base)
            (scheme write))
    
    (define (reverse-arguments proc)
      (lambda args
        (apply proc (reverse args))))
    
    (define snoc (reverse-arguments cons))
    
    (define-syntax show/result
      (syntax-rules ()
        ((_ e ...) (begin (begin (display 'e)
                                 (newline)
                                 (display e)
                                 (newline))
                          ...))))
    
    (define (foldl f a xs)
      (if (null? xs)
          a
          (foldl f (f a (car xs)) (cdr xs))))
    
    (show/result (foldl cons '() '(1 2 3))
                 (foldl snoc '() '(1 2 3)))
    
    (define (foldr f a xs)
      (let lp ((xs (reverse xs))
               (a a))
        (if (null? xs)
            a
            (lp (cdr xs)
                (f (car xs) a)))))
    
    (show/result (foldr + 0 '(1 2 3 4))
                 (foldl + 0 '(1 2 3 4))
                 (foldr cons '() '(1 2 3 4))
                 (foldl cons '() '(1 2 3 4))
                 (foldr (lambda (x y) (+ y 1)) 0 '(1 2 3 4))
                 (foldl snoc '() '(1 2 3 4)))
    
    (define (foldl1 f xs)
      (foldl f (car xs) (cdr xs)))
    
    (define (foldr1 f xs)
      (let ((xs (reverse xs)))
        (let lp ((xs (cdr xs))
                 (a (car xs)))
          (if (null? xs)
              a
              (lp (cdr xs)
                  (f (car xs) a))))))
    
    (show/result (foldr1 max '(1 2 3 4))
                 (foldl1 min '(1 2 3 4))
                 (foldr1 cons '(1 2 3 4))
                 (foldl1 cons '(1 2 3 4)))
    
    (define (scan f a xs)
      (let lp ((xs xs)
               (r (list a)))
        (if (null? xs)
            (reverse r) 
            (lp (cdr xs)
                (cons (f (car r) (car xs))
                      r)))))
    
    (show/result (scan + 0 '(1 2 3 4))
                 (scan snoc '() '(1 2 3 4)))
    

  2. Daniel said

    Here’s a solution in C.

    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>
    
    void foldr(void (*function)(const void* x, void* accumulatorp),
               void* accumulatorp,
               void* array,
               size_t nel,
               size_t width) {
      for (size_t i = 0; i < nel; ++i) {
        char* p = (char*)array;
        function(p + width * (nel - 1 - i), accumulatorp);
      }
    }
    
    void foldl(void (*function)(void* accumulatorp, const void* x),
               void* accumulatorp,
               void* array,
               size_t nel,
               size_t width) {
      for (size_t i = 0; i < nel; ++i) {
        char* p = (char*)array;
        function(accumulatorp, p + width * i);
      }
    }
    
    void foldr1(void (*function)(const void* x, void* accumulatorp),
                void* accumulatorp,
                void* array,
                size_t nel,
                size_t width) {
      memcpy(accumulatorp, array, width);
      foldr(function, accumulatorp, array, nel, width);
    }
    
    void foldl1(void (*function)(void* accumulatorp, const void* x),
                void* accumulatorp,
                void* array,
                size_t nel,
                size_t width) {
      memcpy(accumulatorp, array, width);
      foldl(function, accumulatorp, array, nel, width);
    }
    
    void scan(void (*function)(void* accumulatorp, const void* x),
              void* accumulatorp,
              void* input,
              size_t input_width,
              void* output,
              size_t output_width,
              size_t nel) {
      char* pin = (char*)input;
      char* pout = (char*)output;
      memcpy(pout, accumulatorp, output_width);
      for (size_t i = 0; i < nel; ++i) {
        function(accumulatorp, pin + input_width * i);
        memcpy(pout + output_width * (i + 1), accumulatorp, output_width);
      }
    }
    
    void addr(const void* x, void* accumulatorp) {
      *(int*)accumulatorp += *(int*)x;
    }
    
    void addl(void* accumulatorp, const void* x) {
      *(int*)accumulatorp += *(int*)x;
    }
    
    void appendr(const void* x, void* accumulatorp) {
      **((int**)accumulatorp) = *(int*)x;
      (*(int**)accumulatorp)++;
    }
    
    void appendl(void* accumulatorp, const void* x) {
      appendr(x, accumulatorp);
    }
    
    void plusoner(const void* x, void* accumulatorp) {
      (void)x;
      (*(int*)accumulatorp)++;
    }
    
    void maxr(const void* x, void* accumulatorp) {
      if (*(int*)x > *(int*)accumulatorp) {
        *(int*)accumulatorp = *(int*)x;
      }
    }
    
    void minl(void* accumulatorp, const void* x) {
      if (*(int*)x < *(int*)accumulatorp) {
        *(int*)accumulatorp = *(int*)x;
      }
    }
    
    void print_array(int* array, size_t nel) {
      printf("{");
      for (size_t i = 0; i < nel; ++i) {
        if (i > 0) printf(",");
        printf("%d", array[i]);
      }
      printf("}");
    }
    
    int main(void) {
      int array[] = {1,2,3,4};
      size_t nel = sizeof(array) / sizeof(int);
      {
        int accumulator = 0;
        foldr(addr, &accumulator, array, nel, sizeof(int));
        printf("foldr add 0 {1,2,3,4}\n  ");
        printf("%d\n", accumulator);
      }
      {
        int* accumulator_base = alloca(sizeof(int) * nel);
        int* accumulator = accumulator_base; 
        foldr(appendr, &accumulator_base, array, nel, sizeof(int));
        printf("foldr append {...} {1,2,3,4}\n  ");
        print_array(accumulator, nel);
        printf("\n");
      }
      {
        int* accumulator_base = alloca(sizeof(int) * nel);
        int* accumulator = accumulator_base; 
        foldl(appendl, &accumulator_base, array, nel, sizeof(int));
        printf("foldl append {...} {1,2,3,4}\n  ");
        print_array(accumulator, nel);
        printf("\n");
      }
      {
        int accumulator = 0;
        foldr(plusoner, &accumulator, array, nel, sizeof(int));
        printf("foldr plusone 0 {1,2,3,4}\n  ");
        printf("%d\n", accumulator);
      }
      {
        int accumulator;
        foldr1(maxr, &accumulator, array, nel, sizeof(int));
        printf("foldr1 max {1,2,3,4}\n  ");
        printf("%d\n", accumulator);
      }
      {
        int accumulator;
        foldl1(minl, &accumulator, array, nel, sizeof(int));
        printf("foldl1 min {1,2,3,4}\n  ");
        printf("%d\n", accumulator);
      }
      {
        int accumulator = 0;
        int output[nel+1];
        scan(addl, &accumulator, array, sizeof(int), output, sizeof(int), nel);
        printf("scan add {1,2,3,4}\n  ");
        print_array(output, nel+1);
        printf("\n");
      }
      return 0;
    }
    

    Output:

    foldr add 0 {1,2,3,4}
      10
    foldr append {...} {1,2,3,4}
      {4,3,2,1}
    foldl append {...} {1,2,3,4}
      {1,2,3,4}
    foldr plusone 0 {1,2,3,4}
      4
    foldr1 max {1,2,3,4}
      4
    foldl1 min {1,2,3,4}
      1
    scan add {1,2,3,4}
      {0,1,3,6,10}
    
  3. matthew said

    Folds make sense for any algebraic data type, so we can fold over trees, for example:

    data Tree a = Nil | Tree (Tree a) a (Tree a) deriving (Show)
    fold f x Nil = x
    fold f x (Tree t1 n t2) = f (fold f x t1) n (fold f x t2)
    
    dup = fold Tree Nil -- Duplicate
    rev = fold (flip3 Tree) Nil where
      flip3 f a b c = f c b a
    flatten = fold f Nil where
      f Nil a t = Tree Nil a t
      f (Tree t1 a1 t2) a2 t3 = Tree t1 a1 (f t2 a2 t3)
    tsum = fold add3 0 where
      add3 a b c = a+b+c
    
    t :: Tree Integer
    t = Tree (Tree Nil 1 Nil) 2 (Tree (Tree Nil 3 Nil) 4 Nil)
    
    main =
      print (tsum t) >>
      print (dup t) >>
      print (rev t) >> 
      print (flatten t) >> 
      return()
    

    Implementation of the various derived folds is left as an exercise.

  4. matthew said

    Here’s one that takes a two-argument function:

    mfold :: (a -> b -> b) -> b -> Tree a -> b
    mfold f x Nil = x
    mfold f x (Tree t1 n t2) = mfold f (f n (mfold f x t1)) t2
    tsum = mfold (+) 0 -- add up elements, again
    tolist = mfold (:) [] -- list fringe (in reverse!)
    
  5. Globules said

    Here’s some Haskell…

    import Prelude hiding (foldl, foldl1, foldr, foldr1, scanl, scanr,
                           all, any, map, maximum, minimum)
    import qualified Data.List
    
    foldr :: (a -> b -> b) -> b -> [a] -> b
    foldr _ e []     = e
    foldr f e (x:xs) = x `f` foldr f e xs
    
    foldr1 :: (a -> a -> a) -> [a] -> a
    foldr1 _ []     = error "called foldr1 with empty list"
    foldr1 f (x:xs) = foldr f x xs
    
    scanr :: (a -> b -> b) -> b -> [a] -> [b]
    scanr f e = foldr (\x (y:ys) -> x `f` y : (y:ys)) [e]
    
    foldl :: (b -> a -> b) -> b -> [a] -> b
    foldl _ e []     = e
    foldl f e (x:xs) = foldl f (e `f` x) xs
    
    foldl1 :: (a -> a -> a) -> [a] -> a
    foldl1 _ []     = error "called foldl1 with empty list"
    foldl1 f (x:xs) = foldl f x xs
    
    scanl :: (b -> a -> b) -> b -> [a] -> [b]
    scanl _ e []     = [e]
    scanl f e (x:xs) = e : scanl f (e `f` x) xs
    
    --------------------------------------------------------------------------------
    
    -- Here are various familiar functions implemented using our folds and scans.
    -- 
    -- Note that they may not be the most efficient implementations in Haskell,
    -- especially those based on left folds.
    
    map :: (a -> b) -> [a] -> [b]
    map f = foldr (\x xs -> f x : xs) []
    
    maximum :: Ord a => [a] -> a
    maximum = foldr1 max
    
    minimum :: Ord a => [a] -> a
    minimum = foldl1 min
    
    any :: (a -> Bool) -> [a] -> Bool
    any p = foldr (\x b -> p x || b) False
    
    all :: (a -> Bool) -> [a] -> Bool
    all p = foldl (\b x -> p x && b) True
    
    inits :: [a] -> [[a]]
    inits = scanl (\xs x -> xs ++ [x]) []
    
    tails :: [a] -> [[a]]
    tails = scanr (:) []
    
    --------------------------------------------------------------------------------
    
    test :: (Eq d, Show d) => (a -> b -> c -> d) ->
                              (a -> b -> c -> d) ->
                              a -> b -> c ->
                              IO ()
    test f1 f2 g e xs =
      let r1 = f1 g e xs
          r2 = f2 g e xs
      in putStrLn $ show r1 ++ " == " ++ show r2 ++ " ?  " ++ show (r1 == r2)
    
    main :: IO ()
    main = do
      let xs = [1..5] :: [Int]
      
      -- Check that our functions give the same results as the standard ones.
      
      test Data.List.foldr foldr (+) 0 xs                   -- sum
      test Data.List.foldr foldr (*) 1 xs                   -- product
      test Data.List.foldr foldr (-) 3 xs
      test Data.List.foldr foldr (:) [] xs                  -- id
      test Data.List.foldr foldr (\y ys -> ys ++ [y]) [] xs -- reverse
    
      test Data.List.foldl foldl (+) 0 xs                   -- sum
      test Data.List.foldl foldl (*) 1 xs                   -- product
      test Data.List.foldl foldl (-) 3 xs
      test Data.List.foldl foldl (\ys y -> ys ++ [y]) [] xs -- id
      test Data.List.foldl foldl (flip (:)) [] xs           -- reverse
      
      -- Exercise our versions of some common functions.
      
      print $ map (+2) xs
      print $ maximum xs
      print $ minimum xs
      print $ any (> 3) xs
      print $ all (> 3) xs
      print $ inits xs
      print $ tails xs
    
    $ ./folds 
    15 == 15 ?  True
    120 == 120 ?  True
    0 == 0 ?  True
    [1,2,3,4,5] == [1,2,3,4,5] ?  True
    [5,4,3,2,1] == [5,4,3,2,1] ?  True
    15 == 15 ?  True
    120 == 120 ?  True
    -12 == -12 ?  True
    [1,2,3,4,5] == [1,2,3,4,5] ?  True
    [5,4,3,2,1] == [5,4,3,2,1] ?  True
    [3,4,5,6,7]
    5
    1
    True
    False
    [[],[1],[1,2],[1,2,3],[1,2,3,4],[1,2,3,4,5]]
    [[1,2,3,4,5],[2,3,4,5],[3,4,5],[4,5],[5],[]]
    

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: