Skip to content

Handle nesting for ConvertDType, ToArray, adapt concatenate dispatch #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ def concatenate(self, keys: str | Sequence[str], *, into: str, axis: int = -1):
axis : int, optional
Along which axis to concatenate the keys. The last axis is used by default.
"""
if isinstance(keys, Sequence) and len(keys) == 1:
# unpack string if only one key is supplied, so that Rename is used below
keys = keys[0]
if isinstance(keys, str):
transform = Rename(keys, to_key=into)
else:
Expand Down
5 changes: 3 additions & 2 deletions bayesflow/adapters/transforms/convert_dtype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from keras.tree import map_structure

from bayesflow.utils.serialization import serializable, serialize

Expand Down Expand Up @@ -32,7 +33,7 @@ def get_config(self) -> dict:
return serialize(config)

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data.astype(self.to_dtype, copy=False)
return map_structure(lambda d: d.astype(self.to_dtype, copy=False), data)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data.astype(self.from_dtype, copy=False)
return map_structure(lambda d: d.astype(self.from_dtype, copy=False), data)
28 changes: 26 additions & 2 deletions bayesflow/adapters/transforms/to_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from bayesflow.utils.tree import map_dict, get_value_at_path, map_dict_with_path
from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform
Expand Down Expand Up @@ -35,13 +36,36 @@

def forward(self, data: any, **kwargs) -> np.ndarray:
if self.original_type is None:
self.original_type = type(data)
if isinstance(data, dict):
self.original_type = map_dict(type, data)
else:
self.original_type = type(data)

if isinstance(self.original_type, dict):
# use self.original_type in check to preserve serializablitiy
return map_dict(np.asarray, data)
return np.asarray(data)

def inverse(self, data: np.ndarray, **kwargs) -> any:
def inverse(self, data: np.ndarray | dict, **kwargs) -> any:
if self.original_type is None:
raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.")
if isinstance(self.original_type, dict):
# use self.original_type in check to preserve serializablitiy

def restore_original_type(path, value):
try:
original_type = get_value_at_path(self.original_type, path)
return original_type(value)
except KeyError:
pass

Check warning on line 60 in bayesflow/adapters/transforms/to_array.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/to_array.py#L60

Added line #L60 was not covered by tests
except TypeError:
pass
except ValueError:

Check warning on line 63 in bayesflow/adapters/transforms/to_array.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/to_array.py#L63

Added line #L63 was not covered by tests
# separate statements, as optree does not allow (KeyError | TypeError | ValueError)
pass

Check warning on line 65 in bayesflow/adapters/transforms/to_array.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/to_array.py#L65

Added line #L65 was not covered by tests
return value

return map_dict_with_path(restore_original_type, data)

if issubclass(self.original_type, Number):
try:
Expand Down
3 changes: 2 additions & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
logging,
numpy_utils,
serialization,
tree,
)

from .callbacks import detailed_loss_callback
Expand Down Expand Up @@ -104,4 +105,4 @@

from ._docs import _add_imports_to_all

_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization"])
_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization", "tree"])
69 changes: 69 additions & 0 deletions bayesflow/utils/tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import optree


def flatten_shape(structure):
def is_shape_tuple(x):
return isinstance(x, (list, tuple)) and all(isinstance(e, (int, type(None))) for e in x)

Check warning on line 6 in bayesflow/utils/tree.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tree.py#L5-L6

Added lines #L5 - L6 were not covered by tests

leaves, _ = optree.tree_flatten(

Check warning on line 8 in bayesflow/utils/tree.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tree.py#L8

Added line #L8 was not covered by tests
structure,
is_leaf=is_shape_tuple,
none_is_leaf=True,
namespace="keras",
)
return leaves

Check warning on line 14 in bayesflow/utils/tree.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tree.py#L14

Added line #L14 was not covered by tests


def map_dict(func, *structures):
def is_not_dict(x):
return not isinstance(x, dict)

if not structures:
raise ValueError("Must provide at least one structure")

Check warning on line 22 in bayesflow/utils/tree.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tree.py#L22

Added line #L22 was not covered by tests

# Add check for same structures, otherwise optree just maps to shallowest.
def func_with_check(*args):
if not all(optree.tree_is_leaf(s, is_leaf=is_not_dict, none_is_leaf=True, namespace="keras") for s in args):
raise ValueError("Structures don't have the same nested structure.")
return func(*args)

Check warning on line 28 in bayesflow/utils/tree.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tree.py#L26-L28

Added lines #L26 - L28 were not covered by tests

map_func = func_with_check if len(structures) > 1 else func

return optree.tree_map(
map_func,
*structures,
is_leaf=is_not_dict,
none_is_leaf=True,
namespace="keras",
)


def map_dict_with_path(func, *structures):
def is_not_dict(x):
return not isinstance(x, dict)

if not structures:
raise ValueError("Must provide at least one structure")

Check warning on line 46 in bayesflow/utils/tree.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tree.py#L46

Added line #L46 was not covered by tests

# Add check for same structures, otherwise optree just maps to shallowest.
def func_with_check(*args):
if not all(optree.tree_is_leaf(s, is_leaf=is_not_dict, none_is_leaf=True, namespace="keras") for s in args):
raise ValueError("Structures don't have the same nested structure.")
return func(*args)

Check warning on line 52 in bayesflow/utils/tree.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/tree.py#L50-L52

Added lines #L50 - L52 were not covered by tests

map_func = func_with_check if len(structures) > 1 else func

return optree.tree_map_with_path(
map_func,
*structures,
is_leaf=is_not_dict,
none_is_leaf=True,
namespace="keras",
)


def get_value_at_path(structure, path):
output = structure
for accessor in path:
output = output.__getitem__(accessor)
return output
4 changes: 2 additions & 2 deletions tests/test_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def serializable_fn(x):

return (
Adapter()
.group(["p1", "p2"], into="ps", prefix="p")
.to_array()
.ungroup("ps", prefix="p")
.as_set(["s1", "s2"])
.broadcast("t1", to="t2")
.as_time_series(["t1", "t2"])
Expand All @@ -37,8 +39,6 @@ def serializable_fn(x):
.rename("o1", "o2")
.random_subsample("s3", sample_size=33, axis=0)
.take("s3", indices=np.arange(0, 32), axis=0)
.group(["p1", "p2"], into="ps", prefix="p")
.ungroup("ps", prefix="p")
)


Expand Down
13 changes: 13 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,16 @@ def test_nnpe(random_data):
# Both should assign noise to high-variance dimension
assert std_dim[1] > 0
assert std_glob[1] > 0


def test_single_concatenate_to_rename():
# test that single-element concatenate is converted to rename
from bayesflow import Adapter
from bayesflow.adapters.transforms import Rename, Concatenate

ad = Adapter().concatenate("a", into="b")
assert isinstance(ad[0], Rename)
ad = Adapter().concatenate(["a"], into="b")
assert isinstance(ad[0], Rename)
ad = Adapter().concatenate(["a", "b"], into="c")
assert isinstance(ad[0], Concatenate)
9 changes: 1 addition & 8 deletions tests/test_workflows/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,6 @@ def sample(self, batch_shape: Shape, num_observations: int = 4) -> dict[str, Ten

x = mean[:, None] + noise

return dict(mean=mean, a=x, b=x)
return dict(mean=mean, observables=dict(a=x, b=x))

return FusionSimulator()


@pytest.fixture
def fusion_adapter():
from bayesflow import Adapter

return Adapter.create_default(["mean"]).group(["a", "b"], "summary_variables")
7 changes: 3 additions & 4 deletions tests/test_workflows/test_basic_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@ def test_basic_workflow(tmp_path, inference_network, summary_network):
assert samples["parameters"].shape == (5, 3, 2)


def test_basic_workflow_fusion(
tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator, fusion_adapter
):
def test_basic_workflow_fusion(tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator):
workflow = bf.BasicWorkflow(
adapter=fusion_adapter,
inference_network=fusion_inference_network,
summary_network=fusion_summary_network,
simulator=fusion_simulator,
inference_variables=["mean"],
summary_variables=["observables"],
checkpoint_filepath=str(tmp_path),
)

Expand Down