24 October 2013

Automatic differentiation is a well-studied technique for computing exact numeric derivatives. Dan Piponi has a great introduction on the subject, but to give you an overview, the idea is that you introduce an algebraic symbol such that but . Formulas involving are called dual numbers (e.g., , ) in much the same way as complex numbers are formulas involving the algebraic symbol , which has the property .

Then you teach the computer how to add, subtract, multiply and divide with dual numbers. So for example, (since ). The computer keeps everything in “normal form,” i.e., , as you go along.

In order to find the derivative of some function at a point , all you have to do is compute . The answer you get (in normal form) is

So for example, let and let’s find . To do this, we compute :

We expected . Equating these two results, we can conclude that and . Which is true!

This works not just for simple polynomials, but for any compound or nested formula of any kind. Somehow dual numbers keep track of the derivative during the evaluation of the formula, respecting the chain rule, the product rule and all.

The reason this works becomes clearer when you consider the Taylor series expansion of a function:

When you evaluate , all the higher-order terms drop out (because ) and all you’re left with is .

Implementing dual numbers

One interesting way to implement dual numbers is with matrices. The number can be encoded as

This encoding has the properties and :

Furthermore, they add, multiply and divide like dual numbers.

This is nice because if you already have a library for doing matrix math, implementing automatic differentiation is trivial.

Higher-order derivatives

Many of the papers on automatic differentiation point out that this technique generalizes to second derivatives or arbitrary nth derivatives, but I haven’t found a good explanation of how that works. So, here’s my attempt.

To compute second derivatives, we need to carry out the Taylor series expansion one step further. So instead of , we need and . Then we get numbers of the form . When we want to find the first and second derivatives of some function at , we compute as before, but now we get

since the and higher terms drop out.

Now all we need to do is find a mathematical object that behaves this way. It turns out there’s a matrix for this too:

It encodes the properties we want:

Pretty neat!

If you want 3rd derivatives, you need a 4 x 4 matrix:

So we have

You can extend this technique to any order derivative you want.

Code

OK, enough math, let’s code it up. The Dual class will represent a dual number. The underlying representation is a square upper triangular diagonal-constant matrix. Unlike typical matrix libraries, I’m not going to store the cell values as a List[List[Double]] or anything. Instead I’m just going to have a method get(r: Int, c: Int): Double that returns the value in the specified cell. Also I’ll memoize it, so yeah, I guess I am storing the cell values somewhere. But this technique makes all the matrix operations easier to write.

abstract class Dual(val rank: Int) {
  self =>

  // Cell value accessor
  protected def get(r: Int, c: Int): Double

  // Memoizing cell value accessor.
  // Since it's a diagonal-constant matrix, we can use r - c as the key.
  def apply(r: Int, c: Int): Double = memo.getOrElseUpdate(r - c, self.get(r, c))

  // The memo table
  private val memo = scala.collection.mutable.HashMap[Int, Double]()
}

Now the usual matrix operations.

abstract class Dual(val rank: Int) {
  // ...

  def +(other: Dual): Dual = new Dual(rank) {
    def get(r: Int, c: Int) = self(r, c) + other(r, c)
  }

  def -(other: Dual): Dual = new Dual(rank) {
    def get(r: Int, c: Int) = self(r, c) - other(r, c)
  }

  def unary_-(): Dual = new Dual(rank) {
    def get(r: Int, c: Int) = -self(r, c)
  }

  def *(other: Dual): Dual = new Dual(rank) {
    def get(r: Int, c: Int) = (1 to rank).map(i => self(r, i) * other(i, c)).sum
  }

  def *(x: Double): Dual = new Dual(rank) {
    def get(r: Int, c: Int) = self(r, c) * x
  }
}

Division is implemented as multiplication by the inverse: . Matrix inverses are annoying to find and may not exist in the general case, but thankfully all we’re dealing with are square upper triangular matrices, which are a lot easier to invert. Actually, we have something even better: all of our matrices are of the form , where is a nilpotent matrix, meaning for some .

It turns out that for any nilpotent matrix ,

à la the familiar algebraic identity

So now we can find the inverse as follows:

where . Here’s the code:

abstract class Dual(val rank: Int) {
  // ...

  def /(other: Dual): Dual = self * other.inv

  def /(x: Double): Dual = new Dual(rank) {
    def get(r: Int, c: Int) = self(r, c) / x
  }

  def inv: Dual = {
    val a = self(1, 1)
    val D = self - I * a
    val N = -D / a
    List.iterate(I, rank)(_ * N).reduce(_ + _) / a
  }

  // An identity matrix of the same rank as this one
  lazy val I: Dual = new Dual(rank) {
    def get(r: Int, c: Int) = if (r == c) 1 else 0
  }
}

Finally, some utility methods:

abstract class Dual(val rank: Int) {
  // ...

  def pow(p: Int): Dual = {
    def helper(b: Dual, e: Int, acc: Dual): Dual = {
      if (e == 0) acc
      else helper(b * b, e / 2, if (e % 2 == 0) acc else acc * b)
    }
    helper(self, p, self.I)
  }

  override def toString = {
    (1 to rank).map(c => self(1, c)).mkString(" ")
  }
}

Now we need concrete classes representing 1 and :

class I(override val rank: Int) extends Dual(rank) {
  def get(r: Int, c: Int) = if (r == c) 1 else 0
}

class E(override val rank: Int) extends Dual(rank) {
  def get(r: Int, c: Int) = if (r + 1 == c) 1 else 0
}

Let’s try it out. Suppose we want to find the first 5 derivatives of at .

scala> val one = new I(6)
i: I = 1.0 0.0 0.0 0.0 0.0 0.0

scala> val e = new E(6)
e: D = 0.0 1.0 0.0 0.0 0.0 0.0

scala> def f(x: Dual): Dual = x.pow(4)
f: (x: Dual)Dual

scala> f(one*2 + e)
res0: Dual = 16.0 32.0 24.0 8.0 1.0 0.0

This is . The coefficient of will be . And it checks out:

How about the first 8 derivatives of at ?

scala> val one = new I(9)
i: I = 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

scala> val e = new E(9)
e: D = 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

scala> def g(x: Dual): Dual = x.pow(2) * 4 / (one - x).pow(3)
g: (x: Dual)Dual

scala> g(one*3 + e)
res1: Dual = -4.5 3.75 -2.75 1.875 -1.21875 0.765625 -0.46875 0.28125 -0.166015625

OK, let’s just check . I’m gonna use Wolfram Alpha for this, because… yeah.

Neat!

Conclusion

This is a pretty amazing technique. Instead of computing a difference quotient with tiny values of , which is prone to all sorts of floating-point rounding errors, you get exact numerical derivatives. In fact you get as many higher-order derivatives as you want, simultaneously. So, you almost never need to do symbolic differentiation.

Of course, for this to be really useful, I’d have to implement more than just the standard arithmetic operations on dual numbers. I’ll also want be able to compute or or . There are certainly ways to do this! But maybe it’s a topic for another post.

All of the code in this post is available in this gist.



blog comments powered by Disqus