In-built memoization in Python with @functools.lru_cache()
Memoization, in computer programming, is essentially an optimization technique wherein return values from a function are memorized/stored in a cache so that they don’t have to be computed again when required.
Memoization makes your code faster and hence, more efficient. This is so because the target function recomputes/re-executes only if the required value is not already present in the cache.
Python is already considered a very developer-friendly language and its in-built support for memoization just adds to the options it provides developers to write efficient functions.
The functools
module
This Python module provides functions that work on other functions (the functions provided as part of the module are referred to as higher-order functions) or callable objects. These higher-order functions allow us to use the intended functions or objects, and extend their functionality without having to rewrite them completely.
Some of the higher-order functions defined by the functools module include:
@functools.cache()
@functools.cached_property()
@functools.lru_cache()
-> the one we are going to talk about in this blog
The @functools.lru_cache()
wrapper
To better understand how to use higher-order functions, let’s take a quick look at an example snippet:
@lru_cache(maxsize=2)
def fib(n):
if n == 0:
return 0
if n == 1:
return 1
return fib(n - 2) + fib(n – 1)
The @lru_cache
decorator checks for some base cases and then uses the _lru_cache_wrapper()
function to wrap the target function. This wrapper function contains the logic of adding items to the cache and the LRU replacement strategy i.e. adding items to a circular queue and removing them when the cache is full. The internal implementation of the lru_cache()
decorator looks as below:
def lru_cache(maxsize=128, typed=False):
...
if isinstance(maxsize, int):
# Negative maxsize is treated as 0
if maxsize < 0:
maxsize = 0
elif callable(maxsize) and isinstance(typed, bool):
# The user_function was passed in directly via the maxsize argument
user_function, maxsize = maxsize, 128
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
return update_wrapper(wrapper, user_function)
elif maxsize is not None:
raise TypeError(
'Expected first argument to be an integer, a callable, or None')
def decorating_function(user_function):
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
return update_wrapper(wrapper, user_function)
return decorating_function
The lru_cache()
normalizes maxsize
(when negative), adds the CacheInfo
details, and finally adds the wrapper and updates the decorator docs and other details. You can check the cache status at any point in time by using f.cache_info()
where f
is the function decorated with the @functools.lru_cache()
decorator.
The _lru_cache_wrapper()
function
The lru_cache_wrapper()
used internally by the lru_cache()
implementation is a little more complex. At a high level, it contains the following variables for book-keeping purposes:
sentinel = object() # unique object used to signal cache misses
make_key = _make_key # build a key from the function arguments
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
cache = {}
hits = misses = 0
full = False
cache_get = cache.get # bound method to lookup a key or return None
cache_len = cache.__len__ # get cache size without calling len()
lock = RLock() # because linkedlist updates aren't threadsafe
root = [] # root of the circular doubly linked list
root[:] = [root, root, None, None] # initialize by pointing to self
The wrapper acquires a lock before performing any operation so that operations are thread-safe. The root list contains all the items adhering to maxsize
value. The important concept to remember root is self-referencing itself (root[:] = [root, root, None, None])
in the previous (0) and next position (1).
Three high-level cases depending on maxsize
:
The first case, when
maxsize
is 0, that means no cache functionality, the wrapper wraps the user function without any caching ability. The wrapper increments cache miss count and returns the result.def wrapper(*args, **kwds): # No caching -- just a statistics update nonlocal misses misses += 1 result = user_function(*args, **kwds) return result
The second case. when
maxsize
is None. In the section, there is no limit on the number of elements to store in the cache. So the wrapper checks for the key in the cache(dictionary). When the key is present, the wrapper returns the value and updates the cache hit info. When the key is missing, the wrapper calls the user function with user-passed arguments, updates the cache, updates the cache miss info, and returns the result.def wrapper(*args, **kwds): # Simple caching without ordering or size limit nonlocal hits, misses key = make_key(args, kwds, typed) result = cache_get(key, sentinel) if result is not sentinel: hits += 1 return result misses += 1 result = user_function(*args, **kwds) cache[key] = result return result
The third case, when
maxsize
is a default value (128) or user-passed integer value. Here is the actual LRU cache implementation. The entire code in the wrapper is implemented in a thread-safe way. Before performing any operation, read/write/delete from the cache, the wrapper obtainsRLock
.
LRU Cache Replacement Implementation
The value in the cache is stored as a list of four items(remember root
). The first item is the reference to the previous item, the second item is the reference to the next item, the third item is the key for the particular function call, and the fourth item is the result. Here is an actual value for Fibonacci function argument 1 [[[...], [...], 1, 1], [[...], [...], 1, 1], None, None]
. [...]
means the reference to the self(list).
The first check is for the cache hit. If yes, the value in the cache is a list of four values.
nonlocal root, hits, misses, full
key = make_key(args, kwds, typed)
with lock:
link = cache_get(key)
if link is not None:
# Move the link to the front of the circular queue
print(f'Cache hit for {key}, {root}')
link_prev, link_next, _key, result = link
link_prev[NEXT] = link_next
link_next[PREV] = link_prev
last = root[PREV]
last[NEXT] = root[PREV] = link
link[PREV] = last
link[NEXT] = root
hits += 1
return result
When the item is already in the cache, there is no need to check whether the circular queue is full or pop the item from the cache. Rather change the positions of the items in the circular queue. Since the recently used item is always on the top, the code moves the recent value to the top of the queue and the previous top item becomes next of the current item last[NEXT] = root[PREV] = link
and link[PREV] = last
and link[NEXT] = root
. NEXT
and PREV
are initialized in the top which points to appropriate positions in the list PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields.
Finally, increment the cache hit info and return the result.
When it is cache miss, update the misses info and the code checks for three cases. All three operations happen after obtaining the RLock
. Three cases in the source code in the following order - after acquiring the lock key is found in the cache, the cache is full, and the cache can take new items. For demonstration, let's follow the order, when the cache is not full, the cache is full, and the key is available in the cache after acquiring the lock.
When the cache is not full
...
else:
# Put result in a new link at the front of the queue.
last = root[PREV]
link = [last, root, key, result]
last[NEXT] = root[PREV] = cache[key] = link
# Use the cache_len bound method instead of the len() function
# which could potentially be wrapped in an lru_cache itself.
full = (cache_len() >= maxsize)
When the cache is not full, prepare the recent result(link = [last, root, key, result])
to contain the root's previous reference, root, key, and computed result.
Then point the recent result(link)
to the top of the circular queue (root[PREV] = link)
, root's previous item's next to point to recent result (last[NEXT]=link)
, and add the recent result to the cache (cache[key] = link)
.
Finally, check the cache is full (cache_len() >= maxsize and cache_len = cache.__len__
is declared in the top) and set the status to full.
For the fib example, when the function receives the first value 1, root
is empty and root
value is [[...], [...], None, None]
. After adding the result to the circular queue, the root
value is [[[...], [...], 1, 1], [[...], [...], 1, 1], None, None]
. Both the previous
and next
points to the key 1's result. And for the next value 0, after insertion the root value is
[[[[...], [...], 1, 1], [...], 0, 0], [[...], [[...], [...], 0, 0], 1, 1], None, None]
. Previous is [[[[...], [...], None, None], [...], 1, 1], [[...], [[...], [...], 1, 1], None, None], 0, 0]
and the next is [[[[...], [...], 0, 0], [...], None, None], [[...], [[...], [...], None, None], 0, 0], 1, 1]
.
When the cache is full
...
elif full:
# Use the old root to store the new key and result.
oldroot = root
oldroot[KEY] = key
oldroot[RESULT] = result
# Empty the oldest link and make it the new root.
# Keep a reference to the old key and old result to
# prevent their ref counts from going to zero during the
# update. That will prevent potentially arbitrary object
# clean-up code (i.e. __del__) from running while we're
# still adjusting the links.
root = oldroot[NEXT]
oldkey = root[KEY]
oldresult = root[RESULT]
root[KEY] = root[RESULT] = None
# Now update the cache dictionary.
del cache[oldkey]
# Save the potentially reentrant cache[key] assignment
# for last, after the root and links have been put in
# a consistent state.
cache[key] = oldroot
When the cache is full, use the root
as oldroot (oldroot=root)
and update the key and result.
Then make the oldroot
next item as the new root (root=oldroot[NEXT])
, copy the new root key and result (oldkey = root[KEY]
and oldresult = root[RESULT]
) .
Set the new root key and result to None (root[KEY] = root[RESULT] = None
).
Delete the old key's item from the cache (del cache[oldkey])
and add the calculated result to the cache (cache[key] = oldroot)
.
For the fibonacci example, when the cache is full, and the key is 2, the root
value is [[[[...], [...], 1, 1], [...], 0, 0], [[...], [[...], [...], 0, 0], 1, 1], None, None]
and the new root at the end of the block is [[[[...], [...], 0, 0], [...], 2, 1], [[...], [[...], [...], 2, 1], 0, 0], None, None]
. As you can see key 1 is removed and replaced by key 2.
When the key appears in cache after acquiring the lock
if key in cache:
# Getting here means that this same key was added to the
# cache while the lock was released. Since the link
# update is already done, we need only return the
# computed result and update the count of misses.
pass
When the key appears in the cache, after acquiring the lock, another thread may have enqueued the value. So there is nothing much to do, the wrapper returns the result.
Finally, code returns the result. Before executing the cache miss part, the code update cache misses info and calls the make_key
function.
References
Subscribe to my newsletter
Read articles from Jimil Shah directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by