Skip to content

Commit 4a366ea

Browse files
committed
Add memoize.RECURSIVE
Refs #858.
1 parent 35ff35a commit 4a366ea

File tree

7 files changed

+72
-13
lines changed

7 files changed

+72
-13
lines changed

DOCS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3116,6 +3116,8 @@ _Note: Passing `--strict` disables deprecated features._
31163116

31173117
Coconut provides `functools.lru_cache` as a built-in under the name `memoize` with the modification that the _maxsize_ parameter is set to `None` by default. `memoize` makes the use case of optimizing recursive functions easier, as a _maxsize_ of `None` is usually what is desired in that case.
31183118

3119+
`memoize` also supports a special `maxsize=memoize.RECURSIVE` argument, which will allow the cache to grow without bound within a single call to the top-level function, but clear the cache after the top-level call returns.
3120+
31193121
Use of `memoize` requires `functools.lru_cache`, which exists in the Python 3 standard library, but under Python 2 will require `pip install backports.functools_lru_cache` to function. Additionally, if on Python 2 and `backports.functools_lru_cache` is present, Coconut will patch `functools` such that `functools.lru_cache = backports.functools_lru_cache.lru_cache`.
31203122

31213123
Note that, if the function to be memoized is a generator or otherwise returns an iterator, [`recursive_generator`](#recursive_generator) can also be used to achieve a similar effect, the use of which is required for recursive generators.

__coconut__/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ _coconut_zip = zip
206206

207207
zip_longest = _coconut.zip_longest
208208
memoize = _lru_cache
209+
memoize.RECURSIVE = None # type: ignore
209210
reduce = _coconut.functools.reduce
210211
takewhile = _coconut.itertools.takewhile
211212
dropwhile = _coconut.itertools.dropwhile

coconut/compiler/templates/header.py_template

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,17 +1643,46 @@ def fmap(func, obj, **kwargs):
16431643
else:
16441644
mapped_obj = _coconut_map(func, obj)
16451645
return _coconut_base_makedata(obj.__class__, mapped_obj, from_fmap=True, fallback_to_init=fallback_to_init)
1646-
def _coconut_memoize_helper(maxsize=None, typed=False):
1647-
return maxsize, typed
16481646
def memoize(*args, **kwargs):
16491647
"""Decorator that memoizes a function, preventing it from being recomputed
16501648
if it is called multiple times with the same arguments."""
16511649
if not kwargs and _coconut.len(args) == 1 and _coconut.callable(args[0]):
1652-
return _coconut.functools.lru_cache(maxsize=None)(args[0])
1650+
return _coconut_memoize_helper()(args[0])
16531651
if _coconut.len(kwargs) == 1 and "user_function" in kwargs and _coconut.callable(kwargs["user_function"]):
1654-
return _coconut.functools.lru_cache(maxsize=None)(kwargs["user_function"])
1655-
maxsize, typed = _coconut_memoize_helper(*args, **kwargs)
1656-
return _coconut.functools.lru_cache(maxsize, typed)
1652+
return _coconut_memoize_helper()(kwargs["user_function"])
1653+
return _coconut_memoize_helper(*args, **kwargs)
1654+
memoize.RECURSIVE = _coconut_Sentinel()
1655+
def _coconut_memoize_helper(maxsize=None, typed=False):
1656+
if maxsize is memoize.RECURSIVE:
1657+
def memoizer(func):
1658+
"""memoize(...)"""
1659+
inside = [False]
1660+
cache = {empty_dict}
1661+
@_coconut_wraps(func)
1662+
def memoized_func(*args, **kwargs):
1663+
if typed:
1664+
key = (_coconut.tuple((x, _coconut.type(x)) for x in args), _coconut.tuple((k, _coconut.type(k), v, _coconut.type(v)) for k, v in kwargs.items()))
1665+
else:
1666+
key = (args, _coconut.tuple(kwargs.items()))
1667+
got = cache.get(key, _coconut_sentinel)
1668+
if got is not _coconut_sentinel:
1669+
return got
1670+
outer_inside, inside[0] = inside[0], True
1671+
try:
1672+
got = func(*args, **kwargs)
1673+
cache[key] = got
1674+
return got
1675+
finally:
1676+
inside[0] = outer_inside
1677+
if not inside[0]:
1678+
cache.clear()
1679+
memoized_func.__module__ = _coconut.getattr(func, "__module__", None)
1680+
memoized_func.__name__ = _coconut.getattr(func, "__name__", None)
1681+
memoized_func.__qualname__ = _coconut.getattr(func, "__qualname__", None)
1682+
return memoized_func
1683+
return memoizer
1684+
else:
1685+
return _coconut.functools.lru_cache(maxsize, typed)
16571686
{def_call_set_names}
16581687
class override(_coconut_baseclass):
16591688
"""Declare a method in a subclass as an override of a parent class method.

coconut/root.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
VERSION = "3.1.2"
2727
VERSION_NAME = None
2828
# False for release, int >= 1 for develop
29-
DEVELOP = 3
29+
DEVELOP = 4
3030
ALPHA = False # for pre releases rather than post releases
3131

3232
assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"

coconut/tests/src/cocotest/agnostic/primary_2.coco

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,20 @@ def primary_test_2() -> bool:
491491
assert reduce(function=(+), iterable=range(5), initial=-1) == 9 # type: ignore
492492
assert takewhile(predicate=ident, iterable=[1, 2, 1, 0, 1]) |> list == [1, 2, 1] # type: ignore
493493
assert dropwhile(predicate=(not), iterable=range(5)) |> list == [1, 2, 3, 4] # type: ignore
494+
@memoize(typed=True)
495+
def typed_memoized_func(x):
496+
if x is 1:
497+
return None
498+
else:
499+
return (x, typed_memoized_func(1))
500+
assert typed_memoized_func(1.0) == (1.0, None)
501+
assert typed_memoized_func(1.0)[0] |> type == float
502+
@memoize()
503+
def untyped_memoized_func(x=None):
504+
if x is None:
505+
return (untyped_memoized_func(1), untyped_memoized_func(1.0))
506+
return x
507+
assert untyped_memoized_func() |> map$(type) |> tuple == (int, float)
494508

495509
with process_map.multiple_sequential_calls(): # type: ignore
496510
assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore

coconut/tests/src/cocotest/agnostic/suite.coco

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,10 @@ def suite_test() -> bool:
562562
assert plus1sqsum_all(1, 2) == 13 == plus1sqsum_all_(1, 2) # type: ignore
563563
assert sum_list_range(10) == 45
564564
assert sum2([3, 4]) == 7
565-
assert ridiculously_recursive(300) == 201666561657114122540576123152528437944095370972927688812965354745141489205495516550423117825 == ridiculously_recursive_(300)
565+
with process_map.multiple_sequential_calls():
566+
for ridiculously_recursive in (ridiculously_recursive1, ridiculously_recursive2, ridiculously_recursive3):
567+
got = process_map(ridiculously_recursive, [300])
568+
assert got == (201666561657114122540576123152528437944095370972927688812965354745141489205495516550423117825,), got
566569
assert [fib(n) for n in range(16)] == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] == [fib_(n) for n in range(16)]
567570
assert [fib_alt1(n) for n in range(16)] == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] == [fib_alt2(n) for n in range(16)]
568571
assert fib.cache_info().hits == 28

coconut/tests/src/cocotest/agnostic/util.coco

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,24 +1346,34 @@ def sum2(ab) = a + b where:
13461346
# Memoization
13471347
import functools
13481348

1349-
@memoize()
1350-
def ridiculously_recursive(n):
1349+
@memoize(None)
1350+
def ridiculously_recursive1(n):
13511351
"""Requires maxsize=None when called on large numbers."""
13521352
if n <= 0:
13531353
return 1
13541354
result = 0
13551355
for i in range(1, 200):
1356-
result += ridiculously_recursive(n-i)
1356+
result += ridiculously_recursive1(n-i)
13571357
return result
13581358

13591359
@functools.lru_cache(maxsize=None) # type: ignore
1360-
def ridiculously_recursive_(n):
1360+
def ridiculously_recursive2(n):
13611361
"""Requires maxsize=None when called on large numbers."""
13621362
if n <= 0:
13631363
return 1
13641364
result = 0
13651365
for i in range(1, 200):
1366-
result += ridiculously_recursive_(n-i)
1366+
result += ridiculously_recursive2(n-i)
1367+
return result
1368+
1369+
@memoize(memoize.RECURSIVE) # type: ignore
1370+
def ridiculously_recursive3(n):
1371+
"""Requires maxsize=None when called on large numbers."""
1372+
if n <= 0:
1373+
return 1
1374+
result = 0
1375+
for i in range(1, 200):
1376+
result += ridiculously_recursive3(n-i)
13671377
return result
13681378

13691379
def fib(n if n < 2) = n

0 commit comments

Comments
 (0)