Symbolic differentiation in C#

As part of my (sadly neglected) side project to learn Lisp, I have been reading Structure and Interpretation of Computer Programs. I strongly recommend that all software developers read this book, even if you aren’t interested in Lisp, as the concepts it teaches are basic to our everyday job, even though we might not realise it. I only about a third of the way through it so far, but it has given me some excellent insights into the way I design code. OK, that assumes I design code rather than just sitting down and writing it! Maybe I should have said that it has encouraged me to think about my code design before starting to write it!

One of the sections showed how to do symbolic differentiation in Lisp. For those not familiar with this, it’s a mathematical concept that, in it’s simplest understanding allows you to find the gradient of a line at any point. This is a lot more important in maths than you might think, and is a fundamental part of any course in maths. For a gentle introduction to differentiation, including an interactive differentiator, see this page.

The Lisp they wrote to differentiate expressions was quite short and elegant. You can see the sample code here. Even if you don’t understand Lisp, it’s pretty impressive how much they achieved with so little code.

Being an incurable tinkerer, I wondered how hard it would be to write such code in C#.

Design

One of the (many) nice things about Lisp is the ease with which you can build data structures. As lists are built in, they provide a very simple way of structuring your data. The authors took advantage of this and represented expressions as nested lists, where each list represents one part of the full expression.

For example, a simple expression like 2x could be represented as a single list '(* 2 x). Note that Lisp uses prefix notation, so the operator comes first, and the whole list is annotated with a ` to indicate that it should not be evaluated, but treated as a literal value.

A more complex expression can be built up by inserting lists where the previous example had a symbol or value. For example, we can represent 3x + 2 as '(+ (* 3 x) 2) and so on. Once you are used to prefix notation, this is pretty easy to read, and is quite close to the mathematical representation.

I wanted to achieve the same in C# if it were possible.

For simplicity, I assume that I would only handle three types of expression…

  1. An integer constant
  2. A symbol that represents a variable, eg the “x” in (2x + 2). Note that we can have multiple symbols in an expression, but will only differentiate by one of them
  3. A combination of two other expressions, using a binary operator, eg (1 + 2) or (exp1 * exp2) where exp1 and exp2 are previously defined expressions

For simplicity, we are assuming all numbers involved are integers, and so only support *, + and -, as / could result in non-integers. Actually, that was later modified as I realised that I would need to support ^ as well to avoid unpleasantly complex-looking expressions.

My initial thought was to have an empty DifExpInterface (DifExp being short for “differentiable expression”), and define three classes that inherited from this. This would allow me to write methods that could handle any of the three sub-types, and switch on the type to handle them differently. This worked fine, but gave some really ugly looking code when I tried to construct expressions, as I needed to include the types when creating new objects.

I then switched to a single DifExp class that contained a Type property that specified the type. I also needed properties for the value (if it were an integer), the symbol name (for a symbol) or the three parts of the expression (first operand, operator and second operand) for an expression. This didn’t feel so elegant, but made the construction much neater (see below).

The second design decision I didn’t really like was to use hard-coded strings for operators, instead of an enum. Again, this was to enhance readability when constructing an expression.

The basic DifExp class

Adding in an override for ToString(), the basic class looked like this…

class DifExp {
  private DifExpType Type { init; get; }
  private int Value { init; get; }
  private string Symbol { init; get; } = "";
  private (DifExp E1, string Op, DifExp E2) Expression { init; get; }

  public DifExp(int value) {
    Type = DifExpType.Value;
    Value = value;
  }
  public DifExp(string symbol) {
    Type = DifExpType.Symbol;
    Symbol = symbol;
  }
  public DifExp(string symbol, int power) {
    // This ctor is used for raising a symbol to an integer power
    Type = DifExpType.Expresssion;
    Expression = (new(symbol), "^", new(power));
  }
  public DifExp(DifExp e1, string op, DifExp e2) {
    if (!new[] { "+", "*", "-", "^" }.Contains(op)) {
      throw new ArgumentException($"Unsupported operator: {op}");
    }
    Type = DifExpType.Expresssion;
    Expression = (e1, op, e2);
  }

  public override string ToString() =>
    Type switch {
      DifExpType.Value => Value.ToString(),
      DifExpType.Symbol => Symbol,
      DifExpType.Expresssion => $"({Expression.E1} {Expression.Op} {Expression.E2})"
    };

  public enum DifExpType {
    Value,
    Symbol,
    Expresssion
  }
}

The astute reader may notice that adding a symbol expression requires naming the symbol, eg new("x"). The reason for this is that there is no real reason why an expression has to be over a single symbol, it would be perfectly reasonable to allow the code to handle expressions over multiple symbols, such as x2y2 + xy. If so, you need to specify which symbol is to be used for the derivation. This allows for partial derivatives.

As you can see, I added constructors for each of the three basic types, value, symbol and expression, as well as the later addition of a constructor for xn. To keep things simple, I only allowed a symbol to be raised to an integer power.

This allowed me to construct expressions as shown below. The comments show the result of calling ToString()

DifExp de1 = new(new(1), "*", new(2)); // (1 * 2)
DifExp de2 = new(new("x"), "+", de1); // (x + (1 * 2))
DifExp de3 = new(new(2), "*", new("x")); // (2 * x)
DifExp de4 = new(de3, "+", new(10)); // ((2 * x) + 10)
DifExp de5 = new(new("x"), 2); // (x ^ 2)

As you can see, if you mentally remove the uses of new, then all but the last line read almost like valid mathematical expressions. That benefit overrode my desire to have separate subclasses and an enum for the operators.

The first two expressions could be simplified, but that was left as an exercise for later.

Differentiating

This turned out to be a lot simpler than I expected. As the basic rules for differentiation aren’t actually that difficult, the code ended up being a single switch

  public DifExp Differentiate(string symbol) =>
    Type switch {
      DifExpType.Value => new(0),
      DifExpType.Symbol => symbol == Symbol ? new(1) : new(Symbol),
      DifExpType.Expresssion =>
        Expression.Op switch {
          "+" => new(Expression.E1.Differentiate(symbol), "+", Expression.E2.Differentiate(symbol)),
          "-" => new(Expression.E1.Differentiate(symbol), "-", Expression.E2.Differentiate(symbol)),
          "*" => new(new(Expression.E1.Differentiate(symbol), "*", Expression.E2),
                     "+",
                     new(Expression.E1, "*", Expression.E2.Differentiate(symbol))),
          "^" => new(new(Expression.E2.Value), "*", new(new(symbol), "^", new(Expression.E2.Value - 1)))
        }
    };

At this stage, I had a working differentiator. There are obviously a lot of things that could be done to expand this, such as adding more derivation rules and allowing a wider variety of expressions, but given that this was only an exercise, and I had already expanded the functionality over the code shown in the book, I decided to address the one remaining niggle…

Simplifying the expressions

The first expression I created above was simply the sum of two integers…

DifExp de1 = new(new(1), "*", new(2));

However, this is not simplified, and is stored as 1 + 2. Similarly, when differentiating, expressions would often end up including terms that could be removed or simplified. For example, creating and differentiating as follows…

DifExp de2 = new(new("x"), "+", new(new(1), "*", new(2)));
DifExp de2diff = de2.Differentiate("x");

…results in the differentiated expression being represented as (1 + ((0 * 2) + (1 * 0))). This can obviously be simplified to just 1.

The book showed a method of simplifying expressions, but as I was too lazy to get out of my chair and fetch the book from the other room, I decided to roll my own (😎).

I decided to simplify when the expression was created, so we were always storing the simplest possible representation. Clearly the only constructor that needed simplifying was the one that took two expression and an operator. I could see four obvious simplifications to start with…

  1. If both expressions are values we can combine them with the operator to produce a single value (eg 1 = 2 is simply 3)
  2. If we are multiplying, and one expression is a value of zero, then we can simplify to a value of zero
  3. If we are multiplying, and one expression is a value of one, then we can simplify to the other expression
  4. If we are adding, and one expression is a value of zero, then we can simplify to the other expression

There are other simplifications that could be added, such as converting x + x to 2 * x and so on. I didn’t have time to add those as well, but realised that most of them are actually redundant if you set up your original expressions in the simplest way possible. For example, if you do new(new("x"), 2) instead of new(new("x"), "*", new("x")), then (as far as I can see) you’ll never end up with x + x in the first place.

Coding these up wasn’t too hard, and worked fine. However, I realised that when differentiating, you sometimes needed to simplify more than once, as the first simplification could still result in expressions that could be further simplified. No problem, just make the simplify method recursive.

Ha, that’s where it all went wrong! As I was simplifying in the constructor, the simplify method got called every time I created a new expression, which led to an infinite loop when creating even the most simple expression. I won’t bore you with the sorry tale of how many hours I spent breaking my head over this, but it took longer than the rest of the project combined!

So, after much tearing of hair and gnashing of teeth, I ended up with the following…

  public static DifExp Simplify(DifExp e) {
    if (e.Type != DifExpType.Expresssion) {
      return e;
    }
    return Simplify(e.Expression.E1, e.Expression.Op, e.Expression.E2);
  }

  public static DifExp Simplify(DifExp e1, string op, DifExp e2) {
    DifExp e1s = Simplify(e1);
    DifExp e2s = Simplify(e2);
    if (e1s.Type == DifExpType.Value && e2s.Type == DifExpType.Value) {
      // Both expressions are values, so combine them wth the operator specified
      return new(op switch {
        "+" => e1s.Value + e2s.Value,
        "*" => e1s.Value * e2s.Value,
        "-" => e1s.Value - e2s.Value
      });
    } else if (op == "*" && (e1s.Type == DifExpType.Value && e1s.Value == 0)
                         || (e2s.Type == DifExpType.Value && e2s.Value == 0)) {
      // We are multiplying, and one expression is a value of zero, then we are a value of zero
      return new(0);
    } else if (op == "*" && e1s.Type == DifExpType.Value && e1s.Value == 1) {
      // We are multiplying, and the first expression is a value of one, then we are the other expression
      return e2s;
    } else if (op == "*" && e2s.Type == DifExpType.Value && e2s.Value == 1) {
      // Ditto the previous comment for the second expression
      return e1s;
    } else if (op == "+" && e1s.Type == DifExpType.Value && e1s.Value == 0) {
      // We are adding, and the first expression is a value of zero, then we are the other expression
      return e2s;
    } else if (op == "+" && e2s.Type == DifExpType.Value && e2s.Value == 0) {
      // Ditto the previous comment for the second expression
      return e1s;
    }
    // Nothing to simplify, so return a new expression
    return new(e1s, op, e2s, false);
  }

This includes all of the cases listed above, although could probably be simplified if I had a clear enough head to think about it!

There are two overloads, as this allowed me to simplify a single previously-existing expression, or simplify one being constructed from two expressions and an operator. It’s possible that I could get away without this after I added the optional constructor parameter (see next paragraph), but I had had enough by this point. I had achieved more than I originally set out to do, and really needed to do something more productive with my time!

This was called from the constructor that took two expressions and an operator…

  public DifExp(DifExp e1, string op, DifExp e2, bool simplify = true) {
    if (!new[] { "+", "*", "-", "^" }.Contains(op)) {
      throw new ArgumentException($"Unsupported operator: {op}");
    }
    if (simplify) {
      DifExp exp = Simplify(e1, op, e2);
      Type = exp.Type;
      Value = exp.Value;
      Symbol = exp.Symbol;
      Expression = exp.Expression;
    } else {
      Type = DifExpType.Expresssion;
      Expression = (e1, op, e2);
    }
  }

As I mentioned, I had a problem with infinite loops, and needed some way of avoiding these. I ended up modifying the constructor above to take an optional parameter to specify if we were to simplify the expression. We would generally want to do this, and only want to avoid it in the penultimate line of the Simplify() method.

Once I had that working, I could differentiate quite complex expressions, and get a neat representation at the end.

The full code sample with some examples can be found in this gist.

Comparison with the Lisp code

Due to the static type system in C#, the code above is more complex than the Lisp, and I feel slightly harder to read. Given that this sort of symbolic manipulation is one of the super powers of Lisp, I guess that’s not surprising.

Compare the representations of 2x2 + 3x + 4 in the two languages…

// C#
new(new(new(new(2), "*", new(new("x"), 2)), "+", new(new(3), "*", new("x"))), "+", new(4))
; Lisp
'(+ (+ (* 2 (* x x)) (* 3 x)) 4)

Bit of a difference in clarity!

Be First to Comment

Leave a Reply

Your email address will not be published.

This site uses Akismet to reduce spam. Learn how your comment data is processed.