Skip to content

Commit d3cd641

Browse files
authored
Fix #126, deserialization with postponed type ants (#127)
* Fix #126, deserialization with postponed type ants When `from __future__ import annotations` feature is used, the from_dict method doesn't work correctly, since it can't find the right decoding function. This fixes this bug, by checking if any known types match the given string, and then using those. Signed-off-by: Fabrice Normandin <[email protected]> * Improve warning text when unknown type annotation Signed-off-by: Fabrice Normandin <[email protected]>
1 parent 3e537e9 commit d3cd641

File tree

2 files changed

+70
-11
lines changed

2 files changed

+70
-11
lines changed

simple_parsing/helpers/serialization/decoding.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
""" Functions for decoding dataclass fields from "raw" values (e.g. from json).
22
"""
3+
import inspect
34
import warnings
45
from collections import OrderedDict
56
from dataclasses import Field, fields
67
from functools import lru_cache, partial
78
from logging import getLogger
89
from typing import TypeVar, Any, Dict, Type, Callable, Optional, Union, List, Tuple, Set
910

11+
1012
from simple_parsing.utils import (
1113
get_type_arguments,
1214
is_dataclass_type,
@@ -94,6 +96,35 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
9496
# cache_info = get_decoding_fn.cache_info()
9597
# logger.debug(f"called for type {t}! Cache info: {cache_info}")
9698

99+
if isinstance(t, str):
100+
# Type annotation is a string.
101+
# This can happen when the `from __future__ import annotations` feature is used.
102+
potential_keys: List[Type] = []
103+
for key in _decoding_fns:
104+
if inspect.isclass(key):
105+
if key.__qualname__ == t:
106+
# Qualname is more specific, there can't possibly be another match, so break.
107+
potential_keys.append(key)
108+
break
109+
if key.__qualname__ == t:
110+
# For just __name__, there could be more than one match.
111+
potential_keys.append(key)
112+
113+
if not potential_keys:
114+
raise ValueError(
115+
f"Couldn't find a decoding function for the string annotation '{t}'.\n"
116+
f"This is probably a bug. If it is, please make an issue on GitHub so we can get "
117+
f"to work on fixing it."
118+
)
119+
if len(potential_keys) == 1:
120+
t = potential_keys[0]
121+
else:
122+
raise ValueError(
123+
f"Multiple decoding functions registered for a type {t}: {potential_keys} \n"
124+
f"This could be a bug, but try to use different names for each type, or add the "
125+
f"modules they come from as a prefix, perhaps?"
126+
)
127+
97128
if t in _decoding_fns:
98129
# The type has a dedicated decoding function.
99130
return _decoding_fns[t]
@@ -175,8 +206,9 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
175206
# Unknown type.
176207
warnings.warn(
177208
UserWarning(
178-
f"Unable to find a decoding function for type {t}. "
179-
f"Will try to use the type as a constructor."
209+
f"Unable to find a decoding function for the annotation {t} (of type {type(t)}). "
210+
f"Will try to use the type as a constructor. Consider registering a decoding function "
211+
f"using `register_decoding_fn`, or posting an issue on GitHub. "
180212
)
181213
)
182214
return try_constructor(t)

test/test_future_annotations.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,39 @@ class MoreComplex(TestSetup):
194194
vals_tuple: tuple[int | float, bool] = field(default=(1, False))
195195

196196
assert MoreComplex.setup("--vals_list 456 123") == MoreComplex(vals_list=[456, 123])
197-
assert MoreComplex.setup("--vals_list 4.56 1.23") == MoreComplex(
198-
vals_list=[4.56, 1.23]
199-
)
200-
assert MoreComplex.setup("--vals_tuple 456 False") == MoreComplex(
201-
vals_tuple=(456, False)
202-
)
203-
assert MoreComplex.setup("--vals_tuple 4.56 True") == MoreComplex(
204-
vals_tuple=(4.56, True)
205-
)
197+
assert MoreComplex.setup("--vals_list 4.56 1.23") == MoreComplex(vals_list=[4.56, 1.23])
198+
assert MoreComplex.setup("--vals_tuple 456 False") == MoreComplex(vals_tuple=(456, False))
199+
assert MoreComplex.setup("--vals_tuple 4.56 True") == MoreComplex(vals_tuple=(4.56, True))
200+
201+
202+
from dataclasses import dataclass
203+
from simple_parsing.helpers import Serializable
204+
205+
206+
@dataclass
207+
class Opts1(Serializable):
208+
a: int = 64
209+
b: float = 1.0
210+
211+
212+
@dataclass
213+
class Opts2(Serializable):
214+
a: int = 32
215+
b: float = 0.0
216+
217+
218+
@dataclass
219+
class Wrapper(Serializable):
220+
opts1: Opts1 = Opts1()
221+
opts2: Opts2 = Opts2()
222+
223+
224+
def test_serialization_deserialization():
225+
# Show that it's not possible to deserialize nested dataclasses
226+
opts = Wrapper()
227+
assert Wrapper in Serializable.subclasses
228+
assert Opts1 in Serializable.subclasses
229+
assert Opts2 in Serializable.subclasses
230+
assert Wrapper.from_dict(opts.to_dict()) == opts
231+
assert Wrapper.loads_json(opts.dumps_json()) == opts
232+
assert Wrapper.loads_yaml(opts.dumps_yaml()) == opts

0 commit comments

Comments
 (0)