### Scala: Hello Memo Fibonacci

#### by Michael

The previous Fibonacci program works, to a certain degree. We’re doing a lot of repetitive computation that we could avoid by storing results as they become available. So the second time we need the result, we just do a look-up and retrieve the previously calculated result. This technique is called memoization and can be applied to pretty much any function that is deterministic, i.e., for input *x*, the function will always return *y*.

Ideas and examples are based on the blog post found here on memoization: http://michid.wordpress.com/2009/02/23/function_mem/

Ideally, we would like to be able to apply a memoization function that will return a function that uses memoization. We define our function *f(x) = y* that does all the work. Now we want a function *f'(x) = y* that stores results for *f(x)* and returns the result from a look-up table.

*f'(x) = Memoize(f)*

Example function *f(x)* that does some computation of complexity *O(n!)* we can say that:

*f(x)*has*O(n!)*complexity always!*f'(x)*has*O(n!)*complexity only the first time it is invoked and*O(1)*complexity after the first invocation.

Memoizing is nothing more than caching the result from a previous computation for instant look-up. Implementations usually look like a key-value store such as a Map.

class Memoize1[-T, +R](f: T => R) extends (T => R) { private[this] val memorized = scala.collection.mutable.Map.empty[T, R] def apply(x: T): R = { memorized.getOrElseUpdate(x, f(x)) } } object Memoize1 { def apply[T, R](f: T => R) = new Memoize1(f) def Y[T, R](f: (T => R) => T => R): (T => R) = { lazy val yf: T => R = Memoize1(f(yf)(_)) yf } }

Scala is amazing! It allows us to pass a function to *getOrElseUpdate* in case the value hasn’t already been evaluated. So *f(x)* is only evaluated iff the key *x* is not available. Note we call this a Memoize1 class since it extends Function1, i.e., a function that takes 1 argument.

## What About Recursive Functions

When we try to apply Memoize to a recursive function, we have a problem since we call the function from within the function itself. The recursive function calls use the original function (*f*), not the memoized function (*f*‘). Using the fib function from the previous post:

def fib(n: Long): Long = n match { case 0 | 1 => 1 case _ => fib(n - 2) + fib(n - 1) } val fibMem = Memoize1(fib)

Problem here, the internal recursive functions do not call *fibMem*, but the original *fib* function. There is no elegant way to fix this, all we can do is change the Fibonacci function *fib* to take a new parameter where we can tell it what function to use for the recursive calls. This can be achieved by currying *fib* to return a new function using the memoized *fib* functions.

def fibCurry(f: Long => Long)(n: Long): Long = n match { case 0 | 1 => 1 case _ => f(n - 2) + f(n - 1) } lazy val fibMem = Memoize1(fibCurry(fibMem)(_))

Lazy is required to defer evaluation of *fibMem*. The recursive calls *f* are now using the memoized version *fibMem*.

## All Together

Here’s the source code to play around with together with some comments to explain things that I didn’t think were very clear. I noticed that the memoized version becomes superior at around *n = 34* mark beating the original *fib* function by a few milliseconds.

/** * Based on http://michid.wordpress.com/2009/02/23/function_mem/ */ object FibApp extends App { /** * Basic fibonacci number generator * * @param n index of fibonacci number * @return the fibonacci number */ def fib(n: Long): Long = n match { case 0 | 1 => 1 case _ => fib(n - 2) + fib(n - 1) } /** * Fibonacci number generator that takes a recursive function as * an argument and returns a fibonacci function * * @param f function used for recursion * @param n index of fibonacci number * @return the fibonacci number */ def fibCurry(f: Long => Long)(n: Long): Long = n match { case 0 | 1 => 1 case _ => f(n - 1) + f(n - 2) } /** * This piece of code seems a bit strange, sometimes get negative times * * @param code code to execute * @return result from calling code */ def time[T](code: => T) = { val tic = System.nanoTime val result = code val toc = System.nanoTime println("Time: " + ((toc - tic).doubleValue / 1e6) + "ms") result } /** * class extends Function1, written as (T => R) * e.g., f(x) = y where x is in subset of T and y in superset of R * * @link http://www.hars.de/2009/10/variance-with-horses-and-people.html * @link http://en.wikipedia.org/wiki/Covariance_and_contravariance_(computer_science) * * @param f Function1 function */ class Memoize1[-T, +R](f: T => R) extends (T => R) { private[this] val memorized = scala.collection.mutable.Map.empty[T, R] // function apply is automatically called when using operator () def apply(x: T): R = { memorized.getOrElseUpdate(x, f(x)) } } /** * declare object (singleton) Memoize1, like a Memoize1 factory */ object Memoize1 { /** * apply is used when we write Memoize1(f), same as Memoize1.apply(f) */ def apply[T, R](f: T => R) = new Memoize1(f) def Y[T, R](f: (T => R) => T => R): (T => R) = { lazy val yf: T => R = Memoize1(f(yf)(_)) yf } } val fibMem = Memoize1.Y(fibCurry) val max = 34 time(println(fibMem(max))) println("--------------") time(println(fib(max))) }

Note that I switched the recursive calls to start with *f(n – 1)* instead of *f(n – 2)*. This is so that we directly memoize all results, in place of stepping by two which is what we do with *f(n – 2)*.

Thank you very much!

I was solving knapsack problem in Scala, and your post helped me write an elegant solution with recursion and memoization. Can’t share the code, unfortunately, because it was a homework for online course.