Kotlin: Function Memoization [IN PRACTICE]

Iliyan Germanov
8 min readMar 10, 2023

--

Function memoization (often confused with “memorization” which is the same) is simply remembering the result of a function for a particular input (caching). It’s used for optimization purposes. Let’s illustrate that with an example.

fun factorial(n: Int): Int = when (n) {
0, 1 -> 1
// [BUG 1] the line below is problematic! Do you know why?
else -> n * factorial(n - 1)
}

fun main() {
val cache = mutableMapOf<Int, Int>() // 1
while (true) {
val n = readln().toInt()
val result = cache.getOrPut(n) { // 2
factorial(n) // 3
}
println(result)
}
}

Psst… Can you spot the bug in the code snippet above?

  1. val cache = mutableMapOf<Int, Int>: Creates a mutable map cache that’ll hold the results of factorial(n).
  2. cache.getOrPut(n) {: Checks if the result of factorial(n) is cached and if so return it.
  3. cache.getOrPut(n) { … }: factorial(n) is not cached: calculate n-factorial by executing the function, put the result in the cache, and then return it.

The memoization algorithm is pretty straightforward: create a key-value-like structure where you can store the output (value) of the function for each input (key) → when calling the function first check if the input (key) is already storedotherwise execute the function and then store its result.

Now that we know what memoization is, let’s see when it’s worth doing it.

Generated by Dall-E 2

When NOT to memoize?

More observative readers may already found out that we can’t always memoize a function. Here’s when:

#1 The function is NOT deterministic

A function is not deterministic if you can’t be sure about the result that it’ll produce by just substituting a given input and executing the function’s body in your head.

// Examples: non-determinism
fun rollADice(): Int = Random.nextInt(1..6)
fun rollNDice(n: Int): List<Int> = (1..n).map { rollADice() }

suspend fun getUserStatus(userId: String): String {
val statusResponse = userStatusRequest(userId) // call to server
return "Status: $statusResponse"
}

It’s obvious that if we can’t memoize rollNDice(n) because it’ll make the function useless. No one wants to play a game with dice where each throw is the same.

Another example is getUserStatus(userId) where memoization will lead to incorrect behavior because the program will always return the last cached result for each user ID.

#2 The function produces side-effects

Memoization won’t also work if the function produces a side-effect that we need — for example:

// Example: Side-effect
fun calculateTaxesAndPayThem(revenue: Double): Double {
val taxes = revenue * 42
payTaxes(taxes) // side-effect
return taxes
}

If we memoize calculateTaxesAndPayThem(revenue) our government won’t be happy because if we happen to make the same “revenue” again → the payTaxes(taxes) side-effect won’t be executed and we’ll just return the latest cached value.

A function can be memoized only if it’s pure (also called referential transparent). A pure function is a function that: 1) always return the same output for the same input; 2) produces no side-effects; 3) is defined for all input values. Doesn’t throw an exception like: “fun divide(a: Int, b: Int) = a / b when b = 0”.

When to memoize a function in Kotlin?

If you read the stuff above, it should be obvious that you can memoize only pure functions. If the function is pure, the harder part is to decide whether memoization is worth it.

As a rule of thumb: you should consider memoizing calculations (yet another name for pure functions) that are heavy computations and that will be often called with the same input.

fun factorial(n: Int): Int // memoize it
fun doubleIt(a: Int) = a * 2 // not worth it
fun calculateTax(x: Double) = x * 0.1f // no worth it
fun fibonacci(n: Int): Int // memoize it
fun heavyComputation(
n: Int // often called with the same input
): Int // memoize it
fun heavyComputation2(
n: Int = Random.nextInt()
): Int // DON'T
Look closely! It’s the bug from the beginning of the article (tip: zoom in the center)

[BUG 1] The problem in our “factorial” example

If you found the bug, congrats! It’s a bit tricky. Let’s look at the example again. (it’s changed for simplicity)

fun factorial(n: Int): Int = when (n) {
0, 1 -> 1
else -> n * factorial(n - 1) // <-- this is the problem
}

val cache = mutableMapOf<Int, Int>()

fun run(n: Int) {
val result = cache.getOrPut(n) {
factorial(n)
}
println(result)
}

Attempt #0: run(5)

  • The cache is empty. 5 isn’t present in the cache.
  • factorial(5) called → 5 * factorial(4)5 * (4 * factorial()) → … → 5 * 4 * ... * 2 * factorial(1) = 120 → 120 returned.
  • cache[5] = 120 → 120 printed on the screen.

Attempt #1: run(6)

  • We have cache = { 5 : 120 } . 6 isn’t present in the cache.
  • factorial(6) called → 6 * factorial(5)6 * (5 * factorial(4)) → … → 6 * 5 * ... * 2 * factorial(1) = 720 → 720 returned.
  • cache = { 5: 120, 6 : 720 } → 720 printed on the screen

Do you see the problem now? When calling factorial(6) even though we had cache = { 5 : 120 } , the recursive call to factorial(5) had to do all the computations down to factorial(1) = 1 . Let’s fix that!

val cache = mutableMapOf<Int, Int>()

fun factorial(n: Int): Int = cache.getOrPut(n) { // <-- this is the fix
when (n) {
0, 1 -> 1
else -> n * factorial(n - 1)
}
}

fun run(n: Int) {
println(factorial(n))
}

After the fix, the recursive call factorial(n — 1) checks the cache again before executing all the computational heavy lifting. This is a common pattern for memoizing functions in Kotlin. However, there are still more problems:

  • A global mutable cache variable.
  • What if we want to memoize two functions in the same file? Create cache1 and cache2?
  • What if we want to memoize a function with more than one argument?
  • This is boilerplate. Should we always write it?

The next final section will give you a solution to those questions.

[SOLUTION] Function Memoization in Kotlin

Memoization of functions in Kotlin can be easily generalized for any arbitrary functions having 1 to N parameters where N is a reasonably small number. Before we explore how the solution works under the hood, let’s first use it.

// [SOLUTION] Usage
fun Memo1<Int, Int>.factorial(n: Int): Int = when (n) {
0, 1 -> 1
else -> n * recurse(n - 1) // from Memo1<Int, Int>: calls factorial(n - 1)
}
val factorialMemoized = Memo1<Int, Int>::factorial.memoize()

fun run(n: Int) {
println(factorialMemoized(n))
}

Now, this might seem like a magic or scam to you. But don’t worry we’ll explore its implementation and provide you a code to copy-paste into your project if you like it.

Let’s begin. We see that our new code uses three unknown things Memo1<Int, Int> , recurse(n-1) and Memo1<Int, Int>::factorial.memoize() → these three must be implemented somehow.

// [SOLUTION] Implementation
interface Memo1<A, R> { // 1
fun recurse(a: A): R
}

// 2
fun <A, R> (Memo1<A, R>.(A) -> R).memoize(): (A) -> R {
val memoized = object : Memoized1<A, R>() { // 3
override fun Memo1<A, R>.function(a: A): R = this@memoize(a)
}
return { a -> // 4
memoized.execute(a)
}
}

abstract class Memoized1<A, R> { // 5
private val cache = mutableMapOf<A, R>()
private val memo = object : Memo1<A, R> {
override fun recurse(a: A): R = cache.getOrPut(a) { function(a) }
}

protected abstract fun Memo1<A, R>.function(a: A): R

fun execute(a: A): R = memo.recurse(a)
}

It’s not a lot of code but a lot of things are happening here. Let’s explore:

  • 1: An abstraction providing a callrecurse(a: A): R for a function accepting A and producing R (recurse :: A → R). That’s how our new Memo1<Int, Int>.factorial(n) can calculate n-1 factorial while still using the cache.
  • 2: An extension function that turns ours Memo1<Int,Int>.factorial(n) into a memoized lambda (Int) -> Int that uses caching.
  • 3: An abstract class Memoized1<A, R> that does the heavy lifting which we’ll explore in a moment.
  • 4: We return a lambda that when called always uses the same instance of Memoized1.
  • 5: The magic happens here. A mutable map cache is created which we’ll be checked before calling the function either by execute(A) or recurse(A). After each timeabstruct fun function(A): R is invoked the cache will be updated.

That’s how our out-of-the-box memoization solution works. A timely question would be can we make it work for functions having more than one argument?

Memoizing a Kotlin function with N arguments

Here’s an implementation for memoizing a function with 2 input parameters. It’s almost the same as the solution above so no explanation is needed.

// region Implementation
interface Memo2<A, B, R> {
fun recurse(a: A, b: B): R
}

abstract class Memoized2<A, B, R> {
private data class Input<A, B>(
val a: A,
val b: B
)

private val cache = mutableMapOf<Input<A, B>, R>()
private val memo = object : Memo2<A, B, R> {
override fun recurse(a: A, b: B): R =
cache.getOrPut(Input(a, b)) { function(a, b) }
}

protected abstract fun Memo2<A, B, R>.function(a: A, b: B): R

fun execute(a: A, b: B): R = memo.recurse(a, b)
}

fun <A, B, R> (Memo2<A, B, R>.(A, B) -> R).memoize(): (A, B) -> R {
val memoized = object : Memoized2<A, B, R>() {
override fun Memo2<A, B, R>.function(a: A, b: B): R = this@memoize(a, b)
}
return { a, b ->
memoized.execute(a, b)
}
}
// endregion

// region Usage
fun Memo2<String, Boolean, Long>.myAwesomeFun(a: String, b: Boolean): Long = TODO()
val myAwesomeFun = Memo2<String, Boolean, Long>::myAwesomeFun.memoize()
// endregion

Okay. So far, so good! But what if you want to memoize a function with 3, 4, or more parameters? I got you covered check out my IvyPack repo which has copy-paste-ready snippets that can save you the trouble.

Conclusion (things to consider)

  • Memoization is an optimization technique used primarily to speed up computer programs by storing the results of expensive function calls and returning the cached result when the same inputs occur again.
  • Memoization trades space (RAM) for execution time (CPU).
  • Only pure functions can be memoized.
  • Further optimization: you can pre-compute the most common (or even all) inputs for your function and hard-code them in your memoization cache map.
  • Tip: If the input parameters of the function are large, you can consider hashing them and using the hash as a memoization key which saves RAM.
  • Bonus: You can try re-writing your recursive algorithm using dynamic programming.

That’s it for today! Thank you for reading my entire article! If you want more Kotlin/Android posts like this one → follow me on Medium and clap 50 times to show me that it’s worth spending my time writing more.

P.S. If you enjoyed this article which tbh is kinda boring, check out my other one — [Android/Multiplatform] Kotlin Flows + Ktor = Flawless HTTP requests (- ArrowKt) which is far more interesting and helpful for Android Devs.

--

--

Iliyan Germanov

Software Engineer: Android | Functional Programming - Kotlin & Haskell. Excited by innovation and science.