Skip to content

Commit

Permalink
Improve Categorical and Multinomial checks
Browse files Browse the repository at this point in the history
- Fixes bug that occurred when constructing probability parameters from lists or tuples of Aesara variables
- Improves logp checks for valid probabilty parameters
  • Loading branch information
ricardoV94 committed Nov 7, 2022
1 parent 14aa3d0 commit 7608e30
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 34 deletions.
27 changes: 18 additions & 9 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import aesara.tensor as at
import numpy as np

from aesara.tensor import TensorConstant
from aesara.tensor.random.basic import (
RandomVariable,
ScipyRandomVariable,
Expand Down Expand Up @@ -1285,17 +1286,21 @@ def dist(cls, p=None, logit_p=None, **kwargs):
if logit_p is not None:
p = pm.math.softmax(logit_p, axis=-1)

if isinstance(p, np.ndarray) or isinstance(p, list):
if (np.asarray(p) < 0).any():
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
p_sum = np.sum([p], axis=-1)
if not np.all(np.isclose(p_sum, 1.0)):
p = at.as_tensor_variable(p)
if isinstance(p, TensorConstant):
p_ = np.asarray(p.data)
if np.any(p_ < 0):
raise ValueError(f"Negative `p` parameters are not valid, got: {p_}")
p_sum_ = np.sum([p_], axis=-1)
if not np.all(np.isclose(p_sum_, 1.0)):
warnings.warn(
f"`p` parameters sum to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
f"`p` parameters sum to {p_sum_}, instead of 1.0. "
"They will be automatically rescaled. "
"You can rescale them directly to get rid of this warning.",
UserWarning,
)
p = p / at.sum(p, axis=-1, keepdims=True)
p = at.as_tensor_variable(floatX(p))
p_ = p_ / at.sum(p_, axis=-1, keepdims=True)
p = at.as_tensor_variable(p_)
return super().dist([p], **kwargs)

def moment(rv, size, p):
Expand Down Expand Up @@ -1341,7 +1346,11 @@ def logp(value, p):
)

return check_parameters(
res, at.all(p_ >= 0, axis=-1), at.all(p <= 1, axis=-1), msg="0 <= p <=1"
res,
p_ >= 0,
p_ <= 1,
at.isclose(at.sum(p, axis=-1), 1),
msg="0 <= p <=1, sum(p) = 1",
)


Expand Down
24 changes: 15 additions & 9 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from aesara.graph.op import Op
from aesara.raise_op import Assert
from aesara.sparse.basic import sp_sum
from aesara.tensor import gammaln, sigmoid
from aesara.tensor import TensorConstant, gammaln, sigmoid
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
from aesara.tensor.random.basic import dirichlet, multinomial, multivariate_normal
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
Expand Down Expand Up @@ -543,16 +543,21 @@ class Multinomial(Discrete):

@classmethod
def dist(cls, n, p, *args, **kwargs):
if isinstance(p, np.ndarray) or isinstance(p, list):
if (np.asarray(p) < 0).any():
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
p_sum = np.sum([p], axis=-1)
if not np.all(np.isclose(p_sum, 1.0)):
p = at.as_tensor_variable(p)
if isinstance(p, TensorConstant):
p_ = np.asarray(p.data)
if np.any(p_ < 0):
raise ValueError(f"Negative `p` parameters are not valid, got: {p_}")
p_sum_ = np.sum([p_], axis=-1)
if not np.all(np.isclose(p_sum_, 1.0)):
warnings.warn(
f"`p` parameters sum up to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
f"`p` parameters sum to {p_sum_}, instead of 1.0. "
"They will be automatically rescaled. "
"You can rescale them directly to get rid of this warning.",
UserWarning,
)
p = p / at.sum(p, axis=-1, keepdims=True)
p_ = p_ / at.sum(p_, axis=-1, keepdims=True)
p = at.as_tensor_variable(p_)
n = at.as_tensor_variable(n)
p = at.as_tensor_variable(p)
return super().dist([n, p], *args, **kwargs)
Expand Down Expand Up @@ -591,10 +596,11 @@ def logp(value, n, p):
)
return check_parameters(
res,
p >= 0,
p <= 1,
at.isclose(at.sum(p, axis=-1), 1),
at.ge(n, 0),
msg="p <= 1, sum(p) = 1, n >= 0",
msg="0 <= p <= 1, sum(p) = 1, n >= 0",
)


Expand Down
23 changes: 15 additions & 8 deletions pymc/tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,32 +497,39 @@ def test_categorical_bounds(self):
# entries if there is a single or pair number of negative values
# and the rest are zero
np.array([-1, -1, 0, 0]),
at.as_tensor_variable([-1, -1, 0, 0]),
],
)
def test_categorical_negative_p(self, p):
with pytest.raises(ValueError, match=f"{p}"):
with pytest.raises(ValueError, match="Negative `p` parameters are not valid"):
with pm.Model():
x = pm.Categorical("x", p=p)

def test_categorical_p_not_normalized(self):
# test UserWarning is raised for p vals that sum to more than 1
# and normaliation is triggered
with pytest.warns(UserWarning, match="[5]"):
with pytest.warns(UserWarning, match="They will be automatically rescaled"):
with pm.Model() as m:
x = pm.Categorical("x", p=[1, 1, 1, 1, 1])
assert np.isclose(m.x.owner.inputs[3].sum().eval(), 1.0)

def test_categorical_negative_p_symbolic(self):
value = np.array([[1, 1, 1]])

x = at.scalar("x")
invalid_dist = pm.Categorical.dist(p=[x, x, x])

with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([-1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()
pm.logp(invalid_dist, value).eval({x: -1 / 3})

def test_categorical_p_not_normalized_symbolic(self):
value = np.array([[1, 1, 1]])

x = at.scalar("x")
invalid_dist = pm.Categorical.dist(p=(x, x, x))

with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([2, 2, 2]))
pm.logp(invalid_dist, value).eval()
pm.logp(invalid_dist, value).eval({x: 0.5})

@pytest.mark.parametrize("n", [2, 3, 4])
def test_orderedlogistic(self, n):
Expand Down
21 changes: 13 additions & 8 deletions pymc/tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,14 +548,14 @@ def test_multinomial_invalid_value(self):

def test_multinomial_negative_p(self):
# test passing a list/numpy with negative p raises an immediate error
with pytest.raises(ValueError, match="[-1, 1, 1]"):
with pytest.raises(ValueError, match="Negative `p` parameters are not valid"):
with pm.Model() as model:
x = pm.Multinomial("x", n=5, p=[-1, 1, 1])

def test_multinomial_p_not_normalized(self):
# test UserWarning is raised for p vals that sum to more than 1
# and normaliation is triggered
with pytest.warns(UserWarning, match="[5]"):
with pytest.warns(UserWarning, match="They will be automatically rescaled"):
with pm.Model() as m:
x = pm.Multinomial("x", n=5, p=[1, 1, 1, 1, 1])
# test stored p-vals have been normalised
Expand All @@ -564,18 +564,23 @@ def test_multinomial_p_not_normalized(self):
def test_multinomial_negative_p_symbolic(self):
# Passing symbolic negative p does not raise an immediate error, but evaluating
# logp raises a ParameterValueError
value = np.array([[1, 1, 1]])

x = at.scalar("x")
invalid_dist = pm.Multinomial.dist(n=1, p=[x, x, x])

with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([-1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()
pm.logp(invalid_dist, value).eval({x: -1 / 3})

def test_multinomial_p_not_normalized_symbolic(self):
# Passing symbolic p that do not add up to on does not raise any warning, but evaluating
# logp raises a ParameterValueError
value = np.array([[1, 1, 1]])

x = at.scalar("x")
invalid_dist = pm.Multinomial.dist(n=1, p=(x, x, x))
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()
pm.logp(invalid_dist, value).eval({x: 0.5})

@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
@pytest.mark.parametrize(
Expand Down

0 comments on commit 7608e30

Please sign in to comment.