diff --git a/test/test_distributions.py b/test/test_distributions.py index e47c16d2f0347a..28a6bbed921523 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -821,7 +821,6 @@ def test_binomial(self): self._gradcheck_log_prob(lambda p: Binomial(total_count, p), [p]) self._gradcheck_log_prob(lambda p: Binomial(total_count, None, p.log()), [p]) self.assertRaises(NotImplementedError, Binomial(10, p).rsample) - self.assertRaises(NotImplementedError, Binomial(10, p).entropy) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") def test_binomial_log_prob(self): @@ -2757,6 +2756,7 @@ def __init__(self, probs): [0.33, 0.33, 0.34], [0.2, 0.2, 0.4]]) exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0]) + geometric = pairwise(Geometric, [0.1, 0.2, 0.6, 0.9]) gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]) gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5]) halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0]) @@ -2792,6 +2792,8 @@ def __init__(self, probs): (beta, gamma), (beta, normal), (binomial30, binomial30), + (binomial30, poisson), + (binomial30, geometric), (binomial_vectorized_count, binomial_vectorized_count), (categorical, categorical), (chi2, chi2), @@ -2849,6 +2851,7 @@ def __init__(self, probs): (Exponential(1), Beta(2, 3)), (Exponential(1), Pareto(2, 3)), (Exponential(1), Uniform(-2, 3)), + (Geometric(0.3), Binomial(10, 0.2)), (Gamma(1, 2), Beta(3, 4)), (Gamma(1, 2), Pareto(3, 4)), (Gamma(1, 2), Uniform(-3, 4)), diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 28756d405383f5..ad228a8f542571 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -5,6 +5,35 @@ from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs +def _log1pmtensor(logit_tensor): + """ + Calculates (-tensor).log1p() using logit_tensor = tensor.log() - (-tensor).log() + Useful for distributions with extreme probs. + Note that: (-probs).log1p() = max_val - (logits + 2 * max_val).exp().log1p() + """ + max_val = (-logit_tensor).clamp(min=0.0) + return max_val - torch.log1p((logit_tensor + 2 * max_val).exp()) + + +def _Elnchoosek(p): + """ + Returns expected value of log(nchoosek), log(n!), log(k!), log(n-k!); + where k~p, p is a Binomial distribution + """ + s = p.enumerate_support() + s[0] = 1 # 0! = 1 + # x is log factorial matrix i.e. x[k,...] = log(k!) + x = torch.cumsum(s.log(), dim=0) + s[0] = 0 + lnchoosek = x[-1] - x - x.flip(0) + elognfac = x[-1] + elogkfac = ((lnchoosek + s * p.logits + p.total_count * _log1pmtensor(p.logits)).exp() * + x).sum(dim=0) + elognmkfac = ((lnchoosek + s * p.logits + p.total_count * _log1pmtensor(p.logits)).exp() * + x.flip(0)).sum(dim=0) + return elognfac - elogkfac - elognmkfac, (elognfac, elogkfac, elognmkfac) + + class Binomial(Distribution): r""" Creates a Binomial distribution parameterized by `total_count` and @@ -94,11 +123,8 @@ def log_prob(self, value): log_factorial_n = torch.lgamma(self.total_count + 1) log_factorial_k = torch.lgamma(value + 1) log_factorial_nmk = torch.lgamma(self.total_count - value + 1) - max_val = (-self.logits).clamp(min=0.0) - # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp())) return (log_factorial_n - log_factorial_k - log_factorial_nmk + - value * self.logits + self.total_count * max_val - - self.total_count * torch.log1p((self.logits + 2 * max_val).exp())) + value * self.logits + self.total_count * _log1pmtensor(self.logits)) def enumerate_support(self): total_count = int(self.total_count.max()) @@ -109,3 +135,7 @@ def enumerate_support(self): values = values.view((-1,) + (1,) * len(self._batch_shape)) values = values.expand((-1,) + self._batch_shape) return values + + def entropy(self): + elnchoosek, _ = _Elnchoosek(self) + return - elnchoosek - self.mean * self.logits - self.total_count * _log1pmtensor(self.logits) diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 2ae67fc28ccbcd..50f8c6cae834c0 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -6,7 +6,7 @@ from .bernoulli import Bernoulli from .beta import Beta -from .binomial import Binomial +from .binomial import Binomial, _log1pmtensor, _Elnchoosek from .categorical import Categorical from .dirichlet import Dirichlet from .distribution import Distribution @@ -199,7 +199,7 @@ def _kl_binomial_binomial(p, q): # kullback-leibler-divergence-for-binomial-distributions-p-and-q if (p.total_count < q.total_count).any(): raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented') - kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p()) + kl = p.total_count * (p.probs * (p.logits - q.logits) + _log1pmtensor(p.logits) - _log1pmtensor(q.logits)) inf_idxs = p.total_count > q.total_count kl[inf_idxs] = _infinite_like(kl[inf_idxs]) return kl @@ -396,6 +396,23 @@ def _kl_beta_uniform(p, q): return result +@register_kl(Binomial, Poisson) +def _kl_binomial_poisson(p, q): + _, (e1, _, e3) = _Elnchoosek(p) + return (e1 - e3 + + p.mean * (p.logits - q.rate.log()) + + p.total_count * _log1pmtensor(p.logits) + + q.rate) + + +@register_kl(Binomial, Geometric) +def _kl_binomial_geometric(p, q): + elnchoosek, _ = _Elnchoosek(p) + return (elnchoosek + + (p.logits - (-q.probs).log1p()) * p.mean + + p.total_count * _log1pmtensor(p.logits) - q.probs.log()) + + @register_kl(Exponential, Beta) @register_kl(Exponential, Pareto) @register_kl(Exponential, Uniform) @@ -468,6 +485,11 @@ def _kl_gamma_normal(p, q): return t1 + (p.concentration - 1) * p.concentration.digamma() + (t2 - t3 + t4) / var_normal +@register_kl(Geometric, Binomial) +def _kl_geometric_infinity(p, q): + return _infinite_like(p.probs) + + @register_kl(Gumbel, Beta) @register_kl(Gumbel, Exponential) @register_kl(Gumbel, Gamma)