Skip to content

Allow getting data by subclasses #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
from topgg.errors import TopGGException


class _Int(int):
...


@pytest.fixture
def data_container() -> DataContainerMixin:
dc = DataContainerMixin()
dc.set_data("TEXT")
dc.set_data(200)
dc.set_data(_Int(200))
dc.set_data({"a": "b"})
return dc

Expand Down Expand Up @@ -49,12 +53,26 @@ def test_data_container_raises_data_already_exists(data_container: DataContainer


@pytest.mark.asyncio
async def test_data_container_raises_key_error(data_container: DataContainerMixin):
with pytest.raises(KeyError):
async def test_data_container_raises_lookup_error(data_container: DataContainerMixin):
with pytest.raises(LookupError):
await data_container._invoke_callback(_invalid_callback)


def test_data_container_get_data(data_container: DataContainerMixin):
assert data_container.get_data(str) == "TEXT"
assert data_container.get_data(float) is None
assert isinstance(data_container.get_data(set, set()), set)
assert data_container.get_data(_Int) == 200


def test_data_container_get_data_by_subclass(data_container: DataContainerMixin):
data = data_container.get_data(int)
assert isinstance(data, _Int)
assert data == 200
assert int in data_container._lookup_cache


def test_data_container_override_lookup_cache(data_container: DataContainerMixin):
assert data_container.get_data(int) is not None
data_container.set_data(_Int(300), override=True)
assert data_container.get_data(int) == 300
55 changes: 50 additions & 5 deletions topgg/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
DataContainerT = t.TypeVar("DataContainerT", bound="DataContainerMixin")


# this is meant to be a singleton,
# but we don't care if it's instantiated more than once
# we'll only use the one we instantiated here: _UNSET.
class _UnsetType:
def __bool__(self) -> bool:
return False


_UNSET = _UnsetType()


def data(type_: t.Type[T]) -> T:
"""
Represents the injected data. This should be set as the parameter's default value.
Expand Down Expand Up @@ -74,10 +85,11 @@ class DataContainerMixin:
as arguments in your functions.
"""

__slots__ = ("_data",)
__slots__ = ("_data", "_lookup_cache")

def __init__(self) -> None:
self._data: t.Dict[t.Type, t.Any] = {type(self): self}
self._lookup_cache: t.Dict[t.Type, t.Any] = {}

def set_data(
self: DataContainerT, data_: t.Any, *, override: bool = False
Expand All @@ -101,6 +113,11 @@ def set_data(
f"{type_} already exists. If you wish to override it, pass True into the override parameter."
)

# exclude the type itself and object
for sup in type_.mro()[1:-1]:
if sup in self._lookup_cache:
self._lookup_cache[sup] = data_

self._data[type_] = data_
return self

Expand All @@ -114,7 +131,10 @@ def get_data(self, type_: t.Type[T], default: t.Any = None) -> t.Any:

def get_data(self, type_: t.Any, default: t.Any = None) -> t.Any:
"""Gets the injected data."""
return self._data.get(type_, default)
try:
return self._get_data(type_)
except LookupError:
return default

async def _invoke_callback(
self, callback: t.Callable[..., T], *args: t.Any, **kwargs: t.Any
Expand All @@ -133,13 +153,38 @@ async def _invoke_callback(
}

for k, v in signatures.items():
signatures[k] = self._resolve_data(v.type)
signatures[k] = self._get_data(v.type)

res = callback(*args, **{**signatures, **kwargs})
if inspect.isawaitable(res):
return await res

return res

def _resolve_data(self, type_: t.Type[T]) -> T:
return self._data[type_]
def _resolve_data(self, type_: t.Type[T]) -> t.Union[_UnsetType, t.Tuple[bool, T]]:
maybe_data = self._data.get(type_, _UNSET)
if maybe_data is not _UNSET:
return False, maybe_data

cache = self._lookup_cache.get(type_, _UNSET)
if cache is not _UNSET:
return False, cache

for subclass in type_.__subclasses__():
maybe_data = self._resolve_data(subclass)
if maybe_data is not _UNSET:
return True, maybe_data[1]

return _UNSET

def _get_data(self, type_: t.Type[T]) -> T:
maybe_data = self._resolve_data(type_)
if maybe_data is _UNSET:
raise LookupError(f"data of type {type_} can't be found.")

assert isinstance(maybe_data, tuple)
is_subclass, data = maybe_data
if is_subclass:
self._lookup_cache[type_] = data

return data