-
Notifications
You must be signed in to change notification settings - Fork 1
Add Cumulative Distribution Function, Inverse CDF methods to Distributions #122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e4e58e2
d9263a6
ee55d13
f76114a
d0bc72c
864d190
a9d74ce
24c2461
2932ffe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,7 +51,7 @@ | |
SigmoidTransform, | ||
StickBreakingTransform, | ||
identity_transform) | ||
from torch.distributions.utils import _finfo, probs_to_logits | ||
from torch.distributions.utils import _finfo, probs_to_logits, softmax | ||
|
||
TEST_NUMPY = True | ||
try: | ||
|
@@ -183,8 +183,8 @@ def is_all_nan(tensor): | |
'scale': Variable(torch.randn(1).abs(), requires_grad=True), | ||
}, | ||
{ | ||
'loc': Variable(torch.Tensor([1.0, 0.0])), | ||
'scale': Variable(torch.Tensor([1e-5, 1e-5])), | ||
'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True), | ||
'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True), | ||
}, | ||
]), | ||
Example(LogNormal, [ | ||
|
@@ -197,8 +197,8 @@ def is_all_nan(tensor): | |
'scale': Variable(torch.randn(1).abs(), requires_grad=True), | ||
}, | ||
{ | ||
'loc': torch.Tensor([1.0, 0.0]), | ||
'scale': torch.Tensor([1e-5, 1e-5]), | ||
'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True), | ||
'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True), | ||
}, | ||
]), | ||
Example(Normal, [ | ||
|
@@ -211,8 +211,8 @@ def is_all_nan(tensor): | |
'scale': Variable(torch.randn(1).abs(), requires_grad=True), | ||
}, | ||
{ | ||
'loc': Variable(torch.Tensor([1.0, 0.0])), | ||
'scale': Variable(torch.Tensor([1e-5, 1e-5])), | ||
'loc': Variable(torch.Tensor([1.0, 0.0]), requires_grad=True), | ||
'scale': Variable(torch.Tensor([1e-5, 1e-5]), requires_grad=True), | ||
}, | ||
]), | ||
Example(OneHotCategorical, [ | ||
|
@@ -247,17 +247,17 @@ def is_all_nan(tensor): | |
Example(TransformedDistribution, [ | ||
{ | ||
'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True), | ||
Variable(torch.randn(2, 3), requires_grad=True)), | ||
Variable(torch.randn(2, 3).abs(), requires_grad=True)), | ||
'transforms': [], | ||
}, | ||
{ | ||
'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True), | ||
Variable(torch.randn(2, 3), requires_grad=True)), | ||
Variable(torch.randn(2, 3).abs(), requires_grad=True)), | ||
'transforms': ExpTransform(), | ||
}, | ||
{ | ||
'base_distribution': Normal(Variable(torch.randn(2, 3), requires_grad=True), | ||
Variable(torch.randn(2, 3), requires_grad=True)), | ||
Variable(torch.randn(2, 3).abs(), requires_grad=True)), | ||
'transforms': [AffineTransform(Variable(torch.randn(1)), Variable(torch.randn(1))), | ||
ExpTransform()], | ||
}, | ||
|
@@ -942,10 +942,10 @@ def test_gamma_sample(self): | |
def test_pareto(self): | ||
scale = Variable(torch.randn(2, 3).abs(), requires_grad=True) | ||
alpha = Variable(torch.randn(2, 3).abs(), requires_grad=True) | ||
scale_1d = torch.randn(1).abs() | ||
alpha_1d = torch.randn(1).abs() | ||
self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).mean, float('inf'), allow_inf=True) | ||
self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).variance, float('inf'), allow_inf=True) | ||
scale_1d = Variable(torch.randn(1).abs(), requires_grad=True) | ||
alpha_1d = Variable(torch.randn(1).abs(), requires_grad=True) | ||
self.assertEqual(Pareto(scale_1d, 0.5).mean, float('inf'), allow_inf=True) | ||
self.assertEqual(Pareto(scale_1d, 0.5).variance, float('inf'), allow_inf=True) | ||
self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3)) | ||
self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3)) | ||
self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1)) | ||
|
@@ -973,8 +973,8 @@ def test_pareto_sample(self): | |
def test_gumbel(self): | ||
loc = Variable(torch.randn(2, 3), requires_grad=True) | ||
scale = Variable(torch.randn(2, 3).abs(), requires_grad=True) | ||
loc_1d = torch.randn(1) | ||
scale_1d = torch.randn(1).abs() | ||
loc_1d = Variable(torch.randn(1), requires_grad=True) | ||
scale_1d = Variable(torch.randn(1).abs(), requires_grad=True) | ||
self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3)) | ||
self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3)) | ||
self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,)) | ||
|
@@ -1161,6 +1161,47 @@ def test_beta_sample(self): | |
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0] | ||
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x)) | ||
|
||
def test_cdf_icdf_inverse(self): | ||
# Tests the invertibility property on the distributions | ||
for Dist, params in EXAMPLES: | ||
for i, param in enumerate(params): | ||
dist = Dist(**param) | ||
samples = dist.sample(sample_shape=(20,)) | ||
try: | ||
cdf = dist.cdf(samples) | ||
actual = dist.icdf(cdf) | ||
except NotImplementedError: | ||
continue | ||
rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples)) | ||
self.assertLess(rel_error.max(), 1e-4, msg='\n'.join([ | ||
'{} example {}/{}, icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params)), | ||
'x = {}'.format(samples), | ||
'cdf(x) = {}'.format(cdf), | ||
'icdf(cdf(x)) = {}'.format(actual), | ||
])) | ||
|
||
def test_cdf_log_prob(self): | ||
# Tests if the differentiation of the CDF gives the PDF at a given value | ||
for Dist, params in EXAMPLES: | ||
for i, param in enumerate(params): | ||
dist = Dist(**param) | ||
samples = dist.sample() | ||
if not samples.requires_grad: | ||
continue | ||
try: | ||
cdfs = dist.cdf(samples) | ||
pdfs = dist.log_prob(samples).exp() | ||
except NotImplementedError: | ||
continue | ||
cdfs_derivative = grad(cdfs.sum(), [samples])[0] | ||
self.assertEqual(cdfs_derivative, pdfs, message='\n'.join([ | ||
'{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, len(params)), | ||
'x = {}'.format(samples), | ||
'cdf = {}'.format(cdfs), | ||
'pdf = {}'.format(pdfs), | ||
'grad(cdf) = {}'.format(cdfs_derivative), | ||
])) | ||
|
||
def test_valid_parameter_broadcasting(self): | ||
# Test correct broadcasting of parameter sizes for distributions that have multiple | ||
# parameters. | ||
|
@@ -1741,6 +1782,16 @@ def test_pareto_shape_scalar_params(self): | |
self.assertEqual(pareto.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) | ||
self.assertEqual(pareto.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) | ||
|
||
def test_gumbel_shape_scalar_params(self): | ||
gumbel = Gumbel(1, 1) | ||
self.assertEqual(gumbel._batch_shape, torch.Size()) | ||
self.assertEqual(gumbel._event_shape, torch.Size()) | ||
self.assertEqual(gumbel.sample().size(), torch.Size(SCALAR_SHAPE)) | ||
self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2))) | ||
self.assertRaises(ValueError, gumbel.log_prob, self.scalar_sample) | ||
self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) | ||
self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) | ||
|
||
def test_normal_shape_scalar_params(self): | ||
normal = Normal(0, 1) | ||
self.assertEqual(normal._batch_shape, torch.Size()) | ||
|
@@ -2279,7 +2330,7 @@ def setUp(self): | |
positive_var2 = Variable(torch.Tensor(20,).normal_()).exp() | ||
random_var = Variable(torch.Tensor(20,).normal_()) | ||
random_tensor = torch.Tensor(20,).normal_() | ||
simplex_tensor = random_tensor.exp() / random_tensor.exp().sum() | ||
simplex_tensor = softmax(random_tensor) | ||
self.distribution_pairs = [ | ||
( | ||
Bernoulli(simplex_tensor), | ||
|
@@ -2293,21 +2344,25 @@ def setUp(self): | |
Binomial(10, simplex_tensor), | ||
scipy.stats.binom(10 * np.ones(simplex_tensor.shape), simplex_tensor) | ||
), | ||
( | ||
Cauchy(random_var, positive_var), | ||
scipy.stats.cauchy(loc=random_var, scale=positive_var) | ||
), | ||
( | ||
Dirichlet(positive_var), | ||
scipy.stats.dirichlet(positive_var) | ||
), | ||
( | ||
Exponential(positive_var), | ||
scipy.stats.expon(scale=1. / positive_var) | ||
scipy.stats.expon(scale=positive_var.reciprocal()) | ||
), | ||
( | ||
FisherSnedecor(positive_var, 4 + positive_var2), # var for df2<=4 is undefined | ||
scipy.stats.f(positive_var, 4 + positive_var2) | ||
), | ||
( | ||
Gamma(positive_var, positive_var2), | ||
scipy.stats.gamma(positive_var, scale=1 / positive_var2) | ||
scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal()) | ||
), | ||
( | ||
Geometric(simplex_tensor), | ||
|
@@ -2358,17 +2413,41 @@ def setUp(self): | |
|
||
def test_mean(self): | ||
for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
if isinstance(pytorch_dist, Cauchy): | ||
continue | ||
self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist) | ||
|
||
def test_variance_stddev(self): | ||
for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
if isinstance(pytorch_dist, Cauchy): | ||
continue | ||
if isinstance(pytorch_dist, (Multinomial, OneHotCategorical)): | ||
self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist) | ||
self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov()) ** 0.5, message=pytorch_dist) | ||
else: | ||
self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist) | ||
self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist) | ||
|
||
def test_cdf(self): | ||
set_rng_seed(0) # see Note [Randomized statistical tests] | ||
for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
samples = pytorch_dist.sample((5,)) | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's safer to enclose as little as needed in a try-except. Could you refactor to try:
cdf = pytorch_dist.cdf(samples)
except NotImplementedError:
continue
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, yes. I saw the discussion in TruncatedNormal. I will modify it accordingly. |
||
cdf = pytorch_dist.cdf(samples) | ||
except NotImplementedError: | ||
continue | ||
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist) | ||
|
||
def test_icdf(self): | ||
set_rng_seed(0) # see Note [Randomized statistical tests] | ||
for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape)) | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto, enclose as little as possible in try-except There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. |
||
icdf = pytorch_dist.icdf(samples) | ||
except NotImplementedError: | ||
continue | ||
self.assertEqual(icdf, scipy_dist.ppf(samples), message=pytorch_dist) | ||
|
||
|
||
class TestTransforms(TestCase): | ||
def setUp(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to have an additional test that did not rely on scipy, e.g.
or you could get even fancier by using
grad()
like