From 9e90faec7d9e209095e3968bdb732aee386f1f10 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sat, 4 Oct 2025 16:57:08 +0200 Subject: [PATCH 1/2] Adding Dirichlet and SimplexTransform to pymc.dims --- pymc/dims/distributions/transforms.py | 28 ++++++++++++ pymc/dims/distributions/vector.py | 57 ++++++++++++++++++++++++- tests/dims/distributions/test_vector.py | 23 +++++++++- 3 files changed, 106 insertions(+), 2 deletions(-) diff --git a/pymc/dims/distributions/transforms.py b/pymc/dims/distributions/transforms.py index 8f49d2a16..b5f84eff5 100644 --- a/pymc/dims/distributions/transforms.py +++ b/pymc/dims/distributions/transforms.py @@ -54,6 +54,34 @@ def log_jac_det(self, value, *inputs): log_odds_transform = LogOddsTransform() +class SimplexTransform(DimTransform): + name = "simplex" + + def __init__(self, dim: str): + self.core_dim = dim + + def forward(self, value, *inputs): + log_value = ptx.math.log(value) + N = value.sizes[self.core_dim].astype(value.dtype) + shift = log_value.sum(self.core_dim) / N + return log_value.isel({self.core_dim: slice(None, -1)}) - shift + + def backward(self, value, *inputs): + value = ptx.concat([value, -value.sum(self.core_dim)], dim=self.core_dim) + exp_value_max = ptx.math.exp(value - value.max(self.core_dim)) + return exp_value_max / exp_value_max.sum(self.core_dim) + + def log_jac_det(self, value, *inputs): + N = value.sizes[self.core_dim] + 1 + N = N.astype(value.dtype) + sum_value = value.sum(self.core_dim) + value_sum_expanded = value + sum_value + value_sum_expanded = ptx.concat([value_sum_expanded, 0], dim=self.core_dim) + logsumexp_value_expanded = ptx.math.logsumexp(value_sum_expanded, dim=self.core_dim) + res = ptx.math.log(N) + (N * sum_value) - (N * logsumexp_value_expanded) + return res + + class ZeroSumTransform(DimTransform): name = "zerosum" diff --git a/pymc/dims/distributions/vector.py b/pymc/dims/distributions/vector.py index 0ad834c8a..776990d0d 100644 --- a/pymc/dims/distributions/vector.py +++ b/pymc/dims/distributions/vector.py @@ -19,7 +19,7 @@ from pytensor.xtensor import random as pxr from pymc.dims.distributions.core import VectorDimDistribution -from pymc.dims.distributions.transforms import ZeroSumTransform +from pymc.dims.distributions.transforms import SimplexTransform, ZeroSumTransform from pymc.distributions.multivariate import ZeroSumNormalRV from pymc.util import UNSET @@ -63,6 +63,61 @@ def dist(cls, p=None, *, logit_p=None, core_dims=None, **kwargs): return super().dist([p], core_dims=core_dims, **kwargs) +class Dirichlet(VectorDimDistribution): + """Dirichlet distribution. + + Parameters + ---------- + a : xtensor_like, optional + Probabilities of each category. Must sum to 1 along the core dimension. + core_dims : str + The core dimension of the distribution, which represents the categories. + The dimension must be present in `p` or `logit_p`. + **kwargs + Other keyword arguments used to define the distribution. + + Returns + ------- + XTensorVariable + An xtensor variable representing the categorical distribution. + The output does not contain the core dimension, as it is absorbed into the distribution. + + + """ + + xrv_op = ptxr.dirichlet + + @classmethod + def __new__( + cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs + ): + if core_dims is not None: + if isinstance(core_dims, tuple | list): + [core_dims] = core_dims + + # Create default_transform + if observed is None and default_transform is UNSET: + default_transform = SimplexTransform(dim=core_dims) + + # If the user didn't specify dims, take it from core_dims + # We need them to be forwarded to dist in the `dim_lenghts` argument + # if dims is None and core_dims is not None: + # dims = (..., *core_dims) + + return super().__new__( + *args, + core_dims=core_dims, + dims=dims, + default_transform=default_transform, + observed=observed, + **kwargs, + ) + + @classmethod + def dist(cls, a, *, core_dims=None, **kwargs): + return super().dist([a], core_dims=core_dims, **kwargs) + + class MvNormal(VectorDimDistribution): """Multivariate Normal distribution. diff --git a/tests/dims/distributions/test_vector.py b/tests/dims/distributions/test_vector.py index 3a57453b4..c0a276d7a 100644 --- a/tests/dims/distributions/test_vector.py +++ b/tests/dims/distributions/test_vector.py @@ -19,7 +19,7 @@ import pymc.distributions as regular_distributions from pymc import Model -from pymc.dims import Categorical, MvNormal, ZeroSumNormal +from pymc.dims import Categorical, Dirichlet, MvNormal, ZeroSumNormal from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph @@ -40,6 +40,27 @@ def test_categorical(): assert_equivalent_logp_graph(model, reference_model) +def test_dirichlet(): + coords = {"a": range(3), "b": range(2)} + alpha = pt.as_tensor([1, 2, 3]) + + alpha_xr = as_xtensor(alpha, dims=("b",)) + + with Model(coords=coords) as model: + Dirichlet("x", a=alpha_xr, core_dims="b", dims=("a", "b")) + + with Model(coords=coords) as reference_model: + regular_distributions.Dirichlet("x", a=alpha, dims=("a", "b")) + + assert_equivalent_random_graph(model, reference_model) + + # logp graphs end up different, but they mean the same thing + np.testing.assert_allclose( + model.compile_logp()(model.initial_point()), + reference_model.compile_logp()(reference_model.initial_point()), + ) + + def test_mvnormal(): coords = {"a": range(3), "b": range(2)} mu = pt.as_tensor([1, 2]) From e8f0025ded51a015efb175bc6c251fbcd976aeaf Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 5 Oct 2025 01:04:32 +0200 Subject: [PATCH 2/2] Ruff fix --- pymc/dims/distributions/transforms.py | 2 +- pymc/dims/distributions/vector.py | 2 +- tests/dims/distributions/test_vector.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/dims/distributions/transforms.py b/pymc/dims/distributions/transforms.py index b5f84eff5..3425fbe69 100644 --- a/pymc/dims/distributions/transforms.py +++ b/pymc/dims/distributions/transforms.py @@ -59,7 +59,7 @@ class SimplexTransform(DimTransform): def __init__(self, dim: str): self.core_dim = dim - + def forward(self, value, *inputs): log_value = ptx.math.log(value) N = value.sizes[self.core_dim].astype(value.dtype) diff --git a/pymc/dims/distributions/vector.py b/pymc/dims/distributions/vector.py index 776990d0d..7107712e6 100644 --- a/pymc/dims/distributions/vector.py +++ b/pymc/dims/distributions/vector.py @@ -117,7 +117,7 @@ def __new__( def dist(cls, a, *, core_dims=None, **kwargs): return super().dist([a], core_dims=core_dims, **kwargs) - + class MvNormal(VectorDimDistribution): """Multivariate Normal distribution. diff --git a/tests/dims/distributions/test_vector.py b/tests/dims/distributions/test_vector.py index c0a276d7a..8cfdadb37 100644 --- a/tests/dims/distributions/test_vector.py +++ b/tests/dims/distributions/test_vector.py @@ -60,7 +60,7 @@ def test_dirichlet(): reference_model.compile_logp()(reference_model.initial_point()), ) - + def test_mvnormal(): coords = {"a": range(3), "b": range(2)} mu = pt.as_tensor([1, 2])