diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index c34bcc908..b01361562 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -641,3 +641,185 @@ def _batch_size_from_data(self, data: Mapping[str, any]) -> int: inference variables as present. """ return keras.ops.shape(data["inference_variables"])[0] + + def compositional_sample( + self, + *, + num_samples: int, + conditions: Mapping[str, np.ndarray], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], + split: bool = False, + **kwargs, + ) -> dict[str, np.ndarray]: + """ + Generates compositional samples from the approximator given input conditions. + The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...). + This method handles the extra compositional dimension appropriately. + + Parameters + ---------- + num_samples : int + Number of samples to generate. + conditions : dict[str, np.ndarray] + Dictionary of conditioning variables as NumPy arrays with shape + (n_datasets, n_compositional_conditions, ...). + compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + A function that computes the score of the log prior distribution. + split : bool, default=False + Whether to split the output arrays along the last axis and return one column vector per target variable + samples. + **kwargs : dict + Additional keyword arguments for the adapter and sampling process. + + Returns + ------- + dict[str, np.ndarray] + Dictionary containing generated samples with compositional structure preserved. + """ + original_shapes = {} + flattened_conditions = {} + for key, value in conditions.items(): # Flatten compositional dimensions + original_shapes[key] = value.shape + n_datasets, n_comp = value.shape[:2] + flattened_shape = (n_datasets * n_comp,) + value.shape[2:] + flattened_conditions[key] = value.reshape(flattened_shape) + n_datasets, n_comp = original_shapes[next(iter(original_shapes))][:2] + + # Prepare data using existing method (handles adaptation and standardization) + prepared_conditions = self._prepare_data(flattened_conditions, **kwargs) + + # Remove any superfluous keys, just retain actual conditions + prepared_conditions = {k: v for k, v in prepared_conditions.items() if k in self.CONDITION_KEYS} + + # Prepare prior scores to handle adapter + def compute_prior_score_pre(_samples: Tensor) -> Tensor: + if "inference_variables" in self.standardize: + _samples = self.standardize_layers["inference_variables"](_samples, forward=False) + _samples = keras.tree.map_structure(keras.ops.convert_to_numpy, {"inference_variables": _samples}) + adapted_samples, log_det_jac = self.adapter( + _samples, inverse=True, strict=False, log_det_jac=True, **kwargs + ) + + if len(log_det_jac) > 0: + problematic_keys = [key for key in log_det_jac if log_det_jac[key] != 0.0] + raise NotImplementedError( + f"Cannot use compositional sampling with adapters " + f"that have non-zero log_det_jac. Problematic keys: {problematic_keys}" + ) + + prior_score = compute_prior_score(adapted_samples) + for key in adapted_samples: + prior_score[key] = prior_score[key].astype(np.float32) + + prior_score = keras.tree.map_structure(keras.ops.convert_to_tensor, prior_score) + out = keras.ops.concatenate([prior_score[key] for key in adapted_samples], axis=-1) + + if "inference_variables" in self.standardize: + # Apply jacobian correction from standardization + # For standardization T^{-1}(z) = z * std + mean, the jacobian is diagonal with std on diagonal + # The gradient of log|det(J)| w.r.t. z is 0 since log|det(J)| = sum(log(std)) is constant w.r.t. z + # But we need to transform the score: score_z = score_x * std where x = T^{-1}(z) + standardize_layer = self.standardize_layers["inference_variables"] + + # Compute the correct standard deviation for all components + std_components = [] + for idx in range(len(standardize_layer.moving_mean)): + std_val = standardize_layer.moving_std(idx) + std_components.append(std_val) + + # Concatenate std components to match the shape of out + if len(std_components) == 1: + std = std_components[0] + else: + std = keras.ops.concatenate(std_components, axis=-1) + + # Expand std to match batch dimension of out + std_expanded = keras.ops.expand_dims(std, (0, 1)) # Add batch, sample dimensions + std_expanded = keras.ops.tile(std_expanded, [n_datasets, num_samples, 1]) + + # Apply the jacobian: score_z = score_x * std + out = out * std_expanded + return out + + # Test prior score function, useful for debugging + test = self.inference_network.base_distribution.sample((n_datasets, num_samples)) + test = compute_prior_score_pre(test) + if test.shape[:2] != (n_datasets, num_samples): + raise ValueError( + "The provided compute_prior_score function does not return the correct shape. " + f"Expected ({n_datasets}, {num_samples}, ...), got {test.shape}." + ) + + # Sample using compositional sampling + samples = self._compositional_sample( + num_samples=num_samples, + n_datasets=n_datasets, + n_compositional=n_comp, + compute_prior_score=compute_prior_score_pre, + **prepared_conditions, + **kwargs, + ) + + if "inference_variables" in self.standardize: + samples = self.standardize_layers["inference_variables"](samples, forward=False) + + samples = {"inference_variables": samples} + samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples) + + # Back-transform quantities and samples + samples = self.adapter(samples, inverse=True, strict=False, **kwargs) + + if split: + samples = split_arrays(samples, axis=-1) + return samples + + def _compositional_sample( + self, + num_samples: int, + n_datasets: int, + n_compositional: int, + compute_prior_score: Callable[[Tensor], Tensor], + inference_conditions: Tensor = None, + summary_variables: Tensor = None, + **kwargs, + ) -> Tensor: + """ + Internal method for compositional sampling. + """ + if self.summary_network is None: + if summary_variables is not None: + raise ValueError("Cannot use summary variables without a summary network.") + else: + if summary_variables is None: + raise ValueError("Summary variables are required when a summary network is present.") + + if self.summary_network is not None: + summary_outputs = self.summary_network( + summary_variables, **filter_kwargs(kwargs, self.summary_network.call) + ) + inference_conditions = concatenate_valid([inference_conditions, summary_outputs], axis=-1) + + if inference_conditions is not None: + # Reshape conditions for compositional sampling + # From (n_datasets * n_comp, ...., dims) to (n_datasets, n_comp, ...., dims) + condition_dims = keras.ops.shape(inference_conditions)[1:] + inference_conditions = keras.ops.reshape( + inference_conditions, (n_datasets, n_compositional, *condition_dims) + ) + + # Expand for num_samples: (n_datasets, n_comp, dims) -> (n_datasets, n_comp, num_samples, dims) + inference_conditions = keras.ops.expand_dims(inference_conditions, axis=2) + inference_conditions = keras.ops.broadcast_to( + inference_conditions, (n_datasets, n_compositional, num_samples, *condition_dims) + ) + + batch_shape = (n_datasets, num_samples) + else: + raise ValueError("Cannot perform compositional sampling without inference conditions.") + + return self.inference_network.sample( + batch_shape, + conditions=inference_conditions, + compute_prior_score=compute_prior_score, + **filter_kwargs(kwargs, self.inference_network.sample), + ) diff --git a/bayesflow/networks/__init__.py b/bayesflow/networks/__init__.py index f71d4b536..fb9819445 100644 --- a/bayesflow/networks/__init__.py +++ b/bayesflow/networks/__init__.py @@ -7,7 +7,7 @@ from .consistency_models import ConsistencyModel from .coupling_flow import CouplingFlow from .deep_set import DeepSet -from .diffusion_model import DiffusionModel +from .diffusion_model import DiffusionModel, CompositionalDiffusionModel from .flow_matching import FlowMatching from .inference_network import InferenceNetwork from .point_inference_network import PointInferenceNetwork diff --git a/bayesflow/networks/diffusion_model/__init__.py b/bayesflow/networks/diffusion_model/__init__.py index 341c84c62..ca8aa19be 100644 --- a/bayesflow/networks/diffusion_model/__init__.py +++ b/bayesflow/networks/diffusion_model/__init__.py @@ -1,4 +1,5 @@ from .diffusion_model import DiffusionModel +from .compositional_diffusion_model import CompositionalDiffusionModel from .schedules import CosineNoiseSchedule from .schedules import EDMNoiseSchedule from .schedules import NoiseSchedule diff --git a/bayesflow/networks/diffusion_model/compositional_diffusion_model.py b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py new file mode 100644 index 000000000..171184314 --- /dev/null +++ b/bayesflow/networks/diffusion_model/compositional_diffusion_model.py @@ -0,0 +1,412 @@ +from typing import Literal, Callable + +import keras +import numpy as np +from keras import ops + +from bayesflow.types import Tensor +from bayesflow.utils import ( + expand_right_as, + integrate, + integrate_stochastic, +) +from bayesflow.utils.serialization import serializable +from .diffusion_model import DiffusionModel +from .schedules.noise_schedule import NoiseSchedule + + +# disable module check, use potential module after moving from experimental +@serializable("bayesflow.networks", disable_module_check=True) +class CompositionalDiffusionModel(DiffusionModel): + """Compositional Diffusion Model for Amortized Bayesian Inference. Allows to learn a single + diffusion model one single i.i.d simulations that can perform inference for multiple simulations by leveraging a + compositional score function as in [2]. + + [1] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021) + [2] Compositional Score Modeling for Simulation-Based Inference: Geffner et al. (2023) + [3] Compositional amortized inference for large-scale hierarchical Bayesian models: Arruda et al. (2025) + """ + + MLP_DEFAULT_CONFIG = { + "widths": (256, 256, 256, 256, 256), + "activation": "mish", + "kernel_initializer": "he_normal", + "residual": True, + "dropout": 0.0, + "spectral_normalization": False, + } + + INTEGRATE_DEFAULT_CONFIG = { + "method": "euler_maruyama", + "corrector_steps": 1, + "steps": 100, + } + + def __init__( + self, + *, + subnet: str | type | keras.Layer = "mlp", + noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm", + prediction_type: Literal["velocity", "noise", "F", "x"] = "F", + loss_type: Literal["velocity", "noise", "F"] = "noise", + subnet_kwargs: dict[str, any] = None, + schedule_kwargs: dict[str, any] = None, + integrate_kwargs: dict[str, any] = None, + **kwargs, + ): + """ + Initializes a diffusion model with configurable subnet architecture, noise schedule, + and prediction/loss types for amortized Bayesian inference. + + Note, that score-based diffusion is the most sluggish of all available samplers, + so expect slower inference times than flow matching and much slower than normalizing flows. + + Parameters + ---------- + subnet : str, type or keras.Layer, optional + Architecture for the transformation network. Can be "mlp", a custom network class, or + a Layer object, e.g., `bayesflow.networks.MLP(widths=[32, 32])`. Default is "mlp". + noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional + Noise schedule controlling the diffusion dynamics. Can be a string identifier, + a schedule class, or a pre-initialized schedule instance. Default is "edm". + prediction_type : {'velocity', 'noise', 'F', 'x'}, optional + Output format of the model's prediction. Default is "F". + loss_type : {'velocity', 'noise', 'F'}, optional + Loss function used to train the model. Default is "noise". + subnet_kwargs : dict[str, any], optional + Additional keyword arguments passed to the subnet constructor. Default is None. + schedule_kwargs : dict[str, any], optional + Additional keyword arguments passed to the noise schedule constructor. Default is None. + integrate_kwargs : dict[str, any], optional + Configuration dictionary for integration during training or inference. Default is None. + concatenate_subnet_input: bool, optional + Flag for advanced users to control whether all inputs to the subnet should be concatenated + into a single vector or passed as separate arguments. If set to False, the subnet + must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio), + and optional 'conditions'. Default is True. + + **kwargs + Additional keyword arguments passed to the base class and internal components. + """ + super().__init__( + subnet=subnet, + noise_schedule=noise_schedule, + prediction_type=prediction_type, + loss_type=loss_type, + subnet_kwargs=subnet_kwargs, + schedule_kwargs=schedule_kwargs, + integrate_kwargs=integrate_kwargs, + **kwargs, + ) + + def compositional_bridge(self, time: Tensor) -> Tensor: + """ + Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1. + Otherwise, it can be used to scale the compositional score over time. + + Parameters + ---------- + time: Tensor + Time step for the diffusion process. + + Returns + ------- + Tensor + Bridge function value with same shape as time. + + """ + return ops.exp(-np.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time) + + def compositional_velocity( + self, + xz: Tensor, + time: float | Tensor, + stochastic_solver: bool, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional velocity for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + stochastic_solver : bool + Whether to use stochastic (SDE) or deterministic (ODE) formulation + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + compositional_score = self.compositional_score( + xz=xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + + # Compute velocity using standard drift-diffusion formulation + f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) + + if stochastic_solver: + # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW + velocity = f - g_squared * compositional_score + else: + # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt + velocity = f - 0.5 * g_squared * compositional_score + + return velocity + + def compositional_score( + self, + xz: Tensor, + time: float | Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + mini_batch_size: int | None = None, + training: bool = False, + ) -> Tensor: + """ + Computes the compositional score for multiple datasets using the formula: + s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + + Parameters + ---------- + xz : Tensor + The current state of the latent variable, shape (n_datasets, n_compositional, ...) + time : float or Tensor + Time step for the diffusion process + conditions : Tensor + Conditional inputs with compositional structure (n_datasets, n_compositional, ...) + compute_prior_score: Callable + Function to compute the prior score ∇_θ log p(θ). + mini_batch_size : int or None + Mini batch size for computing individual scores. If None, use all conditions. + training : bool, optional + Whether in training mode + + Returns + ------- + Tensor + Compositional velocity of same shape as input xz + """ + if conditions is None: + raise ValueError("Conditions are required for compositional sampling") + + # Get shapes for compositional structure + n_compositional = ops.shape(conditions)[1] + + # Calculate standard noise schedule components + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + + # Compute individual dataset scores + if mini_batch_size is not None and mini_batch_size < n_compositional: + # sample random indices for mini-batch processing + mini_batch_idx = keras.random.shuffle(ops.arange(n_compositional), seed=self.seed_generator) + mini_batch_idx = mini_batch_idx[:mini_batch_size] + conditions_batch = conditions[:, mini_batch_idx] + else: + conditions_batch = conditions + individual_scores = self._compute_individual_scores(xz, log_snr_t, conditions_batch, training) + + # Compute prior score component + prior_score = compute_prior_score(xz) + weighted_prior_score = (1.0 - n_compositional) * (1.0 - time) * prior_score + + # Sum individual scores across compositional dimensions + summed_individual_scores = n_compositional * ops.mean(individual_scores, axis=1) + + # Combined score using compositional formula: (1-n)(1-t)∇log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ) + time_tensor = ops.cast(time, dtype=ops.dtype(xz)) + compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores) + return compositional_score + + def _compute_individual_scores( + self, + xz: Tensor, + log_snr_t: Tensor, + conditions: Tensor, + training: bool, + ) -> Tensor: + """ + Compute individual dataset scores s_ψ(θ,t,yᵢ) for each compositional condition. + + Returns + ------- + Tensor + Individual scores with shape (n_datasets, n_compositional, ...) + """ + # Get shapes + xz_shape = ops.shape(xz) # (n_datasets, num_samples, ..., dims) + conditions_shape = ops.shape(conditions) # (n_datasets, n_compositional, num_samples, ..., dims) + n_datasets, n_compositional = conditions_shape[0], conditions_shape[1] + conditions_dims = tuple(conditions_shape[3:]) + num_samples = xz_shape[1] + dims = tuple(xz_shape[2:]) + + # Expand xz to match compositional structure + xz_expanded = ops.expand_dims(xz, axis=1) # (n_datasets, 1, num_samples, ..., dims) + xz_expanded = ops.broadcast_to(xz_expanded, (n_datasets, n_compositional, num_samples) + dims) + + # Expand log_snr_t to match compositional structure + log_snr_expanded = ops.expand_dims(log_snr_t, axis=1) + log_snr_expanded = ops.broadcast_to(log_snr_expanded, (n_datasets, n_compositional, num_samples, 1)) + + # Flatten for score computation: (n_datasets * n_compositional, num_samples, ..., dims) + xz_flat = ops.reshape(xz_expanded, (n_datasets * n_compositional, num_samples) + dims) + log_snr_flat = ops.reshape(log_snr_expanded, (n_datasets * n_compositional, num_samples, 1)) + conditions_flat = ops.reshape(conditions, (n_datasets * n_compositional, num_samples) + conditions_dims) + + # Use standard score function + scores_flat = self.score(xz_flat, log_snr_t=log_snr_flat, conditions=conditions_flat, training=training) + + # Reshape back to compositional structure + scores = ops.reshape(scores_flat, (n_datasets, n_compositional, num_samples) + dims) + return scores + + def _inverse_compositional( + self, + z: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + """ + Inverse pass for compositional diffusion sampling. + """ + n_compositional = ops.shape(conditions)[1] + integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} + integrate_kwargs = integrate_kwargs | self.integrate_kwargs + integrate_kwargs = integrate_kwargs | kwargs + if keras.backend.backend() == "jax": + mini_batch_size = integrate_kwargs.pop("mini_batch_size", None) + if mini_batch_size is not None: + raise ValueError( + "Mini batching is not supported with JAX backend. Set mini_batch_size to None " + "or use another backend." + ) + else: + mini_batch_size = max(integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1)), 1) + self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0)) + self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0)) + + # x is sampled from a normal distribution, must be scaled with var 1/n_compositional + scale_latent = n_compositional * self.compositional_bridge(ops.ones(1)) + z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z))) + + if density: + if integrate_kwargs["method"] == "euler_maruyama": + raise ValueError("Stochastic methods are not supported for density computation.") + + def deltas(time, xz): + v = self.compositional_velocity( + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + trace = ops.zeros(ops.shape(xz)[:-1] + (1,), dtype=ops.dtype(xz)) + return {"xz": v, "trace": trace} + + state = { + "xz": z, + "trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)), + } + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + log_density = self.base_distribution.log_prob(ops.mean(z, axis=1)) - ops.squeeze(state["trace"], axis=-1) + return x, log_density + + state = {"xz": z} + + if integrate_kwargs["method"] == "euler_maruyama": + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, + time=time, + stochastic_solver=True, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + def diffusion(time, xz): + return {"xz": self.diffusion_term(xz, time=time, training=training)} + + score_fn = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def score_fn(time, xz): + return { + "xz": self.compositional_score( + xz, + time=time, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = integrate_stochastic( + drift_fn=deltas, + diffusion_fn=diffusion, + score_fn=score_fn, + noise_schedule=self.noise_schedule, + state=state, + seed=self.seed_generator, + **integrate_kwargs, + ) + else: + integrate_kwargs.pop("corrector_steps", None) + + def deltas(time, xz): + return { + "xz": self.compositional_velocity( + xz, + time=time, + stochastic_solver=False, + conditions=conditions, + compute_prior_score=compute_prior_score, + mini_batch_size=mini_batch_size, + training=training, + ) + } + + state = integrate(deltas, state, **integrate_kwargs) + + x = state["xz"] + return x diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index ca8a634e9..9955c4abc 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -243,6 +243,55 @@ def _apply_subnet( else: return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training) + def score( + self, + xz: Tensor, + time: float | Tensor = None, + log_snr_t: Tensor = None, + conditions: Tensor = None, + training: bool = False, + ) -> Tensor: + """ + Computes the score of the target or latent variable `xz`. + + Parameters + ---------- + xz : Tensor + The current state of the latent variable `z`, typically of shape (..., D), + where D is the dimensionality of the latent space. + time : float or Tensor + Scalar or tensor representing the time (or noise level) at which the velocity + should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided. + log_snr_t : Tensor + The log signal-to-noise ratio at time `t`. If None, time must be provided. + conditions : Tensor, optional + Conditional inputs to the network, such as conditioning variables + or encoder outputs. Shape must be broadcastable with `xz`. Default is None. + training : bool, optional + Whether the model is in training mode. Affects behavior of dropout, batch norm, + or other stochastic layers. Default is False. + + Returns + ------- + Tensor + The velocity tensor of the same shape as `xz`, representing the right-hand + side of the SDE or ODE at the given `time`. + """ + if log_snr_t is None: + log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) + log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) + alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + subnet_out = self._apply_subnet( + xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training + ) + pred = self.output_projector(subnet_out, training=training) + + x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) + + score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + return score + def velocity( self, xz: Tensor, @@ -279,19 +328,10 @@ def velocity( The velocity tensor of the same shape as `xz`, representing the right-hand side of the SDE or ODE at the given `time`. """ - # calculate the current noise level and transform into correct shape log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz) log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) - alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - - subnet_out = self._apply_subnet( - xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training - ) - pred = self.output_projector(subnet_out, training=training) - - x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) - score = (alpha_t * x_pred - xz) / ops.square(sigma_t) + score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training) # compute velocity f, g of the SDE or ODE f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training) @@ -362,6 +402,7 @@ def _forward( conditions: Tensor = None, density: bool = False, training: bool = False, + compositional: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0} @@ -412,6 +453,7 @@ def _inverse( conditions: Tensor = None, density: bool = False, training: bool = False, + compositional: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0} @@ -447,9 +489,25 @@ def deltas(time, xz): def diffusion(time, xz): return {"xz": self.diffusion_term(xz, time=time, training=training)} + score_fn = None + if "corrector_steps" in integrate_kwargs: + if integrate_kwargs["corrector_steps"] > 0: + + def score_fn(time, xz): + return { + "xz": self.score( + xz, + time=time, + conditions=conditions, + training=training, + ) + } + state = integrate_stochastic( drift_fn=deltas, diffusion_fn=diffusion, + score_fn=score_fn, + noise_schedule=self.noise_schedule, state=state, seed=self.seed_generator, **integrate_kwargs, diff --git a/bayesflow/networks/inference_network.py b/bayesflow/networks/inference_network.py index b092ce2cb..9488f644d 100644 --- a/bayesflow/networks/inference_network.py +++ b/bayesflow/networks/inference_network.py @@ -1,3 +1,4 @@ +from typing import Callable import keras from bayesflow.types import Shape, Tensor @@ -27,11 +28,30 @@ def call( conditions: Tensor = None, inverse: bool = False, density: bool = False, + compute_prior_score: Callable[[Tensor], Tensor] = None, training: bool = False, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: if inverse: + if compute_prior_score is not None: + return self._inverse_compositional( + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, + ) return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs) + if compute_prior_score is not None: + return self._forward_compositional( + xz, + conditions=conditions, + compute_prior_score=compute_prior_score, + density=density, + training=training, + **kwargs, + ) return self._forward(xz, conditions=conditions, density=density, training=training, **kwargs) def _forward( @@ -44,6 +64,28 @@ def _inverse( ) -> Tensor | tuple[Tensor, Tensor]: raise NotImplementedError + def _forward_compositional( + self, + x: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + raise NotImplementedError + + def _inverse_compositional( + self, + z: Tensor, + conditions: Tensor, + compute_prior_score: Callable[[Tensor], Tensor], + density: bool = False, + training: bool = False, + **kwargs, + ) -> Tensor | tuple[Tensor, Tensor]: + raise NotImplementedError + @allow_batch_size def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> Tensor: samples = self.base_distribution.sample(batch_shape) diff --git a/bayesflow/simulators/sequential_simulator.py b/bayesflow/simulators/sequential_simulator.py index 21e1542e6..96ab0ead3 100644 --- a/bayesflow/simulators/sequential_simulator.py +++ b/bayesflow/simulators/sequential_simulator.py @@ -88,7 +88,7 @@ def _single_sample(self, batch_shape_ext, **kwargs) -> dict[str, np.ndarray]: return self.sample(batch_shape=(1, *tuple(batch_shape_ext)), **kwargs) def sample_parallel( - self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 0, **kwargs + self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 1, **kwargs ) -> dict[str, np.ndarray]: """ Sample in parallel from the sequential simulator. @@ -101,7 +101,7 @@ def sample_parallel( n_jobs : int, optional Number of parallel jobs. -1 uses all available cores. Default is -1. verbose : int, optional - Verbosity level for joblib. Default is 0 (no output). + Verbosity level for joblib. Default is 1 (minimal output). **kwargs Additional keyword arguments passed to each simulator. These may include previously sampled outputs used as inputs for subsequent simulators. diff --git a/bayesflow/simulators/simulator.py b/bayesflow/simulators/simulator.py index 00d3d84f3..53d54e455 100644 --- a/bayesflow/simulators/simulator.py +++ b/bayesflow/simulators/simulator.py @@ -95,3 +95,8 @@ def accept_all_predicate(x): return np.full((sample_size,), True) return self.rejection_sample(batch_shape, predicate=accept_all_predicate, sample_size=sample_size, **kwargs) + + def sample_parallel( + self, batch_shape: Shape, n_jobs: int = -1, verbose: int = 1, **kwargs + ) -> dict[str, np.ndarray]: + raise NotImplementedError diff --git a/bayesflow/utils/integrate.py b/bayesflow/utils/integrate.py index b197ea975..961015b8f 100644 --- a/bayesflow/utils/integrate.py +++ b/bayesflow/utils/integrate.py @@ -401,11 +401,19 @@ def integrate_stochastic( steps: int, seed: keras.random.SeedGenerator, method: str = "euler_maruyama", + score_fn: Callable = None, + corrector_steps: int = 0, + noise_schedule=None, + step_size_factor: float = 0.1, **kwargs, ) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]: """ Integrates a stochastic differential equation from start_time to stop_time. + When score_fn is provided, performs predictor-corrector sampling where: + - Predictor: reverse diffusion SDE solver + - Corrector: annealed Langevin dynamics with step size e = sqrt(dim) + Args: drift_fn: Function that computes the drift term. diffusion_fn: Function that computes the diffusion term. @@ -415,11 +423,15 @@ def integrate_stochastic( steps: Number of integration steps. seed: Random seed for noise generation. method: Integration method to use, e.g., 'euler_maruyama'. + score_fn: Optional score function for predictor-corrector sampling. + Should take (time, **state) and return score dict. + corrector_steps: Number of corrector steps to take after each predictor step. + noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector. + step_size_factor: Scaling factor for corrector step size. **kwargs: Additional arguments to pass to the step function. Returns: - If return_noise is False, returns the final state dictionary. - If return_noise is True, returns a tuple of (final_state, noise_history). + Final state dictionary after integration. """ if steps <= 0: raise ValueError("Number of steps must be positive.") @@ -438,17 +450,56 @@ def integrate_stochastic( step_size = (stop_time - start_time) / steps sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size)) - # Pre-generate noise history: shape = (steps, *state_shape) + # Pre-generate noise history for predictor: shape = (steps, *state_shape) noise_history = {} for key, val in state.items(): noise_history[key] = ( keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt ) + # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape) + corrector_noise_history = {} + if corrector_steps > 0: + if score_fn is None or noise_schedule is None: + raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") + + for key, val in state.items(): + corrector_noise_history[key] = keras.random.normal( + (steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed + ) + def body(_loop_var, _loop_state): _current_state, _current_time = _loop_state _noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()} + + # Predictor step new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i) + + # Corrector steps: annealed Langevin dynamics if score_fn is provided + if corrector_steps > 0: + for corrector_step in range(corrector_steps): + score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) + _corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()} + + # Compute noise schedule components for corrector step size + log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) + alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) + + # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector + # where e = 2*alpha_t * (r * ||z|| / ||score||)**2 + for k in new_state.keys(): + if k in score: + z_norm = keras.ops.norm(_corrector_noise[k], axis=-1, keepdims=True) + score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) + + # Prevent division by zero + score_norm = keras.ops.maximum(score_norm, 1e-8) + + e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 + sqrt_2e = keras.ops.sqrt(2.0 * e) + + new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k] + return new_state, new_time final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time)) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 34fa03794..cfa63545b 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -286,6 +286,42 @@ def sample( """ return self.approximator.sample(num_samples=num_samples, conditions=conditions, **kwargs) + def compositional_sample( + self, + *, + num_samples: int, + conditions: Mapping[str, np.ndarray], + compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray], + **kwargs, + ) -> dict[str, np.ndarray]: + """ + Draws `num_samples` samples from the approximator given specified composition conditions. + The `conditions` dictionary should have shape (n_datasets, n_compositional_conditions, ...). + + Parameters + ---------- + num_samples : int + The number of samples to generate. + conditions : dict[str, np.ndarray] + A dictionary where keys represent variable names and values are + NumPy arrays containing the adapted simulated variables. Keys used as summary or inference + conditions during training should be present. + Should have shape (n_datasets, n_compositional_conditions, ...). + compute_prior_score : Callable[[Mapping[str, np.ndarray]], np.ndarray] + A function that computes the log probability of samples under the prior distribution. + **kwargs : dict, optional + Additional keyword arguments passed to the approximator's sampling function. + + Returns + ------- + dict[str, np.ndarray] + A dictionary where keys correspond to variable names and + values are arrays containing the generated samples. + """ + return self.approximator.compositional_sample( + num_samples=num_samples, conditions=conditions, compute_prior_score=compute_prior_score, **kwargs + ) + def estimate( self, *, diff --git a/tests/test_approximators/conftest.py b/tests/test_approximators/conftest.py index a56802a3e..befc0da06 100644 --- a/tests/test_approximators/conftest.py +++ b/tests/test_approximators/conftest.py @@ -220,3 +220,71 @@ def approximator_with_summaries(request): ) case _: raise ValueError("Invalid param for approximator class.") + + +@pytest.fixture +def simple_log_simulator(): + """Create a simple simulator for testing.""" + import numpy as np + from bayesflow.simulators import Simulator + from bayesflow.utils.decorators import allow_batch_size + from bayesflow.types import Shape, Tensor + + class SimpleSimulator(Simulator): + """Simple simulator that generates mean and scale parameters.""" + + @allow_batch_size + def sample(self, batch_shape: Shape) -> dict[str, Tensor]: + # Generate parameters in original space + loc = np.random.normal(0.0, 1.0, size=batch_shape + (2,)) # location parameters + scale = np.random.lognormal(0.0, 0.5, size=batch_shape + (2,)) # scale parameters > 0 + + # Generate some dummy conditions + conditions = np.random.normal(0.0, 1.0, size=batch_shape + (3,)) + + return dict( + loc=loc.astype("float32"), scale=scale.astype("float32"), conditions=conditions.astype("float32") + ) + + return SimpleSimulator() + + +@pytest.fixture +def identity_adapter(): + """Create an adapter that applies no transformation to the parameters.""" + from bayesflow.adapters import Adapter + + adapter = Adapter() + adapter.to_array() + adapter.convert_dtype("float64", "float32") + + adapter.concatenate(["loc"], into="inference_variables") + adapter.concatenate(["conditions"], into="inference_conditions") + adapter.keep(["inference_variables", "inference_conditions"]) + return adapter + + +@pytest.fixture +def transforming_adapter(): + """Create an adapter that applies log transformation to scale parameters.""" + from bayesflow.adapters import Adapter + + adapter = Adapter() + adapter.to_array() + adapter.convert_dtype("float64", "float32") + + # Apply log transformation to scale parameters (to make them unbounded) + adapter.log(["scale"]) + + adapter.concatenate(["scale", "loc"], into="inference_variables") + adapter.concatenate(["conditions"], into="inference_conditions") + adapter.keep(["inference_variables", "inference_conditions"]) + return adapter + + +@pytest.fixture +def diffusion_network(): + """Create a diffusion network for compositional sampling.""" + from bayesflow.networks import DiffusionModel, MLP + + return DiffusionModel(subnet=MLP(widths=[32, 32])) diff --git a/tests/test_approximators/test_compositional_prior_score.py b/tests/test_approximators/test_compositional_prior_score.py new file mode 100644 index 000000000..02be46c00 --- /dev/null +++ b/tests/test_approximators/test_compositional_prior_score.py @@ -0,0 +1,43 @@ +"""Tests for compositional sampling and prior score computation with adapters.""" + +import numpy as np + +from bayesflow import ContinuousApproximator + + +def mock_prior_score_original_space(data_dict): + """Mock prior score function that expects data in original space.""" + loc = data_dict["loc"] + + # Simple prior: N(0,1) for loc + loc_score = -loc + return {"loc": loc_score} + + +def test_prior_score_identity_adapter(simple_log_simulator, identity_adapter, diffusion_network): + """Test that prior scores work correctly with transforming adapter (log transformation).""" + + # Create approximator with transforming adapter + approximator = ContinuousApproximator( + adapter=identity_adapter, + inference_network=diffusion_network, + ) + + # Generate test data and adapt it + data = simple_log_simulator.sample((2,)) + adapted_data = identity_adapter(data) + + # Build approximator + approximator.build_from_data(adapted_data) + + # Test compositional sampling + n_datasets, n_compositional = 3, 5 + conditions = {"conditions": np.random.normal(0.0, 1.0, (n_datasets, n_compositional, 3)).astype("float32")} + samples = approximator.compositional_sample( + num_samples=10, + conditions=conditions, + compute_prior_score=mock_prior_score_original_space, + ) + + assert "loc" in samples + assert samples["loc"].shape == (n_datasets, 10, 2) diff --git a/tests/test_networks/test_diffusion_model/conftest.py b/tests/test_networks/test_diffusion_model/conftest.py index b1ee915ae..581b4abde 100644 --- a/tests/test_networks/test_diffusion_model/conftest.py +++ b/tests/test_networks/test_diffusion_model/conftest.py @@ -1,4 +1,5 @@ import pytest +import keras @pytest.fixture() @@ -21,3 +22,49 @@ def edm_noise_schedule(): ) def noise_schedule(request): return request.getfixturevalue(request.param) + + +@pytest.fixture +def simple_diffusion_model(): + """Create a simple diffusion model for testing compositional sampling.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + return DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule="cosine", + prediction_type="noise", + loss_type="noise", + ) + + +@pytest.fixture +def compositional_conditions(): + """Create test conditions for compositional sampling.""" + batch_size = 2 + n_compositional = 3 + n_samples = 4 + condition_dim = 5 + + return keras.random.normal((batch_size, n_compositional, n_samples, condition_dim)) + + +@pytest.fixture +def compositional_state(): + """Create test state for compositional sampling.""" + batch_size = 2 + n_samples = 4 + param_dim = 3 + + return keras.random.normal((batch_size, n_samples, param_dim)) + + +@pytest.fixture +def mock_prior_score(): + """Create a mock prior score function for testing.""" + + def prior_score_fn(theta): + # Simple quadratic prior: -0.5 * ||theta||^2 + return -theta + + return prior_score_fn diff --git a/tests/test_networks/test_diffusion_model/test_compositional_sampling.py b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py new file mode 100644 index 000000000..2757bd28a --- /dev/null +++ b/tests/test_networks/test_diffusion_model/test_compositional_sampling.py @@ -0,0 +1,132 @@ +import keras +import pytest + + +def test_compositional_score_shape( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test that compositional score returns correct shapes.""" + # Build the model + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + time = 0.5 + + score = simple_diffusion_model.compositional_score( + xz=compositional_state, + time=time, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + training=False, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(score) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +def test_compositional_score_no_conditions_raises_error(simple_diffusion_model, compositional_state, mock_prior_score): + """Test that compositional score raises error when conditions is None.""" + simple_diffusion_model.build(keras.ops.shape(compositional_state), None) + + with pytest.raises(ValueError, match="Conditions are required for compositional sampling"): + simple_diffusion_model.compositional_score( + xz=compositional_state, time=0.5, conditions=None, compute_prior_score=mock_prior_score, training=False + ) + + +def test_inverse_compositional_basic( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test basic compositional inverse sampling.""" + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + # Test inverse sampling with ODE method + result = simple_diffusion_model._inverse_compositional( + z=compositional_state, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + density=False, + training=False, + method="euler", + steps=5, + start_time=1.0, + stop_time=0.0, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(result) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +def test_inverse_compositional_euler_maruyama_with_corrector( + simple_diffusion_model, compositional_state, compositional_conditions, mock_prior_score +): + """Test compositional inverse sampling with Euler-Maruyama and corrector steps.""" + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + simple_diffusion_model.build(state_shape, conditions_shape) + + result = simple_diffusion_model._inverse_compositional( + z=compositional_state, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + density=False, + training=False, + method="euler_maruyama", + steps=5, + corrector_steps=2, + start_time=1.0, + stop_time=0.0, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(result) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + ) + + +@pytest.mark.parametrize("noise_schedule_name", ["cosine", "edm"]) +def test_compositional_sampling_with_different_schedules( + noise_schedule_name, compositional_state, compositional_conditions, mock_prior_score +): + """Test compositional sampling with different noise schedules.""" + from bayesflow.networks.diffusion_model import DiffusionModel + from bayesflow.networks import MLP + + diffusion_model = DiffusionModel( + subnet=MLP(widths=[32, 32]), + noise_schedule=noise_schedule_name, + prediction_type="noise", + loss_type="noise", + ) + + state_shape = keras.ops.shape(compositional_state) + conditions_shape = keras.ops.shape(compositional_conditions) + diffusion_model.build(state_shape, conditions_shape) + + score = diffusion_model.compositional_score( + xz=compositional_state, + time=0.5, + conditions=compositional_conditions, + compute_prior_score=mock_prior_score, + training=False, + ) + + expected_shape = keras.ops.shape(compositional_state) + actual_shape = keras.ops.shape(score) + + assert keras.ops.all(keras.ops.equal(expected_shape, actual_shape)), ( + f"Expected shape {expected_shape}, got {actual_shape}" + )