Skip to content

Add Cumulative Distribution Function, Inverse CDF methods to Distributions #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
85 changes: 73 additions & 12 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def is_all_nan(tensor):
'scale': Variable(torch.randn(1).abs(), requires_grad=True),
},
{
'loc': Variable(torch.Tensor([1.0, 0.0])),
'scale': Variable(torch.Tensor([1e-5, 1e-5])),
'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True),
'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True),
},
]),
Example(LogNormal, [
Expand All @@ -197,8 +197,8 @@ def is_all_nan(tensor):
'scale': Variable(torch.randn(1).abs(), requires_grad=True),
},
{
'loc': torch.Tensor([1.0, 0.0]),
'scale': torch.Tensor([1e-5, 1e-5]),
'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True),
'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True),
},
]),
Example(Normal, [
Expand All @@ -211,8 +211,8 @@ def is_all_nan(tensor):
'scale': Variable(torch.randn(1).abs(), requires_grad=True),
},
{
'loc': Variable(torch.Tensor([1.0, 0.0])),
'scale': Variable(torch.Tensor([1e-5, 1e-5])),
'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True),
'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True),
},
]),
Example(OneHotCategorical, [
Expand Down Expand Up @@ -942,10 +942,10 @@ def test_gamma_sample(self):
def test_pareto(self):
scale = Variable(torch.randn(2, 3).abs(), requires_grad=True)
alpha = Variable(torch.randn(2, 3).abs(), requires_grad=True)
scale_1d = torch.randn(1).abs()
alpha_1d = torch.randn(1).abs()
self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).mean, float('inf'), allow_inf=True)
self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).variance, float('inf'), allow_inf=True)
scale_1d = Variable(torch.randn(1).abs(), requires_grad=True)
alpha_1d = Variable(torch.randn(1).abs(), requires_grad=True)
self.assertEqual(Pareto(scale_1d, 0.5).mean, float('inf'), allow_inf=True)
self.assertEqual(Pareto(scale_1d, 0.5).variance, float('inf'), allow_inf=True)
self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3))
self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1))
Expand Down Expand Up @@ -973,8 +973,8 @@ def test_pareto_sample(self):
def test_gumbel(self):
loc = Variable(torch.randn(2, 3), requires_grad=True)
scale = Variable(torch.randn(2, 3).abs(), requires_grad=True)
loc_1d = torch.randn(1)
scale_1d = torch.randn(1).abs()
loc_1d = Variable(torch.randn(1), requires_grad=True)
scale_1d = Variable(torch.randn(1).abs(), requires_grad=True)
self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3))
self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3))
self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,))
Expand Down Expand Up @@ -1161,6 +1161,39 @@ def test_beta_sample(self):
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))

def test_cdf_icdf_inverse(self):
# Tests the invertibility property on the distributions
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
samples = dist.sample(sample_shape=(20,))
try:
cdf = dist.cdf(samples)
actual = dist.icdf(cdf)
except NotImplementedError:
continue
self.assertEqual(actual, samples,
message='{} example {}/{},\
icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params)))

def test_cdf_log_prob(self):
# Tests if the differentiation of the CDF gives the PDF at a given value
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
samples = dist.sample(sample_shape=(20,))
if not samples.requires_grad:
continue
try:
cdfs = dist.cdf(samples)
pdfs = dist.log_prob(samples).exp()
except NotImplementedError:
continue
cdfs_derivative = grad(cdfs.sum(), [samples])[0]
self.assertEqual(cdfs_derivative, pdfs,
message='{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1,
len(params)))

def test_valid_parameter_broadcasting(self):
# Test correct broadcasting of parameter sizes for distributions that have multiple
# parameters.
Expand Down Expand Up @@ -2293,6 +2326,10 @@ def setUp(self):
Binomial(10, simplex_tensor),
scipy.stats.binom(10 * np.ones(simplex_tensor.shape), simplex_tensor)
),
(
Cauchy(random_var, positive_var),
scipy.stats.cauchy(loc=random_var, scale=positive_var)
),
(
Dirichlet(positive_var),
scipy.stats.dirichlet(positive_var)
Expand Down Expand Up @@ -2358,17 +2395,41 @@ def setUp(self):

def test_mean(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
if isinstance(pytorch_dist, Cauchy):
continue
self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist)

def test_variance_stddev(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
if isinstance(pytorch_dist, Cauchy):
continue
if isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist)
self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov()) ** 0.5, message=pytorch_dist)
else:
self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist)
self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist)

def test_cdf(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have an additional test that did not rely on scipy, e.g.

class TestDistributions(TestCase):
    def test_cdf_icdf(self):
        for Dist, params in EXAMPLES:
            for i, param in enumerate(params):
                dist = Dist(**param)
                samples = dist.sample(sample_shape=(20,))
                try:
                    cdf = dist.cdf(samples)
                    actual = dist.icdf(cdf)
                except NotImplementedError:
                    continue
                self.assertEqual(actual, samples, message='{} example {}/{}, icdf(cdf(x)) != x')

or you could get even fancier by using grad() like

x = dist.sample(sample_shape=(20,))
expected_pdf = dist.log_prob(x).exp()
actual_pdf = grad(dist.cdf(x).sum(), [x])[0]
self.assertEqual(actual_pdf, expected_pdf)

set_rng_seed(0) # see Note [Randomized statistical tests]
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = pytorch_dist.sample((5,))
try:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's safer to enclose as little as needed in a try-except. Could you refactor to

try:
    cdf = pytorch_dist.cdf(samples)
except NotImplementedError:
    continue
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. I saw the discussion in TruncatedNormal. I will modify it accordingly.

cdf = pytorch_dist.cdf(samples)
except NotImplementedError:
continue
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)

def test_icdf(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape))
try:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, enclose as little as possible in try-except

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

icdf = pytorch_dist.icdf(samples)
except NotImplementedError:
continue
self.assertEqual(icdf, scipy_dist.ppf(samples), message=pytorch_dist)


class TestTransforms(TestCase):
def setUp(self):
Expand Down
8 changes: 8 additions & 0 deletions torch/distributions/cauchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,13 @@ def log_prob(self, value):
self._validate_log_prob_arg(value)
return -math.log(math.pi) - self.scale.log() - (1 + ((value - self.loc) / self.scale)**2).log()

def cdf(self, value):
self._validate_log_prob_arg(value)
return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5

def icdf(self, value):
self._validate_log_prob_arg(value)
return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc

def entropy(self):
return math.log(4 * math.pi) + self.scale.log()
20 changes: 20 additions & 0 deletions torch/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@ def log_prob(self, value):
"""
raise NotImplementedError

def cdf(self, value):
"""
Returns the cumulative density/mass function evaluated at
`value`.

Args:
value (Tensor or Variable):
"""
raise NotImplementedError

def icdf(self, value):
"""
Returns the inverse cumulative density/mass function evaluated at
`value`.

Args:
value (Tensor or Variable):
"""
raise NotImplementedError

def enumerate_support(self):
"""
Returns tensor containing all values supported by a discrete
Expand Down
8 changes: 8 additions & 0 deletions torch/distributions/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def log_prob(self, value):
self._validate_log_prob_arg(value)
return self.rate.log() - self.rate * value

def cdf(self, value):
self._validate_log_prob_arg(value)
return 1 - torch.exp(-self.rate * value)

def icdf(self, value):
self._validate_log_prob_arg(value)
return -torch.log(1 - value) / self.rate

def entropy(self):
return 1.0 - torch.log(self.rate)

Expand Down
24 changes: 8 additions & 16 deletions torch/distributions/gumbel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import math
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.uniform import Uniform
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, ExpTransform
from torch.distributions.utils import _finfo, broadcast_all

euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant


class Gumbel(Distribution):
class Gumbel(TransformedDistribution):
r"""
Samples from a Gumbel Distribution.

Expand All @@ -23,7 +25,6 @@ class Gumbel(Distribution):
loc (float or Tensor or Variable): Location parameter of the distribution
scale (float or Tensor or Variable): Scale parameter of the distribution
"""
has_rsample = True
params = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real

Expand All @@ -33,19 +34,10 @@ def __init__(self, loc, scale):
batch_shape = torch.Size()
else:
batch_shape = self.scale.size()
super(Gumbel, self).__init__(batch_shape)

def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uni_dist = self.scale.new(shape).uniform_(_finfo(self.scale).eps, 1)
# X ~ Uniform(0, 1)
# Y = loc - scale * ln (-ln (X)) ~ Gumbel(loc, scale)
return self.loc - self.scale * torch.log(-uni_dist.log())

def log_prob(self, value):
self._validate_log_prob_arg(value)
z = (value - self.loc) / self.scale
return -(self.scale.log() + z + torch.exp(-z))
base_dist = Uniform(torch.zeros_like(self.loc), 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should avoid infinity like

finfo = _finfo(self.loc)
base_dist = Uniform(self.loc.new([finfo.tiny]).expand_as(self.loc), 1 - finfo.eps) 

transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-1),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

ExpTransform().inv, AffineTransform(loc=loc, scale=-scale)]
super(Gumbel, self).__init__(base_dist, transforms)

@property
def mean(self):
Expand Down
8 changes: 8 additions & 0 deletions torch/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ def log_prob(self, value):
log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))

def cdf(self, value):
self._validate_log_prob_arg(value)
return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)))

def icdf(self, value):
self._validate_log_prob_arg(value)
return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)

def entropy(self):
return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)

Expand Down
20 changes: 7 additions & 13 deletions torch/distributions/pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.exponential import Exponential
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, ExpTransform
from torch.distributions.utils import broadcast_all


class Pareto(Distribution):
class Pareto(TransformedDistribution):
r"""
Samples from a Pareto Type 1 distribution.

Expand All @@ -23,7 +25,6 @@ class Pareto(Distribution):
scale (float or Tensor or Variable): Scale parameter of the distribution
alpha (float or Tensor or Variable): Shape parameter of the distribution
"""
has_rsample = True
params = {'alpha': constraints.positive, 'scale': constraints.positive}

def __init__(self, scale, alpha):
Expand All @@ -32,7 +33,9 @@ def __init__(self, scale, alpha):
batch_shape = torch.Size()
else:
batch_shape = self.scale.size()
super(Pareto, self).__init__(batch_shape)
base_dist = Exponential(alpha)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)]
super(Pareto, self).__init__(base_dist, transforms)

@property
def mean(self):
Expand All @@ -50,14 +53,5 @@ def variance(self):
def support(self):
return constraints.greater_than(self.scale)

def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
exp_dist = self.alpha.new(shape).exponential_()
return self.scale * torch.exp(exp_dist / self.alpha)

def log_prob(self, value):
self._validate_log_prob_arg(value)
return torch.log(self.alpha / value) + self.alpha * (self.scale / value).log()

def entropy(self):
return ((self.scale / self.alpha).log() + (1 + self.alpha.reciprocal()))
21 changes: 21 additions & 0 deletions torch/distributions/transformed_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,24 @@ def log_prob(self, value):
log_prob += _sum_rightmost(self.base_dist.log_prob(y),
event_dim - len(self.base_dist.event_shape))
return log_prob

def cdf(self, value):
"""
Computes the cumulative distribution function by inverting the transform(s) and computing
the score of the base distribution
"""
self.base_dist._validate_log_prob_arg(value)
for transform in self.transforms[::-1]:
value = transform.inv(value)
return self.base_dist.cdf(value)

def icdf(self, value):
"""
Computes the inverse cumulative distribution function using transform(s) and computing
the score of the base distribution
"""
self.base_dist._validate_log_prob_arg(value)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the base_dist.icdf() should call _validate_log_prob_arg(value) internally on the following line. Do you think it's worth having the extra check here? I'd be happy either way.

value = self.base_dist.icdf(value)
for transform in self.transforms:
value = transform(value)
return value
10 changes: 10 additions & 0 deletions torch/distributions/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,15 @@ def log_prob(self, value):
ub = value.lt(self.high).type_as(self.low)
return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)

def cdf(self, value):
self._validate_log_prob_arg(value)
result = (value - self.low) / (self.high - self.low)
return result

def icdf(self, value):
self._validate_log_prob_arg(value)
result = value * (self.high - self.low) + self.low
return result

def entropy(self):
return torch.log(self.high - self.low)