Symbolic Differentiation
March 24, 2017
One of the original motivations for the Lisp language was to write a program that performed symbolic differentiation. In this exercise we look at the symbolic differentiator in Section 2.3.2 of SICP, which handles expressions containing only addition and multiplication according to the following rules:
, with
,
,
, and
.
Your task is to write a program that performs symbolic differentiation according to the rules given above. When you are finished, you are welcome to read or run a suggested solution, or to post your own solution or discuss the exercise in the comments below.
I used to know Prolog. And I used to be curious about Mercury, but now I digress – the following is in standard Prolog, tested with variations on a polynomial expression in GNU Prolog.
Here’s a Haskell version. Perhaps the most notable thing is making (Expr a) an instance of the Num typeclass, which allows us to create expressions using the standard arithmetic operators. I was playing around with simplifying expressions to try to clean up the output a bit. E.g. to avoid things like “… 1*(0-x) …”, etc. Since that’s mostly boilerplate I put it at the end. If I were starting over I might be tempted to get rid of subtraction, take advantage of associativity and commutativity to have Add and Mul take lists of expressions and add Pow so that consecutive variables, v, could be written as v^n. Then simplification would convert an expression to a sum of multiplications, sort each component to gather together like variables and their powers, sort the components, add like components, etc.
Two solutions: Prolog and C11.
Here is the Prolog solution:
d(X, Y * Z, Y * DZ + Z * DY) :-
d(X, Y, DY),
d(X, Z, DZ),
!.
d(X, Y + Z, DY + DZ) :-
d(X, Y, DY),
d(X, Z, DZ),
!.
d(X, Y - Z, DY - DZ) :-
d(X, Y, DY),
d(X, Z, DZ),
!.
d(X, X, 1) :-
!.
d(X, Y, 0) :-
!.
An example:
?- d(x, x * (x - 1) + c * (x - 1), D).
D = x* (1-0)+ (x-1)*1+ (c* (1-0)+ (x-1)*0).
(Example polynomial taken from Jussi Piitulainen’s example.)
And here is the C11 solution:
#include <iso646.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "gc.h"
typedef struct tree {
char *value;
struct tree *l, *r;
} tree_s;
tree_s *add(tree_s *l, tree_s *r) {
tree_s* result = GC_MALLOC(sizeof(tree_s));
result->value = "+";
result->l = l ;
result->r = r ;
return result;
}
tree_s *mult(tree_s *l, tree_s *r) {
tree_s* result = GC_MALLOC(sizeof(tree_s));
result->value = "*";
result->l = l ;
result->r = r ;
return result;
}
#define $(x) literal(x)
tree_s *literal(char *value) {
tree_s *result = GC_MALLOC(sizeof(tree_s));
result->value = value;
result->l = NULL;
result->r = NULL;
return result;
}
tree_s *d(tree_s *t, char *x) {
if (strcmp(t->value, "+") == 0) {
return add(d(t->l, x),
d(t->r, x));
} else if (strcmp(t->value, "*") == 0) {
return add(mult(t->l,
d(t->r, x)),
mult(t->r,
d(t->l, x)));
} else if (strcmp(t->value, x) == 0) {
return literal("1");
} else {
return literal("0");
}
}
void print_tree(tree_s *t) {
#define binop(t, op) printf("("); print_tree(t->l); printf(" " op " "); print_tree(t->r); printf(")");
if (strcmp(t->value, "+") == 0) {
binop(t, "+");
} else if (strcmp(t->value, "*") == 0) {
binop(t, "*");
} else {
printf("%s", t->value);
}
#undef binop
}
bool equal(tree_s *x, tree_s *y) {
if (x == NULL) {
return y == NULL;
} else if (y == NULL) {
return false;
} else {
return strcmp(x->value, y->value) == 0
and equal(x->l, y->l)
and equal(x->r, y->r);
}
}
tree_s *simplify(tree_s *x) {
#define match(s1, s2) (strcmp(s1, s2) == 0)
if (x == NULL) {
return NULL;
}
x->l = simplify(x->l);
x->r = simplify(x->r);
if (match(x->value, "+")) {
if (match(x->l->value, "0")) { return x->r; }
else if (match(x->r->value, "0")) { return x->l; }
} else if (match(x->value, "*")) {
if (match(x->l->value, "0")) { return $("0"); }
else if (match(x->r->value, "0")) { return $("0"); }
else if (match(x->l->value, "1")) { return x->r; }
else if (match(x->r->value, "1")) { return x->l; }
}
return x;
#undef match
}
int main(int argc, char **argv) {
GC_INIT();
char *var = "x";
tree_s* exp = add(mult(add($("x"),
$("y")),
add(mult($("x"),
$("2")),
mult($("x"),
mult($("y"),
$("z"))))),
mult(add($("x"),
$("y")),
add(mult($("x"),
$("2")),
mult($("x"),
mult($("y"),
$("z"))))));
exp = add(exp, mult(exp, exp));
puts("The derivative of\n");
print_tree(simplify(exp));
printf("\n\nwith respect to %s is\n\n", var);
print_tree(simplify(d(exp, var)));
puts("\n");
exit(0);
}