Symbolic Differentiation

March 24, 2017

One of the original motivations for the Lisp language was to write a program that performed symbolic differentiation. In this exercise we look at the symbolic differentiator in Section 2.3.2 of SICP, which handles expressions containing only addition and multiplication according to the following rules:

\frac{dc}{dx} = 0, with c \not= x,

\frac{dx}{dx} = 1,

\frac{d(u+v)}{dx} = \frac{du}{dx} + \frac{dv}{dx}, and

\frac{d(uv)}{dx} = u\frac{dv}{dx} + v\frac{du}{dx}.

Your task is to write a program that performs symbolic differentiation according to the rules given above. When you are finished, you are welcome to read or run a suggested solution, or to post your own solution or discuss the exercise in the comments below.

Advertisement

Pages: 1 2

3 Responses to “Symbolic Differentiation”

  1. Jussi Piitulainen said

    I used to know Prolog. And I used to be curious about Mercury, but now I digress – the following is in standard Prolog, tested with variations on a polynomial expression in GNU Prolog.

    % dee(+variable, +expression, -derivative)
    
    dee(X, X, 1).
    dee(X, Y, 0) :- atomic(Y), X \= Y.
    
    dee(X, U + W, DU + DW) :- aux(X, U, W, DU, DW).
    dee(X, U - W, DU - DW) :- aux(X, U, W, DU, DW).
    
    dee(X, U * W, U * DW + W * DU) :- aux(X, U, W, DU, DW).
    
    dee(X, -U, -DU) :- dee(X, U, DU).
    
    aux(X, U, W, DU, DW) :-
        dee(X, U, DU),
        dee(X, W, DW).
    
    % Rather than a larger class of symbolic expressions,
    % one might be tempted to implement a simplifier :-)
    % 
    % | ?- dee(x, (x + c) * (x - 1), D),             
    %      dee(x, x * (x - 1) + c * (x - 1), E),
    %      dee(x, x * x - x + c * x - c, F).    
    %
    % D = (x+c)*(1-0)+(x-1)*(1+0)
    % E = x*(1-0)+(x-1)*1+(c*(1-0)+(x-1)*0)
    % F = x*1+x*1-1+(c*1+x*0)-0 ? ;
    % 
    % no
    
  2. Globules said

    Here’s a Haskell version. Perhaps the most notable thing is making (Expr a) an instance of the Num typeclass, which allows us to create expressions using the standard arithmetic operators. I was playing around with simplifying expressions to try to clean up the output a bit. E.g. to avoid things like “… 1*(0-x) …”, etc. Since that’s mostly boilerplate I put it at the end. If I were starting over I might be tempted to get rid of subtraction, take advantage of associativity and commutativity to have Add and Mul take lists of expressions and add Pow so that consecutive variables, v, could be written as v^n. Then simplification would convert an expression to a sum of multiplications, sort each component to gather together like variables and their powers, sort the components, add like components, etc.

    data Expr a = Const a               -- a constant number
                | Neg (Expr a)          -- negating an expression
                | Add (Expr a) (Expr a) -- addition
                | Sub (Expr a) (Expr a) -- subtraction
                | Mul (Expr a) (Expr a) -- multiplcation
                | Var String            -- a variable
               deriving (Eq)
    
    -- Calculate the derivative of an expression with respect to a named variable.
    -- For example, d "x" expr, is the derivative of expr with respect to x.
    d :: Num a => String -> Expr a -> Expr a
    d _ (Const _) = 0
    d v (Neg x)   = negate (d v x)
    d v (Add x y) = d v x + d v y
    d v (Sub x y) = d v x - d v y
    d v (Mul x y) = (x * d v y) + (y * d v x)
    d v (Var x)   = if v == x then 1 else 0
    
    -- Make Expr an instance of Num so we can write expressions in a more natural
    -- style.
    instance Num a => Num (Expr a) where
      (+) = Add
      (*) = Mul
      (-) = Sub
      negate = Neg
      fromInteger = Const . fromInteger
      abs    = error "abs not implemented!"
      signum = error "signum not implemented!"
    
    -- Convert an expression to a string.  Bracket most things to avoid ambiguity.
    instance Show a => Show (Expr a) where
      show (Const c) = show c
      show (Neg x)   = "-" ++ show x
      show (Add x y) = "(" ++ show x ++ "+" ++ show y ++ ")"
      show (Sub x y) = "(" ++ show x ++ "-" ++ show y ++ ")"
      show (Mul x y) = "(" ++ show x ++ "*" ++ show y ++ ")"
      show (Var v)   = v
    
    -- Print an expression and its (simplified) derivative.
    demo :: String -> (String -> Expr Int -> Expr Int) -> Expr Int -> IO ()
    demo var dfn expr = putStrLn $ "d/d" ++ var ++ " " ++ show expr ++ " = " ++
                                   show (simplify $ dfn var expr)
    
    main :: IO ()
    main = do
      let x = Var "x"
          y = Var "y"
      demo "x" d 5
      demo "x" d (x*x)
      demo "x" d y
      demo "x" d (x + y)
      demo "x" d (y - x)
      demo "x" d (x * y)
      
      -- Because Exprs are Nums we get exponentiation for free.
      let e = x*9 - 5*x + x*y + x^3 + (5-x^2)*y^2
      demo "x" d e
    
    -------------------------------------------------------------------------------
    
    -- A grab bag of expression simplification rules.
    s :: (Eq a, Num a) => Expr a -> Expr a
    s (Neg             (Neg x)) = s x
    s (Add        0         x ) = s x
    s (Add        x         0 ) = s x
    s (Add (Const x) (Const y)) = Const (x + y)
    s (Add        x    (Neg y)) = s x - s y
    s (Add        x         y ) = s x + s y
    s (Sub        0         x ) = negate (s x)
    s (Sub        x         0 ) = s x
    s (Sub (Const x) (Const y)) = Const (x - y)
    s (Sub        x         y ) = s x - s y
    s (Mul        0         _ ) = 0
    s (Mul        _         0 ) = 0
    s (Mul        1         x ) = s x
    s (Mul      (-1)        x ) = negate (s x)
    s (Mul        x         1 ) = s x
    s (Mul        x       (-1)) = negate (s x)
    s (Mul (Const x) (Const y)) = Const (x * y)
    s (Mul        x  (Neg   y)) = negate (s x * s y)
    s (Mul   (Neg x)        y ) = negate (s x * s y)
    s (Mul        x         y ) = s x * s y
    s x = x
    
    -- Iteratively simplify the expression until the result doesn't change.
    simplify :: (Eq a, Num a) => Expr a -> Expr a
    simplify e = let es = iterate s e
                 in fst $ head $ dropWhile (uncurry (/=)) $ zip es (tail es)
    
    $ ./symdiff 
    d/dx 5 = 0
    d/dx (x*x) = (x+x)
    d/dx y = 0
    d/dx (x+y) = 1
    d/dx (y-x) = -1
    d/dx (x*y) = y
    d/dx (((((x*9)-(5*x))+(x*y))+((x*x)*x))+((5-(x*x))*(y*y))) = (((4+y)+((x*x)+(x*(x+x))))-((y*y)*(x+x)))
    
  3. john said

    Two solutions: Prolog and C11.

    Here is the Prolog solution:


    d(X, Y * Z, Y * DZ + Z * DY) :-
        d(X, Y, DY),
        d(X, Z, DZ),
        !.

    d(X, Y + Z, DY + DZ) :-
        d(X, Y, DY),
        d(X, Z, DZ),
        !.

    d(X, Y - Z, DY - DZ) :-
        d(X, Y, DY),
        d(X, Z, DZ),
        !.

    d(X, X, 1) :-
        !.

    d(X, Y, 0) :-
        !.

    An example:


    ?- d(x, x * (x - 1) + c * (x - 1), D).
    D = x* (1-0)+ (x-1)*1+ (c* (1-0)+ (x-1)*0).

    (Example polynomial taken from Jussi Piitulainen’s example.)

    And here is the C11 solution:


    #include <iso646.h>
    #include <stdbool.h>
    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>

    #include "gc.h"

    typedef struct tree {
      char *value;
      struct tree *l, *r;
    } tree_s;

    tree_s *add(tree_s *l, tree_s *r) {
      tree_s* result = GC_MALLOC(sizeof(tree_s));
      
      result->value = "+";
      result->l = l ;
      result->r = r ;

      return result;
    }

    tree_s *mult(tree_s *l, tree_s *r) {
      tree_s* result = GC_MALLOC(sizeof(tree_s));

      result->value = "*";
      result->l = l ;
      result->r = r ;
      
      return result;
    }

    #define $(x) literal(x)
    tree_s *literal(char *value) {
      tree_s *result = GC_MALLOC(sizeof(tree_s));

      result->value = value;
      result->l = NULL;
      result->r = NULL;

      return result;
    }

    tree_s *d(tree_s *t, char *x) {
      if (strcmp(t->value, "+") == 0) {
        return add(d(t->l, x),
                   d(t->r, x));
      } else if (strcmp(t->value, "*") == 0) {
        return add(mult(t->l,
                        d(t->r, x)),
                   mult(t->r,
                        d(t->l, x)));
      } else if (strcmp(t->value, x) == 0) {
        return literal("1");
      } else {
        return literal("0");
      }
    }

    void print_tree(tree_s *t) {
    #define binop(t, op) printf("("); print_tree(t->l); printf(" " op " "); print_tree(t->r); printf(")");
      if (strcmp(t->value, "+") == 0) {
        binop(t, "+");
      } else if (strcmp(t->value, "*") == 0) {
        binop(t, "*");
      } else {
        printf("%s", t->value);
      }
    #undef binop
    }

    bool equal(tree_s *x, tree_s *y) {
      if (x == NULL) {
        return y == NULL;
      } else if (y == NULL) {
        return false;
      } else {
        return strcmp(x->value, y->value) == 0
          and equal(x->l, y->l)
          and equal(x->r, y->r);
      }
    }

    tree_s *simplify(tree_s *x) {
    #define match(s1, s2) (strcmp(s1, s2) == 0)
      if (x == NULL) {
        return NULL;
      }
      
      x->l = simplify(x->l);
      x->r = simplify(x->r);

      if (match(x->value, "+")) {
        if (match(x->l->value, "0")) { return x->r; }
        else if (match(x->r->value, "0")) { return x->l; }
      } else if (match(x->value, "*")) {
        if (match(x->l->value, "0")) { return $("0"); }
        else if (match(x->r->value, "0")) { return $("0"); }
        else if (match(x->l->value, "1")) { return x->r; }
        else if (match(x->r->value, "1")) { return x->l; }
      }

      return x;
    #undef match
    }

    int main(int argc, char **argv) {
      GC_INIT();
      
      char *var = "x";
      tree_s* exp = add(mult(add($("x"),
                             $("y")),
                         add(mult($("x"),
                                  $("2")),
                             mult($("x"),
                                  mult($("y"),
                                       $("z"))))),
                        mult(add($("x"),
                                 $("y")),
                             add(mult($("x"),
                                      $("2")),
                                 mult($("x"),
                                      mult($("y"),
                                           $("z"))))));
      
      exp = add(exp, mult(exp, exp));

      puts("The derivative of\n");
      print_tree(simplify(exp));
      printf("\n\nwith respect to %s is\n\n", var);
      print_tree(simplify(d(exp, var)));
      puts("\n");
      
      exit(0);
    }

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: