Friday, December 8, 2023
HomeUncategorizedDynamic programming in Haskell: automatic memoization

# Dynamic programming in Haskell: automatic memoization

This is part 2 of a promised multi-part series on dynamic programming in Haskell. As a reminder, we’re using Zapis as a sample problem. In this problem, we are given a sequence of opening and closing brackets (parens, square brackets, and curly braces) with question marks, and have to compute the number of different ways in which the question marks could be replaced by brackets to create valid, properly nested bracket sequences.

Last time, we developed some code to efficiently solve this problem using a mutually recursive pair of a function and a lookup table represented by a lazy, immutable array. This solution is pretty good, but it leaves a few things to be desired:

• It requires defining both a function and a lazy, immutable array, and coming up with names for them.
• When defining the function, we have to remember to index into the array instead of calling the function recursively, and there is nothing that will warn us if we forget.

## An impossible dream

Wouldn’t it be cool if we could just write the recursive function, and then have some generic machinery make it fast for us by automatically generating a memo table?

In other words, we’d like a magic memoization function, with a type something like this:

``memo :: (i -> a) -> (i -> a)``

Then we could just define our slow, recursive function normally, wave our magic `memo` wand over it, and get a fast version for free!

This sounds lovely, of course, but there are a few problems:

• Surely this magic `memo` function won’t be able to work for any type `i`. Well, OK, we can add something like an `Ix i` constraint and/or extra arguments to make sure that values of type `i` can be used as (or converted to) array indices.

• How can `memo` possibly know how big of a table to allocate? One simple way to solve this would be to provide the table size as an extra explicit argument to `memo`. (In my next post we’ll also explore some clever things we can do when we don’t know in advance how big of a table we will need.)

• More fundamentally, though, our dream seems impossible: given a function `i -> a`, the only thing the `memo` function can do is call it on some input of type `i`; if the `i -> a` function is recursive then it will go off and do its recursive thing without ever consulting a memo table, defeating the entire purpose.

## … or is it?

For now let’s ignore the fact that our dream seems impossible and think about how we could write `memo`. The idea is to take the given `(i -> a)` function and first turn it into a lookup table storing a value of type `a` for each `i`; then return a new `i -> a` function which works by just doing a table lookup.

From my previous post we already have a function to create a table for a given function:

``````tabulate :: Ix i => (i,i) -> (i -> a) -> Array i a
tabulate rng f = listArray rng (map f \$ range rng)``````

The inverse function, which turns an array back into a function, is just the array indexing operator, with extra parentheses around the `i -> a` to emphasize the shift in perspective:

``(!) :: Ix i => Array i a -> (i -> a)``

So we can define `memo` simply as the composition

``````memo :: Ix i => (i,i) -> (i -> a) -> (i -> a)
memo rng = (!) . tabulate rng``````

This is nifty… but as we already saw, it doesn’t help very much… right? For example, let’s define a recursive (slow!) Fibonacci function, and apply `memo` to it:

``````{-# LANGUAGE LambdaCase #-}

fib :: Int -> Integer
fib = case
0 -> 0
1 -> 1
n -> fib (n-1) + fib (n-2)

fib' :: Int -> Integer
fib' = memo (0,1000) fib``````

As you can see from the following `ghci` session, calling, say, `fib' 35` is still very slow the first time, since it simply calls `fib 35` which does its usual exponential recursion. However, if we call `fib' 35` a second time, we get the answer instantly:

``````λ> :set +s
λ> fib' 35
9227465
(4.18 secs, 3,822,432,984 bytes)
λ> fib' 35
9227465
(0.00 secs, 94,104 bytes)``````

This is better than nothing, but it’s not really the point. We want it to be fast the first time by looking up intermediate results in the memo table. And trying to call `fib'` on bigger inputs is still going to be completely hopeless.

## The punchline

All might seem hopeless at this point, but we actually have everything we need—all we have to do is just stick the call to `memo` in the definition of `fib` itself!

``````fib :: Int -> Integer
fib = memo (0,1000) \$ case
0 -> 0
1 -> 1
n -> fib (n-1) + fib (n-2)``````

Magically, `fib` is now fast:

``````λ> fib 35
9227465
(0.00 secs, 94,096 bytes)
λ> fib 1000
43466557686937456435688527675040625802564660517371780402481729089536555417949051890403879840079255169295922593080322634775209689623239873322471161642996440906533187938298969649928516003704476137795166849228875
(0.01 secs, 807,560 bytes)``````

This solves all our problems. We only have to write a single definition, which is a directly recursive function, so it’s hard to mess it up. The only thing we have to change is to stick a call to `memo` (with an appropriate index range) on the front; the whole thing is elegant and short.

How does this even work, though? At first glance, it might seem like it will generate a new table with every recursive call to `fib`, which would obviously be a disaster. However, that’s not what happens: there is only a single, top-level definition of `fib`, and it is defined as the function which looks up its input in a certain table. Every time we call `fib` we are calling that same, unique top-level function which is defined in terms of its (unique, top-level) table. So this ends up being equivalent to our previous solution—there is a mutually recursive pair of a function and a lookup table—but written in a much nicer, more compact way that doesn’t require us to explicitly name the table.

So here’s our final solution for Zapis. As you can see, the extra code we have to write in order to memoize our recurrence boils down to about five lines (two of which are type signatures and could be omitted). This is definitely a technique worth knowing!

``````{-# LANGUAGE LambdaCase #-}

import Control.Arrow
import Data.Array

main = interact \$ lines >>> last >>> solve >>> format

format :: Integer -> String
format = show >>> reverse >>> take 5 >>> reverse

tabulate :: Ix i => (i,i) -> (i -> a) -> Array i a
tabulate rng f = listArray rng (map f \$ range rng)

memo :: Ix i => (i,i) -> (i -> a) -> (i -> a)
memo rng = (!) . tabulate rng

solve :: String -> Integer
solve str = c (0,n)
where
n = length str
s = listArray (0,n-1) str

c :: (Int, Int) -> Integer
c = memo ((0,0), (n,n)) \$ case
(i,j)
| i == j           -> 1
| even i /= even j -> 0
| otherwise        -> sum
[ m (s!i) (s!k) * c (i+1,k) * c (k+1, j)
| k <- [i+1, i+3 .. j-1]
]

m '(' ')'                = 1
m '[' ']'                = 1
m '{' '}'                = 1
m '?' '?'                = 3
m b '?' | b `elem` "([{" = 1
m '?' b | b `elem` ")]}" = 1
m _ _                    = 0``````

RELATED ARTICLES