Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
82b3ab4
allow tensor in DiagonalNormal dimension
arrjon Sep 6, 2025
8fbf737
fix sum dims
arrjon Sep 7, 2025
5c27246
fix batch_shape for sample
arrjon Sep 7, 2025
c684bca
dims to tuple
arrjon Sep 7, 2025
0697634
first draft compositional
arrjon Sep 8, 2025
b8e849e
first draft compositional
arrjon Sep 8, 2025
a280af3
first draft compositional
arrjon Sep 8, 2025
b9faf31
first draft compositional
arrjon Sep 8, 2025
9b7eb16
fix shapes
arrjon Sep 8, 2025
e79aac1
fix shapes
arrjon Sep 8, 2025
8a80240
fix shapes
arrjon Sep 8, 2025
00fbc61
fix shapes
arrjon Sep 8, 2025
e6158e7
fix shapes
arrjon Sep 8, 2025
1ac39b2
fix shapes
arrjon Sep 8, 2025
9fd9cf8
add minibatch
arrjon Sep 8, 2025
830e929
add compositional_bridge
arrjon Sep 8, 2025
f97594b
fix mini batch randomness
arrjon Sep 8, 2025
7219a71
fix mini batch randomness
arrjon Sep 8, 2025
a10026a
fix mini batch randomness
arrjon Sep 8, 2025
457eb5d
add prior score
arrjon Sep 8, 2025
7de4736
add prior score
arrjon Sep 8, 2025
1ee0e78
add prior score draft
arrjon Sep 8, 2025
f71359b
add prior score draft
arrjon Sep 8, 2025
6210c07
add prior score draft
arrjon Sep 8, 2025
bcb9f60
add prior score draft
arrjon Sep 8, 2025
455f03c
fix dtype
arrjon Sep 8, 2025
89523a9
fix docstring
arrjon Sep 9, 2025
e55631d
fix batch_shape in sample
arrjon Sep 9, 2025
3eaff24
fix batch_shape for point approximator
arrjon Sep 9, 2025
5601d20
Merge branch 'normal_distribution_dimension' into compositional_sampl…
arrjon Sep 9, 2025
6b9671b
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 10, 2025
e97e375
fix docstring
arrjon Sep 10, 2025
caa2d67
fix float32
arrjon Sep 10, 2025
1ac9bff
reorganize
arrjon Sep 12, 2025
df23f89
add annealed_langevin
arrjon Sep 12, 2025
0a87694
fix annealed_langevin
arrjon Sep 12, 2025
64d4373
add predictor corrector sampling
arrjon Sep 12, 2025
5b42368
add predictor corrector sampling
arrjon Sep 12, 2025
9402941
add predictor corrector sampling
arrjon Sep 12, 2025
e0b3bd5
add predictor corrector sampling
arrjon Sep 12, 2025
89361f7
add predictor corrector sampling
arrjon Sep 12, 2025
5969bd3
robust mean scores
arrjon Sep 12, 2025
e983cf7
add some tests
arrjon Sep 12, 2025
eac9aaf
minor fixes
arrjon Sep 12, 2025
2a9b0e1
minor fixes
arrjon Sep 12, 2025
9a1ba32
add test for compute_prior_score_pre
arrjon Sep 12, 2025
93b59ba
fix order of prior scores
arrjon Sep 12, 2025
922040d
fix prior scores standardize
arrjon Sep 13, 2025
b2991d1
better standard values for compositional
arrjon Sep 13, 2025
d2a36a8
better compositional_bridge
arrjon Sep 13, 2025
0ff960f
fix integrate_kwargs
arrjon Sep 13, 2025
b2ef755
fix integrate_kwargs
arrjon Sep 13, 2025
ca7f3bd
fix kwargs in sample
arrjon Sep 16, 2025
09df093
Merge branch 'dev' into fix_sampling_method_kwargs
arrjon Sep 16, 2025
2c161c6
fix kwargs in set transformer
arrjon Sep 16, 2025
9d4c1a1
fix kwargs in set transformer
arrjon Sep 16, 2025
ea0659d
remove print
arrjon Sep 16, 2025
922412f
Merge branch 'fix_sampling_method_kwargs' into compositional_sampling…
arrjon Sep 22, 2025
9220816
new class for compositional diffusion
arrjon Sep 22, 2025
ee1c320
fix import
arrjon Sep 22, 2025
c977959
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 23, 2025
9fee1d4
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 23, 2025
d3f639d
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 25, 2025
7d15b49
Merge branch 'dev' into compositional_sampling_diffusion
arrjon Sep 25, 2025
e6513c1
add import
arrjon Sep 26, 2025
e87f9d1
fix mini_batch_size
arrjon Sep 26, 2025
983cb8d
fix mini_batch_size
arrjon Sep 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,185 @@ def _batch_size_from_data(self, data: Mapping[str, any]) -> int:
inference variables as present.
"""
return keras.ops.shape(data["inference_variables"])[0]

def compositional_sample(
self,
*,
num_samples: int,
conditions: Mapping[str, np.ndarray],
compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray],
split: bool = False,
**kwargs,
) -> dict[str, np.ndarray]:
"""
Generates compositional samples from the approximator given input conditions.
The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...).
This method handles the extra compositional dimension appropriately.

Parameters
----------
num_samples : int
Number of samples to generate.
conditions : dict[str, np.ndarray]
Dictionary of conditioning variables as NumPy arrays with shape
(n_datasets, n_compositional_conditions, ...).
compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray]
A function that computes the score of the log prior distribution.
split : bool, default=False
Whether to split the output arrays along the last axis and return one column vector per target variable
samples.
**kwargs : dict
Additional keyword arguments for the adapter and sampling process.

Returns
-------
dict[str, np.ndarray]
Dictionary containing generated samples with compositional structure preserved.
"""
original_shapes = {}
flattened_conditions = {}
for key, value in conditions.items(): # Flatten compositional dimensions
original_shapes[key] = value.shape
n_datasets, n_comp = value.shape[:2]
flattened_shape = (n_datasets * n_comp,) + value.shape[2:]
flattened_conditions[key] = value.reshape(flattened_shape)
n_datasets, n_comp = original_shapes[next(iter(original_shapes))][:2]

# Prepare data using existing method (handles adaptation and standardization)
prepared_conditions = self._prepare_data(flattened_conditions, **kwargs)

# Remove any superfluous keys, just retain actual conditions
prepared_conditions = {k: v for k, v in prepared_conditions.items() if k in self.CONDITION_KEYS}

# Prepare prior scores to handle adapter
def compute_prior_score_pre(_samples: Tensor) -> Tensor:
if "inference_variables" in self.standardize:
_samples = self.standardize_layers["inference_variables"](_samples, forward=False)
_samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples})
adapted_samples, log_det_jac = self.adapter(
_samples, inverse=True, strict=False, log_det_jac=True, **kwargs
)

if len(log_det_jac) > 0:
problematic_keys = [key for key in log_det_jac if log_det_jac[key] != 0.0]
raise NotImplementedError(
f"Cannot use compositional sampling with adapters "
f"that have non-zero log_det_jac. Problematic keys: {problematic_keys}"
)

prior_score = compute_prior_score(adapted_samples)
for key in adapted_samples:
prior_score[key] = prior_score[key].astype(np.float32)

prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score)
out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1)

if "inference_variables" in self.standardize:
# Apply jacobian correction from standardization
# For standardization T^{-1}(z) = z * std + mean, the jacobian is diagonal with std on diagonal
# The gradient of log|det(J)| w.r.t. z is 0 since log|det(J)| = sum(log(std)) is constant w.r.t. z
# But we need to transform the score: score_z = score_x * std where x = T^{-1}(z)
standardize_layer = self.standardize_layers["inference_variables"]

# Compute the correct standard deviation for all components
std_components = []
for idx in range(len(standardize_layer.moving_mean)):
std_val = standardize_layer.moving_std(idx)
std_components.append(std_val)

# Concatenate std components to match the shape of out
if len(std_components) == 1:
std = std_components[0]
else:
std = keras.ops.concatenate(std_components, axis=-1)

# Expand std to match batch dimension of out
std_expanded = keras.ops.expand_dims(std, (0, 1)) # Add batch, sample dimensions
std_expanded = keras.ops.tile(std_expanded, [n_datasets, num_samples, 1])

# Apply the jacobian: score_z = score_x * std
out = out * std_expanded
return out

# Test prior score function, useful for debugging
test = self.inference_network.base_distribution.sample((n_datasets, num_samples))
test = compute_prior_score_pre(test)
if test.shape[:2] != (n_datasets, num_samples):
raise ValueError(
"The provided compute_prior_score function does not return the correct shape. "
f"Expected ({n_datasets}, {num_samples}, ...), got {test.shape}."
)

# Sample using compositional sampling
samples = self._compositional_sample(
num_samples=num_samples,
n_datasets=n_datasets,
n_compositional=n_comp,
compute_prior_score=compute_prior_score_pre,
**prepared_conditions,
**kwargs,
)

if "inference_variables" in self.standardize:
samples = self.standardize_layers["inference_variables"](samples, forward=False)

samples = {"inference_variables": samples}
samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples)

# Back-transform quantities and samples
samples = self.adapter(samples, inverse=True, strict=False, **kwargs)

if split:
samples = split_arrays(samples, axis=-1)
return samples

def _compositional_sample(
self,
num_samples: int,
n_datasets: int,
n_compositional: int,
compute_prior_score: Callable[[Tensor], Tensor],
inference_conditions: Tensor = None,
summary_variables: Tensor = None,
**kwargs,
) -> Tensor:
"""
Internal method for compositional sampling.
"""
if self.summary_network is None:
if summary_variables is not None:
raise ValueError("Cannot use summary variables without a summary network.")
else:
if summary_variables is None:
raise ValueError("Summary variables are required when a summary network is present.")

if self.summary_network is not None:
summary_outputs = self.summary_network(
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
)
inference_conditions = concatenate_valid([inference_conditions, summary_outputs], axis=-1)

if inference_conditions is not None:
# Reshape conditions for compositional sampling
# From (n_datasets * n_comp, ...., dims) to (n_datasets, n_comp, ...., dims)
condition_dims = keras.ops.shape(inference_conditions)[1:]
inference_conditions = keras.ops.reshape(
inference_conditions, (n_datasets, n_compositional, *condition_dims)
)

# Expand for num_samples: (n_datasets, n_comp, dims) -> (n_datasets, n_comp, num_samples, dims)
inference_conditions = keras.ops.expand_dims(inference_conditions, axis=2)
inference_conditions = keras.ops.broadcast_to(
inference_conditions, (n_datasets, n_compositional, num_samples, *condition_dims)
)

batch_shape = (n_datasets, num_samples)
else:
raise ValueError("Cannot perform compositional sampling without inference conditions.")

return self.inference_network.sample(
batch_shape,
conditions=inference_conditions,
compute_prior_score=compute_prior_score,
**filter_kwargs(kwargs, self.inference_network.sample),
)
2 changes: 1 addition & 1 deletion bayesflow/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .consistency_models import ConsistencyModel
from .coupling_flow import CouplingFlow
from .deep_set import DeepSet
from .diffusion_model import DiffusionModel
from .diffusion_model import DiffusionModel, CompositionalDiffusionModel
from .flow_matching import FlowMatching
from .inference_network import InferenceNetwork
from .point_inference_network import PointInferenceNetwork
Expand Down
1 change: 1 addition & 0 deletions bayesflow/networks/diffusion_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .diffusion_model import DiffusionModel
from .compositional_diffusion_model import CompositionalDiffusionModel
from .schedules import CosineNoiseSchedule
from .schedules import EDMNoiseSchedule
from .schedules import NoiseSchedule
Expand Down
Loading
Loading