diff --git a/test/test_distributions.py b/test/test_distributions.py index b36d56ea5d61c4..1cc58abcb519f8 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -51,7 +51,7 @@ SigmoidTransform, StickBreakingTransform, identity_transform) -from torch.distributions.utils import _finfo, probs_to_logits +from torch.distributions.utils import _finfo, probs_to_logits, softmax TEST_NUMPY = True try: @@ -183,8 +183,8 @@ def is_all_nan(tensor): 'scale': Variable(torch.randn(1).abs(), requires_grad=True), }, { - 'loc': Variable(torch.Tensor([1.0, 0.0])), - 'scale': Variable(torch.Tensor([1e-5, 1e-5])), + 'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True), + 'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True), }, ]), Example(LogNormal, [ @@ -197,8 +197,8 @@ def is_all_nan(tensor): 'scale': Variable(torch.randn(1).abs(), requires_grad=True), }, { - 'loc': torch.Tensor([1.0, 0.0]), - 'scale': torch.Tensor([1e-5, 1e-5]), + 'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True), + 'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True), }, ]), Example(Normal, [ @@ -211,8 +211,8 @@ def is_all_nan(tensor): 'scale': Variable(torch.randn(1).abs(), requires_grad=True), }, { - 'loc': Variable(torch.Tensor([1.0, 0.0])), - 'scale': Variable(torch.Tensor([1e-5, 1e-5])), + 'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True), + 'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True), }, ]), Example(OneHotCategorical, [ @@ -247,17 +247,17 @@ def is_all_nan(tensor): Example(TransformedDistribution, [ { 'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True), - Variable(torch.randn(2, 3), requires_grad=True)), + Variable(torch.randn(2, 3).abs(), requires_grad=True)), 'transforms': [], }, { 'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True), - Variable(torch.randn(2, 3), requires_grad=True)), + Variable(torch.randn(2, 3).abs(), requires_grad=True)), 'transforms': ExpTransform(), }, { 'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True), - Variable(torch.randn(2, 3), requires_grad=True)), + Variable(torch.randn(2, 3).abs(), requires_grad=True)), 'transforms': [AffineTransform(Variable(torch.randn(1)), Variable(torch.randn(1))), ExpTransform()], }, @@ -942,10 +942,10 @@ def test_gamma_sample(self): def test_pareto(self): scale = Variable(torch.randn(2, 3).abs(), requires_grad=True) alpha = Variable(torch.randn(2, 3).abs(), requires_grad=True) - scale_1d = torch.randn(1).abs() - alpha_1d = torch.randn(1).abs() - self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).mean, float('inf'), allow_inf=True) - self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).variance, float('inf'), allow_inf=True) + scale_1d = Variable(torch.randn(1).abs(), requires_grad=True) + alpha_1d = Variable(torch.randn(1).abs(), requires_grad=True) + self.assertEqual(Pareto(scale_1d, 0.5).mean, float('inf'), allow_inf=True) + self.assertEqual(Pareto(scale_1d, 0.5).variance, float('inf'), allow_inf=True) self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3)) self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3)) self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1)) @@ -973,8 +973,8 @@ def test_pareto_sample(self): def test_gumbel(self): loc = Variable(torch.randn(2, 3), requires_grad=True) scale = Variable(torch.randn(2, 3).abs(), requires_grad=True) - loc_1d = torch.randn(1) - scale_1d = torch.randn(1).abs() + loc_1d = Variable(torch.randn(1), requires_grad=True) + scale_1d = Variable(torch.randn(1).abs(), requires_grad=True) self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3)) self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3)) self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,)) @@ -1161,6 +1161,47 @@ def test_beta_sample(self): x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0] self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x)) + def test_cdf_icdf_inverse(self): + # Tests the invertibility property on the distributions + for Dist, params in EXAMPLES: + for i, param in enumerate(params): + dist = Dist(**param) + samples = dist.sample(sample_shape=(20,)) + try: + cdf = dist.cdf(samples) + actual = dist.icdf(cdf) + except NotImplementedError: + continue + rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples)) + self.assertLess(rel_error.max(), 1e-4, msg='\n'.join([ + '{} example {}/{}, icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params)), + 'x = {}'.format(samples), + 'cdf(x) = {}'.format(cdf), + 'icdf(cdf(x)) = {}'.format(actual), + ])) + + def test_cdf_log_prob(self): + # Tests if the differentiation of the CDF gives the PDF at a given value + for Dist, params in EXAMPLES: + for i, param in enumerate(params): + dist = Dist(**param) + samples = dist.sample() + if not samples.requires_grad: + continue + try: + cdfs = dist.cdf(samples) + pdfs = dist.log_prob(samples).exp() + except NotImplementedError: + continue + cdfs_derivative = grad(cdfs.sum(), [samples])[0] + self.assertEqual(cdfs_derivative, pdfs, message='\n'.join([ + '{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, len(params)), + 'x = {}'.format(samples), + 'cdf = {}'.format(cdfs), + 'pdf = {}'.format(pdfs), + 'grad(cdf) = {}'.format(cdfs_derivative), + ])) + def test_valid_parameter_broadcasting(self): # Test correct broadcasting of parameter sizes for distributions that have multiple # parameters. @@ -1741,6 +1782,16 @@ def test_pareto_shape_scalar_params(self): self.assertEqual(pareto.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual(pareto.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + def test_gumbel_shape_scalar_params(self): + gumbel = Gumbel(1, 1) + self.assertEqual(gumbel._batch_shape, torch.Size()) + self.assertEqual(gumbel._event_shape, torch.Size()) + self.assertEqual(gumbel.sample().size(), torch.Size(SCALAR_SHAPE)) + self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2))) + self.assertRaises(ValueError, gumbel.log_prob, self.scalar_sample) + self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) + self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + def test_normal_shape_scalar_params(self): normal = Normal(0, 1) self.assertEqual(normal._batch_shape, torch.Size()) @@ -2279,7 +2330,7 @@ def setUp(self): positive_var2 = Variable(torch.Tensor(20,).normal_()).exp() random_var = Variable(torch.Tensor(20,).normal_()) random_tensor = torch.Tensor(20,).normal_() - simplex_tensor = random_tensor.exp() / random_tensor.exp().sum() + simplex_tensor = softmax(random_tensor) self.distribution_pairs = [ ( Bernoulli(simplex_tensor), @@ -2293,13 +2344,17 @@ def setUp(self): Binomial(10, simplex_tensor), scipy.stats.binom(10 * np.ones(simplex_tensor.shape), simplex_tensor) ), + ( + Cauchy(random_var, positive_var), + scipy.stats.cauchy(loc=random_var, scale=positive_var) + ), ( Dirichlet(positive_var), scipy.stats.dirichlet(positive_var) ), ( Exponential(positive_var), - scipy.stats.expon(scale=1. / positive_var) + scipy.stats.expon(scale=positive_var.reciprocal()) ), ( FisherSnedecor(positive_var, 4 + positive_var2), # var for df2<=4 is undefined @@ -2307,7 +2362,7 @@ def setUp(self): ), ( Gamma(positive_var, positive_var2), - scipy.stats.gamma(positive_var, scale=1 / positive_var2) + scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal()) ), ( Geometric(simplex_tensor), @@ -2358,10 +2413,14 @@ def setUp(self): def test_mean(self): for pytorch_dist, scipy_dist in self.distribution_pairs: + if isinstance(pytorch_dist, Cauchy): + continue self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist) def test_variance_stddev(self): for pytorch_dist, scipy_dist in self.distribution_pairs: + if isinstance(pytorch_dist, Cauchy): + continue if isinstance(pytorch_dist, (Multinomial, OneHotCategorical)): self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist) self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov()) ** 0.5, message=pytorch_dist) @@ -2369,6 +2428,26 @@ def test_variance_stddev(self): self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist) self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist) + def test_cdf(self): + set_rng_seed(0) # see Note [Randomized statistical tests] + for pytorch_dist, scipy_dist in self.distribution_pairs: + samples = pytorch_dist.sample((5,)) + try: + cdf = pytorch_dist.cdf(samples) + except NotImplementedError: + continue + self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist) + + def test_icdf(self): + set_rng_seed(0) # see Note [Randomized statistical tests] + for pytorch_dist, scipy_dist in self.distribution_pairs: + samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape)) + try: + icdf = pytorch_dist.icdf(samples) + except NotImplementedError: + continue + self.assertEqual(icdf, scipy_dist.ppf(samples), message=pytorch_dist) + class TestTransforms(TestCase): def setUp(self): diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 6a3600b637df9e..b3b7f06e4fdfb8 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -53,5 +53,13 @@ def log_prob(self, value): self._validate_log_prob_arg(value) return -math.log(math.pi) - self.scale.log() - (1 + ((value - self.loc) / self.scale)**2).log() + def cdf(self, value): + self._validate_log_prob_arg(value) + return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5 + + def icdf(self, value): + self._validate_log_prob_arg(value) + return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc + def entropy(self): return math.log(4 * math.pi) + self.scale.log() diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 4201def601bd42..ffa6dfbfe2607c 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -101,6 +101,26 @@ def log_prob(self, value): """ raise NotImplementedError + def cdf(self, value): + """ + Returns the cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor or Variable): + """ + raise NotImplementedError + + def icdf(self, value): + """ + Returns the inverse cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor or Variable): + """ + raise NotImplementedError + def enumerate_support(self): """ Returns tensor containing all values supported by a discrete diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index dc0dbe57d4742b..523404675e9093 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -50,6 +50,14 @@ def log_prob(self, value): self._validate_log_prob_arg(value) return self.rate.log() - self.rate * value + def cdf(self, value): + self._validate_log_prob_arg(value) + return 1 - torch.exp(-self.rate * value) + + def icdf(self, value): + self._validate_log_prob_arg(value) + return -torch.log(1 - value) / self.rate + def entropy(self): return 1.0 - torch.log(self.rate) diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index c4f8fcc03d617b..8b8dc0ecee7f2e 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -2,13 +2,15 @@ import math import torch from torch.distributions import constraints -from torch.distributions.distribution import Distribution +from torch.distributions.uniform import Uniform +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, ExpTransform from torch.distributions.utils import _finfo, broadcast_all euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant -class Gumbel(Distribution): +class Gumbel(TransformedDistribution): r""" Samples from a Gumbel Distribution. @@ -23,29 +25,21 @@ class Gumbel(Distribution): loc (float or Tensor or Variable): Location parameter of the distribution scale (float or Tensor or Variable): Scale parameter of the distribution """ - has_rsample = True params = {'loc': constraints.real, 'scale': constraints.positive} support = constraints.real def __init__(self, loc, scale): self.loc, self.scale = broadcast_all(loc, scale) + finfo = _finfo(self.loc) if isinstance(loc, Number) and isinstance(scale, Number): batch_shape = torch.Size() + base_dist = Uniform(finfo.tiny, 1 - finfo.eps) else: batch_shape = self.scale.size() - super(Gumbel, self).__init__(batch_shape) - - def rsample(self, sample_shape=torch.Size()): - shape = self._extended_shape(sample_shape) - uni_dist = self.scale.new(shape).uniform_(_finfo(self.scale).eps, 1) - # X ~ Uniform(0, 1) - # Y = loc - scale * ln (-ln (X)) ~ Gumbel(loc, scale) - return self.loc - self.scale * torch.log(-uni_dist.log()) - - def log_prob(self, value): - self._validate_log_prob_arg(value) - z = (value - self.loc) / self.scale - return -(self.scale.log() + z + torch.exp(-z)) + base_dist = Uniform(self.loc.new(self.loc.size()).fill_(finfo.tiny), 1 - finfo.eps) + transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), + ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)] + super(Gumbel, self).__init__(base_dist, transforms) @property def mean(self): diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index dc36e9b8374956..8720cedc3584a7 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -65,6 +65,14 @@ def log_prob(self, value): log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log() return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) + def cdf(self, value): + self._validate_log_prob_arg(value) + return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2))) + + def icdf(self, value): + self._validate_log_prob_arg(value) + return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) + def entropy(self): return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale) diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 15854b9da0636e..56dd2a7ca443d5 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -4,11 +4,13 @@ import torch from torch.distributions import constraints -from torch.distributions.distribution import Distribution +from torch.distributions.exponential import Exponential +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, ExpTransform from torch.distributions.utils import broadcast_all -class Pareto(Distribution): +class Pareto(TransformedDistribution): r""" Samples from a Pareto Type 1 distribution. @@ -23,16 +25,13 @@ class Pareto(Distribution): scale (float or Tensor or Variable): Scale parameter of the distribution alpha (float or Tensor or Variable): Shape parameter of the distribution """ - has_rsample = True params = {'alpha': constraints.positive, 'scale': constraints.positive} def __init__(self, scale, alpha): self.scale, self.alpha = broadcast_all(scale, alpha) - if isinstance(scale, Number) and isinstance(alpha, Number): - batch_shape = torch.Size() - else: - batch_shape = self.scale.size() - super(Pareto, self).__init__(batch_shape) + base_dist = Exponential(alpha) + transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)] + super(Pareto, self).__init__(base_dist, transforms) @property def mean(self): @@ -50,14 +49,5 @@ def variance(self): def support(self): return constraints.greater_than(self.scale) - def rsample(self, sample_shape=torch.Size()): - shape = self._extended_shape(sample_shape) - exp_dist = self.alpha.new(shape).exponential_() - return self.scale * torch.exp(exp_dist / self.alpha) - - def log_prob(self, value): - self._validate_log_prob_arg(value) - return torch.log(self.alpha / value) + self.alpha * (self.scale / value).log() - def entropy(self): return ((self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())) diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 2db3b568c21880..2b463323d0f296 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -74,6 +74,7 @@ def log_prob(self, value): Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian """ + self.base_dist._validate_log_prob_arg(value) event_dim = len(self.event_shape) log_prob = 0.0 y = value @@ -85,3 +86,23 @@ def log_prob(self, value): log_prob += _sum_rightmost(self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)) return log_prob + + def cdf(self, value): + """ + Computes the cumulative distribution function by inverting the transform(s) and computing + the score of the base distribution + """ + self.base_dist._validate_log_prob_arg(value) + for transform in self.transforms[::-1]: + value = transform.inv(value) + return self.base_dist.cdf(value) + + def icdf(self, value): + """ + Computes the inverse cumulative distribution function using transform(s) and computing + the score of the base distribution + """ + value = self.base_dist.icdf(value) + for transform in self.transforms: + value = transform(value) + return value diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index 1d233a3ad5d644..9750511af02ec3 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -63,5 +63,15 @@ def log_prob(self, value): ub = value.lt(self.high).type_as(self.low) return torch.log(lb.mul(ub)) - torch.log(self.high - self.low) + def cdf(self, value): + self._validate_log_prob_arg(value) + result = (value - self.low) / (self.high - self.low) + return result + + def icdf(self, value): + self._validate_log_prob_arg(value) + result = value * (self.high - self.low) + self.low + return result + def entropy(self): return torch.log(self.high - self.low)