Symbolic Differentiation
March 24, 2017
We copy our solution directly from SICP, changing only the call to the error
function in deriv
, for which Chez Scheme differs from the MIT Scheme dialect used in SICP:
(define (variable? x) (symbol? x)) (define (same-variable? v1 v2) (and (variable? v1) (variable? v2) (eq? v1 v2))) (define (=number? exp num) (and (number? exp) (= exp num))) (define (make-sum a1 a2) (cond ((=number? a1 0) a2) ((=number? a2 0) a1) ((and (number? a1) (number? a2)) (+ a1 a2)) (else (list '+ a1 a2)))) (define (make-product m1 m2) (cond ((or (=number? m1 0) (=number? m2 0)) 0) ((=number? m1 1) m2) ((=number? m2 1) m1) ((and (number? m1) (number? m2)) (* m1 m2)) (else (list '* m1 m2)))) (define (sum? x) (and (pair? x) (eq? (car x) '+))) (define (addend s) (cadr s)) (define (augend s) (caddr s)) (define (product? x) (and (pair? x) (eq? (car x) '*))) (define (multiplier p) (cadr p)) (define (multiplicand p) (caddr p)) (define (deriv exp var) (cond ((number? exp) 0) ((variable? exp) (if (same-variable? exp var) 1 0)) ((sum? exp) (make-sum (deriv (addend exp) var) (deriv (augend exp) var))) ((product? exp) (make-sum (make-product (multiplier exp) (deriv (multiplicand exp) var)) (make-product (deriv (multiplier exp) var) (multiplicand exp)))) (else (error 'deriv "unknown expression type"))))
This is the version of the code that produces “simplified” results. I would show you my version of the code, but it’s ugly, as I stuffed pretty much everything into a single function.
Here are some examples:
> (deriv '(+ x 3) 'x) 1 > (deriv '(* x y) 'x) y > (deriv '(* (* x y) (+ x 3)) 'x) (+ (* x y) (* y (+ x 3))) > (deriv '(log x) 'x) Exception in deriv: unknown expression type
You can run the program at http://ideone.com/Ac80Vb. You might be interested in the exercises at SICP, which extend the solution given above.
I used to know Prolog. And I used to be curious about Mercury, but now I digress – the following is in standard Prolog, tested with variations on a polynomial expression in GNU Prolog.
Here’s a Haskell version. Perhaps the most notable thing is making (Expr a) an instance of the Num typeclass, which allows us to create expressions using the standard arithmetic operators. I was playing around with simplifying expressions to try to clean up the output a bit. E.g. to avoid things like “… 1*(0-x) …”, etc. Since that’s mostly boilerplate I put it at the end. If I were starting over I might be tempted to get rid of subtraction, take advantage of associativity and commutativity to have Add and Mul take lists of expressions and add Pow so that consecutive variables, v, could be written as v^n. Then simplification would convert an expression to a sum of multiplications, sort each component to gather together like variables and their powers, sort the components, add like components, etc.
Two solutions: Prolog and C11.
Here is the Prolog solution:
d(X, Y * Z, Y * DZ + Z * DY) :-
d(X, Y, DY),
d(X, Z, DZ),
!.
d(X, Y + Z, DY + DZ) :-
d(X, Y, DY),
d(X, Z, DZ),
!.
d(X, Y - Z, DY - DZ) :-
d(X, Y, DY),
d(X, Z, DZ),
!.
d(X, X, 1) :-
!.
d(X, Y, 0) :-
!.
An example:
?- d(x, x * (x - 1) + c * (x - 1), D).
D = x* (1-0)+ (x-1)*1+ (c* (1-0)+ (x-1)*0).
(Example polynomial taken from Jussi Piitulainen’s example.)
And here is the C11 solution:
#include <iso646.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "gc.h"
typedef struct tree {
char *value;
struct tree *l, *r;
} tree_s;
tree_s *add(tree_s *l, tree_s *r) {
tree_s* result = GC_MALLOC(sizeof(tree_s));
result->value = "+";
result->l = l ;
result->r = r ;
return result;
}
tree_s *mult(tree_s *l, tree_s *r) {
tree_s* result = GC_MALLOC(sizeof(tree_s));
result->value = "*";
result->l = l ;
result->r = r ;
return result;
}
#define $(x) literal(x)
tree_s *literal(char *value) {
tree_s *result = GC_MALLOC(sizeof(tree_s));
result->value = value;
result->l = NULL;
result->r = NULL;
return result;
}
tree_s *d(tree_s *t, char *x) {
if (strcmp(t->value, "+") == 0) {
return add(d(t->l, x),
d(t->r, x));
} else if (strcmp(t->value, "*") == 0) {
return add(mult(t->l,
d(t->r, x)),
mult(t->r,
d(t->l, x)));
} else if (strcmp(t->value, x) == 0) {
return literal("1");
} else {
return literal("0");
}
}
void print_tree(tree_s *t) {
#define binop(t, op) printf("("); print_tree(t->l); printf(" " op " "); print_tree(t->r); printf(")");
if (strcmp(t->value, "+") == 0) {
binop(t, "+");
} else if (strcmp(t->value, "*") == 0) {
binop(t, "*");
} else {
printf("%s", t->value);
}
#undef binop
}
bool equal(tree_s *x, tree_s *y) {
if (x == NULL) {
return y == NULL;
} else if (y == NULL) {
return false;
} else {
return strcmp(x->value, y->value) == 0
and equal(x->l, y->l)
and equal(x->r, y->r);
}
}
tree_s *simplify(tree_s *x) {
#define match(s1, s2) (strcmp(s1, s2) == 0)
if (x == NULL) {
return NULL;
}
x->l = simplify(x->l);
x->r = simplify(x->r);
if (match(x->value, "+")) {
if (match(x->l->value, "0")) { return x->r; }
else if (match(x->r->value, "0")) { return x->l; }
} else if (match(x->value, "*")) {
if (match(x->l->value, "0")) { return $("0"); }
else if (match(x->r->value, "0")) { return $("0"); }
else if (match(x->l->value, "1")) { return x->r; }
else if (match(x->r->value, "1")) { return x->l; }
}
return x;
#undef match
}
int main(int argc, char **argv) {
GC_INIT();
char *var = "x";
tree_s* exp = add(mult(add($("x"),
$("y")),
add(mult($("x"),
$("2")),
mult($("x"),
mult($("y"),
$("z"))))),
mult(add($("x"),
$("y")),
add(mult($("x"),
$("2")),
mult($("x"),
mult($("y"),
$("z"))))));
exp = add(exp, mult(exp, exp));
puts("The derivative of\n");
print_tree(simplify(exp));
printf("\n\nwith respect to %s is\n\n", var);
print_tree(simplify(d(exp, var)));
puts("\n");
exit(0);
}