Scala: Hello Memo Fibonacci

by Michael

The previous Fibonacci program works, to a certain degree. We’re doing a lot of repetitive computation that we could avoid by storing results as they become available. So the second time we need the result, we just do a look-up and retrieve the previously calculated result. This technique is called memoization and can be applied to pretty much any function that is deterministic, i.e., for input x, the function will always return y.

Ideas and examples are based on the blog post found here on memoization: http://michid.wordpress.com/2009/02/23/function_mem/

Ideally, we would like to be able to apply a memoization function that will return a function that uses memoization. We define our function f(x) = y that does all the work. Now we want a function f'(x) = y that stores results for f(x) and returns the result from a look-up table.

f'(x) = Memoize(f)

Example function f(x) that does some computation of complexity O(n!) we can say that:

  • f(x) has O(n!) complexity always!
  • f'(x) has O(n!) complexity only the first time it is invoked and O(1) complexity after the first invocation.

Memoizing is nothing more than caching the result from a previous computation for instant look-up. Implementations usually look like a key-value store such as a Map.

class Memoize1[-T, +R](f: T => R) extends (T => R) {

  private[this] val memorized = scala.collection.mutable.Map.empty[T, R]

  def apply(x: T): R = {
    memorized.getOrElseUpdate(x, f(x))
  }
}
object Memoize1 {

  def apply[T, R](f: T => R) = new Memoize1(f)

  def Y[T, R](f: (T => R) => T => R): (T => R) = {
    lazy val yf: T => R = Memoize1(f(yf)(_))
    yf
  }
}

Scala is amazing! It allows us to pass a function to getOrElseUpdate in case the value hasn’t already been evaluated. So f(x) is only evaluated iff the key x is not available. Note we call this a Memoize1 class since it extends Function1, i.e., a function that takes 1 argument.

What About Recursive Functions

When we try to apply Memoize to a recursive function, we have a problem since we call the function from within the function itself. The recursive function calls use the original function (f), not the memoized function (f‘). Using the fib function from the previous post:

def fib(n: Long): Long = n match {
  case 0 | 1 => 1
  case _     => fib(n - 2) + fib(n - 1)
}
val fibMem = Memoize1(fib)

Problem here, the internal recursive functions do not call fibMem, but the original fib function. There is no elegant way to fix this, all we can do is change the Fibonacci function fib to take a new parameter where we can tell it what function to use for the recursive calls. This can be achieved by currying fib to return a new function using the memoized fib functions.

def fibCurry(f: Long => Long)(n: Long): Long = n match {
  case 0 | 1 => 1
  case _     => f(n - 2) + f(n - 1)
}
lazy val fibMem = Memoize1(fibCurry(fibMem)(_))

Lazy is required to defer evaluation of fibMem. The recursive calls f are now using the memoized version fibMem.

All Together

Here’s the source code to play around with together with some comments to explain things that I didn’t think were very clear. I noticed that the memoized version becomes superior at around n = 34 mark beating the original fib function by a few milliseconds.

/**
  * Based on http://michid.wordpress.com/2009/02/23/function_mem/
  */
object FibApp extends App {

  /**
   * Basic fibonacci number generator
   *
   * @param n index of fibonacci number
   * @return the fibonacci number
   */
  def fib(n: Long): Long = n match {
    case 0 | 1 => 1
    case _     => fib(n - 2) + fib(n - 1)
  }

  /**
   * Fibonacci number generator that takes a recursive function as
   * an argument and returns a fibonacci function
   *
   * @param f function used for recursion
   * @param n index of fibonacci number
   * @return the fibonacci number
   */
  def fibCurry(f: Long => Long)(n: Long): Long = n match {
    case 0 | 1 => 1
    case _     => f(n - 1) + f(n - 2)
  }

  /**
   * This piece of code seems a bit strange, sometimes get negative times
   *
   * @param code code to execute
   * @return result from calling code
   */
  def time[T](code: => T) = {
    val tic = System.nanoTime
    val result = code
    val toc = System.nanoTime

    println("Time: " + ((toc - tic).doubleValue / 1e6) + "ms")
    result
  }

  /**
   * class extends Function1, written as (T => R)
   * e.g., f(x) = y where x is in subset of T and y in superset of R
   *
   * @link http://www.hars.de/2009/10/variance-with-horses-and-people.html
   * @link http://en.wikipedia.org/wiki/Covariance_and_contravariance_(computer_science)
   *
   * @param f Function1 function
   */
  class Memoize1[-T, +R](f: T => R) extends (T => R) {

    private[this] val memorized = scala.collection.mutable.Map.empty[T, R]

    // function apply is automatically called when using operator ()
    def apply(x: T): R = {
      memorized.getOrElseUpdate(x, f(x))
    }
  }

  /**
   * declare object (singleton) Memoize1, like a Memoize1 factory
   */
  object Memoize1 {

    /**
     * apply is used when we write Memoize1(f), same as Memoize1.apply(f)
     */
    def apply[T, R](f: T => R) = new Memoize1(f)

    def Y[T, R](f: (T => R) => T => R): (T => R) = {
      lazy val yf: T => R = Memoize1(f(yf)(_))
      yf
    }
  }

  val fibMem = Memoize1.Y(fibCurry)
  val max = 34

  time(println(fibMem(max)))
  println("--------------")
  time(println(fib(max)))
}

Note that I switched the recursive calls to start with f(n – 1) instead of f(n – 2). This is so that we directly memoize all results, in place of stepping by two which is what we do with f(n – 2).

Advertisements