diff --git a/tests/test_data_container.py b/tests/test_data_container.py index 978574fb..716e8bc0 100644 --- a/tests/test_data_container.py +++ b/tests/test_data_container.py @@ -4,11 +4,15 @@ from topgg.errors import TopGGException +class _Int(int): + ... + + @pytest.fixture def data_container() -> DataContainerMixin: dc = DataContainerMixin() dc.set_data("TEXT") - dc.set_data(200) + dc.set_data(_Int(200)) dc.set_data({"a": "b"}) return dc @@ -49,8 +53,8 @@ def test_data_container_raises_data_already_exists(data_container: DataContainer @pytest.mark.asyncio -async def test_data_container_raises_key_error(data_container: DataContainerMixin): - with pytest.raises(KeyError): +async def test_data_container_raises_lookup_error(data_container: DataContainerMixin): + with pytest.raises(LookupError): await data_container._invoke_callback(_invalid_callback) @@ -58,3 +62,17 @@ def test_data_container_get_data(data_container: DataContainerMixin): assert data_container.get_data(str) == "TEXT" assert data_container.get_data(float) is None assert isinstance(data_container.get_data(set, set()), set) + assert data_container.get_data(_Int) == 200 + + +def test_data_container_get_data_by_subclass(data_container: DataContainerMixin): + data = data_container.get_data(int) + assert isinstance(data, _Int) + assert data == 200 + assert int in data_container._lookup_cache + + +def test_data_container_override_lookup_cache(data_container: DataContainerMixin): + assert data_container.get_data(int) is not None + data_container.set_data(_Int(300), override=True) + assert data_container.get_data(int) == 300 diff --git a/topgg/data.py b/topgg/data.py index 7126d3bf..704ac91d 100644 --- a/topgg/data.py +++ b/topgg/data.py @@ -31,6 +31,17 @@ DataContainerT = t.TypeVar("DataContainerT", bound="DataContainerMixin") +# this is meant to be a singleton, +# but we don't care if it's instantiated more than once +# we'll only use the one we instantiated here: _UNSET. +class _UnsetType: + def __bool__(self) -> bool: + return False + + +_UNSET = _UnsetType() + + def data(type_: t.Type[T]) -> T: """ Represents the injected data. This should be set as the parameter's default value. @@ -74,10 +85,11 @@ class DataContainerMixin: as arguments in your functions. """ - __slots__ = ("_data",) + __slots__ = ("_data", "_lookup_cache") def __init__(self) -> None: self._data: t.Dict[t.Type, t.Any] = {type(self): self} + self._lookup_cache: t.Dict[t.Type, t.Any] = {} def set_data( self: DataContainerT, data_: t.Any, *, override: bool = False @@ -101,6 +113,11 @@ def set_data( f"{type_} already exists. If you wish to override it, pass True into the override parameter." ) + # exclude the type itself and object + for sup in type_.mro()[1:-1]: + if sup in self._lookup_cache: + self._lookup_cache[sup] = data_ + self._data[type_] = data_ return self @@ -114,7 +131,10 @@ def get_data(self, type_: t.Type[T], default: t.Any = None) -> t.Any: def get_data(self, type_: t.Any, default: t.Any = None) -> t.Any: """Gets the injected data.""" - return self._data.get(type_, default) + try: + return self._get_data(type_) + except LookupError: + return default async def _invoke_callback( self, callback: t.Callable[..., T], *args: t.Any, **kwargs: t.Any @@ -133,7 +153,7 @@ async def _invoke_callback( } for k, v in signatures.items(): - signatures[k] = self._resolve_data(v.type) + signatures[k] = self._get_data(v.type) res = callback(*args, **{**signatures, **kwargs}) if inspect.isawaitable(res): @@ -141,5 +161,30 @@ async def _invoke_callback( return res - def _resolve_data(self, type_: t.Type[T]) -> T: - return self._data[type_] + def _resolve_data(self, type_: t.Type[T]) -> t.Union[_UnsetType, t.Tuple[bool, T]]: + maybe_data = self._data.get(type_, _UNSET) + if maybe_data is not _UNSET: + return False, maybe_data + + cache = self._lookup_cache.get(type_, _UNSET) + if cache is not _UNSET: + return False, cache + + for subclass in type_.__subclasses__(): + maybe_data = self._resolve_data(subclass) + if maybe_data is not _UNSET: + return True, maybe_data[1] + + return _UNSET + + def _get_data(self, type_: t.Type[T]) -> T: + maybe_data = self._resolve_data(type_) + if maybe_data is _UNSET: + raise LookupError(f"data of type {type_} can't be found.") + + assert isinstance(maybe_data, tuple) + is_subclass, data = maybe_data + if is_subclass: + self._lookup_cache[type_] = data + + return data