Symbolic Differentiation

March 24, 2017

We copy our solution directly from SICP, changing only the call to the error function in deriv, for which Chez Scheme differs from the MIT Scheme dialect used in SICP:

(define (variable? x) (symbol? x))

(define (same-variable? v1 v2)
  (and (variable? v1)
       (variable? v2)
       (eq? v1 v2)))

(define (=number? exp num)
  (and (number? exp) (= exp num)))

(define (make-sum a1 a2)
  (cond ((=number? a1 0) a2)
        ((=number? a2 0) a1)
        ((and (number? a1) (number? a2)) 
         (+ a1 a2))
        (else (list '+ a1 a2))))

(define (make-product m1 m2)
  (cond ((or (=number? m1 0) 
             (=number? m2 0)) 
         0)
        ((=number? m1 1) m2)
        ((=number? m2 1) m1)
        ((and (number? m1) (number? m2)) 
         (* m1 m2))
        (else (list '* m1 m2))))

(define (sum? x)
  (and (pair? x) (eq? (car x) '+)))

(define (addend s) (cadr s))

(define (augend s) (caddr s))

(define (product? x)
  (and (pair? x) (eq? (car x) '*)))

(define (multiplier p) (cadr p))

(define (multiplicand p) (caddr p))

(define (deriv exp var)
  (cond ((number? exp) 0)
        ((variable? exp)
         (if (same-variable? exp var) 1 0))
        ((sum? exp)
         (make-sum (deriv (addend exp) var)
                   (deriv (augend exp) var)))
        ((product? exp)
         (make-sum
          (make-product 
           (multiplier exp)
           (deriv (multiplicand exp) var))
          (make-product 
           (deriv (multiplier exp) var)
           (multiplicand exp))))
        (else (error 'deriv "unknown expression type"))))

This is the version of the code that produces “simplified” results. I would show you my version of the code, but it’s ugly, as I stuffed pretty much everything into a single function.

Here are some examples:

> (deriv '(+ x 3) 'x)
1
> (deriv '(* x y) 'x)
y
> (deriv '(* (* x y) (+ x 3)) 'x)
(+ (* x y) (* y (+ x 3)))
> (deriv '(log x) 'x)
Exception in deriv: unknown expression type

You can run the program at http://ideone.com/Ac80Vb. You might be interested in the exercises at SICP, which extend the solution given above.

Pages: 1 2

3 Responses to “Symbolic Differentiation”

  1. Jussi Piitulainen said

    I used to know Prolog. And I used to be curious about Mercury, but now I digress – the following is in standard Prolog, tested with variations on a polynomial expression in GNU Prolog.

    % dee(+variable, +expression, -derivative)
    
    dee(X, X, 1).
    dee(X, Y, 0) :- atomic(Y), X \= Y.
    
    dee(X, U + W, DU + DW) :- aux(X, U, W, DU, DW).
    dee(X, U - W, DU - DW) :- aux(X, U, W, DU, DW).
    
    dee(X, U * W, U * DW + W * DU) :- aux(X, U, W, DU, DW).
    
    dee(X, -U, -DU) :- dee(X, U, DU).
    
    aux(X, U, W, DU, DW) :-
        dee(X, U, DU),
        dee(X, W, DW).
    
    % Rather than a larger class of symbolic expressions,
    % one might be tempted to implement a simplifier :-)
    % 
    % | ?- dee(x, (x + c) * (x - 1), D),             
    %      dee(x, x * (x - 1) + c * (x - 1), E),
    %      dee(x, x * x - x + c * x - c, F).    
    %
    % D = (x+c)*(1-0)+(x-1)*(1+0)
    % E = x*(1-0)+(x-1)*1+(c*(1-0)+(x-1)*0)
    % F = x*1+x*1-1+(c*1+x*0)-0 ? ;
    % 
    % no
    
  2. Globules said

    Here’s a Haskell version. Perhaps the most notable thing is making (Expr a) an instance of the Num typeclass, which allows us to create expressions using the standard arithmetic operators. I was playing around with simplifying expressions to try to clean up the output a bit. E.g. to avoid things like “… 1*(0-x) …”, etc. Since that’s mostly boilerplate I put it at the end. If I were starting over I might be tempted to get rid of subtraction, take advantage of associativity and commutativity to have Add and Mul take lists of expressions and add Pow so that consecutive variables, v, could be written as v^n. Then simplification would convert an expression to a sum of multiplications, sort each component to gather together like variables and their powers, sort the components, add like components, etc.

    data Expr a = Const a               -- a constant number
                | Neg (Expr a)          -- negating an expression
                | Add (Expr a) (Expr a) -- addition
                | Sub (Expr a) (Expr a) -- subtraction
                | Mul (Expr a) (Expr a) -- multiplcation
                | Var String            -- a variable
               deriving (Eq)
    
    -- Calculate the derivative of an expression with respect to a named variable.
    -- For example, d "x" expr, is the derivative of expr with respect to x.
    d :: Num a => String -> Expr a -> Expr a
    d _ (Const _) = 0
    d v (Neg x)   = negate (d v x)
    d v (Add x y) = d v x + d v y
    d v (Sub x y) = d v x - d v y
    d v (Mul x y) = (x * d v y) + (y * d v x)
    d v (Var x)   = if v == x then 1 else 0
    
    -- Make Expr an instance of Num so we can write expressions in a more natural
    -- style.
    instance Num a => Num (Expr a) where
      (+) = Add
      (*) = Mul
      (-) = Sub
      negate = Neg
      fromInteger = Const . fromInteger
      abs    = error "abs not implemented!"
      signum = error "signum not implemented!"
    
    -- Convert an expression to a string.  Bracket most things to avoid ambiguity.
    instance Show a => Show (Expr a) where
      show (Const c) = show c
      show (Neg x)   = "-" ++ show x
      show (Add x y) = "(" ++ show x ++ "+" ++ show y ++ ")"
      show (Sub x y) = "(" ++ show x ++ "-" ++ show y ++ ")"
      show (Mul x y) = "(" ++ show x ++ "*" ++ show y ++ ")"
      show (Var v)   = v
    
    -- Print an expression and its (simplified) derivative.
    demo :: String -> (String -> Expr Int -> Expr Int) -> Expr Int -> IO ()
    demo var dfn expr = putStrLn $ "d/d" ++ var ++ " " ++ show expr ++ " = " ++
                                   show (simplify $ dfn var expr)
    
    main :: IO ()
    main = do
      let x = Var "x"
          y = Var "y"
      demo "x" d 5
      demo "x" d (x*x)
      demo "x" d y
      demo "x" d (x + y)
      demo "x" d (y - x)
      demo "x" d (x * y)
      
      -- Because Exprs are Nums we get exponentiation for free.
      let e = x*9 - 5*x + x*y + x^3 + (5-x^2)*y^2
      demo "x" d e
    
    -------------------------------------------------------------------------------
    
    -- A grab bag of expression simplification rules.
    s :: (Eq a, Num a) => Expr a -> Expr a
    s (Neg             (Neg x)) = s x
    s (Add        0         x ) = s x
    s (Add        x         0 ) = s x
    s (Add (Const x) (Const y)) = Const (x + y)
    s (Add        x    (Neg y)) = s x - s y
    s (Add        x         y ) = s x + s y
    s (Sub        0         x ) = negate (s x)
    s (Sub        x         0 ) = s x
    s (Sub (Const x) (Const y)) = Const (x - y)
    s (Sub        x         y ) = s x - s y
    s (Mul        0         _ ) = 0
    s (Mul        _         0 ) = 0
    s (Mul        1         x ) = s x
    s (Mul      (-1)        x ) = negate (s x)
    s (Mul        x         1 ) = s x
    s (Mul        x       (-1)) = negate (s x)
    s (Mul (Const x) (Const y)) = Const (x * y)
    s (Mul        x  (Neg   y)) = negate (s x * s y)
    s (Mul   (Neg x)        y ) = negate (s x * s y)
    s (Mul        x         y ) = s x * s y
    s x = x
    
    -- Iteratively simplify the expression until the result doesn't change.
    simplify :: (Eq a, Num a) => Expr a -> Expr a
    simplify e = let es = iterate s e
                 in fst $ head $ dropWhile (uncurry (/=)) $ zip es (tail es)
    
    $ ./symdiff 
    d/dx 5 = 0
    d/dx (x*x) = (x+x)
    d/dx y = 0
    d/dx (x+y) = 1
    d/dx (y-x) = -1
    d/dx (x*y) = y
    d/dx (((((x*9)-(5*x))+(x*y))+((x*x)*x))+((5-(x*x))*(y*y))) = (((4+y)+((x*x)+(x*(x+x))))-((y*y)*(x+x)))
    
  3. john said

    Two solutions: Prolog and C11.

    Here is the Prolog solution:


    d(X, Y * Z, Y * DZ + Z * DY) :-
        d(X, Y, DY),
        d(X, Z, DZ),
        !.

    d(X, Y + Z, DY + DZ) :-
        d(X, Y, DY),
        d(X, Z, DZ),
        !.

    d(X, Y - Z, DY - DZ) :-
        d(X, Y, DY),
        d(X, Z, DZ),
        !.

    d(X, X, 1) :-
        !.

    d(X, Y, 0) :-
        !.

    An example:


    ?- d(x, x * (x - 1) + c * (x - 1), D).
    D = x* (1-0)+ (x-1)*1+ (c* (1-0)+ (x-1)*0).

    (Example polynomial taken from Jussi Piitulainen’s example.)

    And here is the C11 solution:


    #include <iso646.h>
    #include <stdbool.h>
    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>

    #include "gc.h"

    typedef struct tree {
      char *value;
      struct tree *l, *r;
    } tree_s;

    tree_s *add(tree_s *l, tree_s *r) {
      tree_s* result = GC_MALLOC(sizeof(tree_s));
      
      result->value = "+";
      result->l = l ;
      result->r = r ;

      return result;
    }

    tree_s *mult(tree_s *l, tree_s *r) {
      tree_s* result = GC_MALLOC(sizeof(tree_s));

      result->value = "*";
      result->l = l ;
      result->r = r ;
      
      return result;
    }

    #define $(x) literal(x)
    tree_s *literal(char *value) {
      tree_s *result = GC_MALLOC(sizeof(tree_s));

      result->value = value;
      result->l = NULL;
      result->r = NULL;

      return result;
    }

    tree_s *d(tree_s *t, char *x) {
      if (strcmp(t->value, "+") == 0) {
        return add(d(t->l, x),
                   d(t->r, x));
      } else if (strcmp(t->value, "*") == 0) {
        return add(mult(t->l,
                        d(t->r, x)),
                   mult(t->r,
                        d(t->l, x)));
      } else if (strcmp(t->value, x) == 0) {
        return literal("1");
      } else {
        return literal("0");
      }
    }

    void print_tree(tree_s *t) {
    #define binop(t, op) printf("("); print_tree(t->l); printf(" " op " "); print_tree(t->r); printf(")");
      if (strcmp(t->value, "+") == 0) {
        binop(t, "+");
      } else if (strcmp(t->value, "*") == 0) {
        binop(t, "*");
      } else {
        printf("%s", t->value);
      }
    #undef binop
    }

    bool equal(tree_s *x, tree_s *y) {
      if (x == NULL) {
        return y == NULL;
      } else if (y == NULL) {
        return false;
      } else {
        return strcmp(x->value, y->value) == 0
          and equal(x->l, y->l)
          and equal(x->r, y->r);
      }
    }

    tree_s *simplify(tree_s *x) {
    #define match(s1, s2) (strcmp(s1, s2) == 0)
      if (x == NULL) {
        return NULL;
      }
      
      x->l = simplify(x->l);
      x->r = simplify(x->r);

      if (match(x->value, "+")) {
        if (match(x->l->value, "0")) { return x->r; }
        else if (match(x->r->value, "0")) { return x->l; }
      } else if (match(x->value, "*")) {
        if (match(x->l->value, "0")) { return $("0"); }
        else if (match(x->r->value, "0")) { return $("0"); }
        else if (match(x->l->value, "1")) { return x->r; }
        else if (match(x->r->value, "1")) { return x->l; }
      }

      return x;
    #undef match
    }

    int main(int argc, char **argv) {
      GC_INIT();
      
      char *var = "x";
      tree_s* exp = add(mult(add($("x"),
                             $("y")),
                         add(mult($("x"),
                                  $("2")),
                             mult($("x"),
                                  mult($("y"),
                                       $("z"))))),
                        mult(add($("x"),
                                 $("y")),
                             add(mult($("x"),
                                      $("2")),
                                 mult($("x"),
                                      mult($("y"),
                                           $("z"))))));
      
      exp = add(exp, mult(exp, exp));

      puts("The derivative of\n");
      print_tree(simplify(exp));
      printf("\n\nwith respect to %s is\n\n", var);
      print_tree(simplify(d(exp, var)));
      puts("\n");
      
      exit(0);
    }

Leave a comment