Skip to content

Commit

Permalink
Remove special logprob case for MaxNeg (used for Min logprob)
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 19, 2024
1 parent f09e1b4 commit e5899b4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 181 deletions.
193 changes: 43 additions & 150 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,25 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from typing import cast

import pytensor.tensor as pt

from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableOp,
MeasurableOpMixin,
_logcdf_helper,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import find_negated_var
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold

Expand All @@ -73,25 +70,41 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableMax):
return None # pragma: no cover
if isinstance(node.op, MeasurableMax | MeasurableMaxDiscrete):
return None

Check warning on line 74 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L74

Added line #L74 was not covered by tests

base_var = cast(TensorVariable, node.inputs[0])
[base_var] = node.inputs

if base_var.owner is None:
return None

if not rv_map_feature.request_measurable(node.inputs):
return None

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
# We allow Max of RandomVariables or Elemwise of univariate RandomVariables
if isinstance(base_var.owner.op, MeasurableElemwise):
latent_base_vars = [
var
for var in base_var.owner.inputs
if (var.owner and isinstance(var.owner.op, MeasurableOp))
]
if len(latent_base_vars) != 1:
return None

Check warning on line 92 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L92

Added line #L92 was not covered by tests
[latent_base_var] = latent_base_vars
else:
latent_base_var = base_var

latent_op = latent_base_var.owner.op
if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_var.owner.op.dist_params(base_var.owner):
if not all(params.type.broadcastable):
return None
if not all(
all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner)
):
return None

base_var = cast(TensorVariable, base_var)

if node.op.axis is None:
axis = tuple(range(base_var.ndim))
Expand All @@ -102,16 +115,11 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_max: Max
if base_var.type.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(axis)
else:
measurable_max = MeasurableMax(axis)

max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

return max_rv
measurable_max_class = (
MeasurableMaxDiscrete if latent_base_var.type.dtype.startswith("int") else MeasurableMax
)
max_rv = cast(TensorVariable, measurable_max_class(axis)(base_var))
return [max_rv]


measurable_ir_rewrites_db.register(
Expand All @@ -127,13 +135,13 @@ def max_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation."""
(value,) = values

logprob = _logprob_helper(base_rv, value)
logcdf = _logcdf_helper(base_rv, value)
base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False)
bcast_value = pt.broadcast_to(value, base_rv_shape)
logprob = _logprob_helper(base_rv, bcast_value)[0]
logcdf = _logcdf_helper(base_rv, bcast_value)[0]

[n] = constant_fold([base_rv.size])
logprob = (n - 1) * logcdf + logprob + pt.math.log(n)

return logprob
n = pt.prod(base_rv_shape)
return (n - 1) * logcdf + logprob + pt.math.log(n)


@_logprob.register(MeasurableMaxDiscrete)
Expand All @@ -146,126 +154,11 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""
(value,) = values
logcdf = _logcdf_helper(base_rv, value)
logcdf_prev = _logcdf_helper(base_rv, value - 1)

[n] = constant_fold([base_rv.size])

logprob = logdiffexp(n * logcdf, n * logcdf_prev)

return logprob


class MeasurableMaxNeg(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
This shows up in the graph of min, which is (neg(max(neg(x)))."""


class MeasurableDiscreteMaxNeg(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""


@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableMaxNeg):
return None # pragma: no cover

base_var = cast(TensorVariable, node.inputs[0])

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)):
return None

base_rv = find_negated_var(base_var)

# negation is rv * (-1). Hence the scalar_op must be Mul
if base_rv is None:
return None

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_rv.owner.op.dist_params(base_rv.owner):
if not all(params.type.broadcastable):
return None

if node.op.axis is None:
axis = tuple(range(base_var.ndim))
else:
# Check whether axis is supported or not
axis = tuple(sorted(node.op.axis))
if axis != tuple(range(base_var.ndim)):
return None

if not rv_map_feature.request_measurable([base_rv]):
return None

# distinguish measurable discrete and continuous (because logprob is different)
measurable_min: Max
if base_rv.type.dtype.startswith("int"):
measurable_min = MeasurableDiscreteMaxNeg(axis)
else:
measurable_min = MeasurableMaxNeg(axis)

return measurable_min.make_node(base_rv).outputs


measurable_ir_rewrites_db.register(
"find_measurable_max_neg",
find_measurable_max_neg,
"basic",
"min",
)


@_logprob.register(MeasurableMaxNeg)
def max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
"""
(value,) = values

logprob = _logprob_helper(base_rv, -value)
logcdf = _logcdf_helper(base_rv, -value)

[n] = constant_fold([base_rv.size])
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)

return logprob


@_logprob.register(MeasurableDiscreteMaxNeg)
def discrete_max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
.. math::
\ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""

(value,) = values

# The cdf of a negative variable is the survival at the negated value
logcdf = pt.log1mexp(_logcdf_helper(base_rv, -value))
logcdf_prev = pt.log1mexp(_logcdf_helper(base_rv, -(value + 1)))

[n] = constant_fold([base_rv.size])

# Now we can use the same expression as the discrete max
logprob = pt.where(
pt.and_(pt.eq(logcdf, -pt.inf), pt.eq(logcdf_prev, -pt.inf)),
-pt.inf,
logdiffexp(n * logcdf_prev, n * logcdf),
)
base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False)
bcast_value = pt.broadcast_to(value, base_rv_shape)
logcdf = _logcdf_helper(base_rv, bcast_value)[0]
logcdf_prev = _logcdf_helper(base_rv, bcast_value - 1)[0]

return logprob
n = pt.prod(base_rv_shape)
return logdiffexp(n * logcdf, n * logcdf_prev)
45 changes: 14 additions & 31 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def test_argmax():
"""Test whether the logprob for ```pt.argmax``` is correctly rejected"""
x = pt.random.normal(0, 1, size=(3,))
x.name = "x"
x_max = pt.argmax(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_argmax = pt.argmax(x, axis=-1)
x_max_value = pt.scalar("x_max_value", dtype=x_argmax.type.dtype)

with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented for Argmax")):
x_max_logprob = logp(x_max, x_max_value)
logp(x_argmax, x_max_value)


@pytest.mark.parametrize(
Expand All @@ -72,26 +72,9 @@ def test_non_iid_fails(pt_op):
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
x.name = "x"
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
x_m_value = pt.scalar("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)


@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_non_rv_fails(pt_op):
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)
logp(x_m, x_m_value)


@pytest.mark.parametrize(
Expand All @@ -107,9 +90,9 @@ def test_multivariate_rv_fails(pt_op):
x = pm.StickBreakingWeights.dist(_alpha, _k)
x.name = "x"
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
x_m_value = pt.scalar("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)
logp(x_m, x_m_value)


@pytest.mark.parametrize(
Expand All @@ -124,9 +107,9 @@ def test_categorical(pt_op):
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
x.name = "x"
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
x_m_value = pt.scalar("x_value", dtype=x.type.dtype)
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)
logp(x_m, x_m_value)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -230,19 +213,19 @@ def test_min_non_mul_elemwise_fails():
x = pt.log(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_min = pt.min(x, axis=-1)
x_min_value = pt.vector("x_min_value")
x_min_value = pt.scalar("x_min_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_min_logprob = logp(x_min, x_min_value)
logp(x_min, x_min_value)


@pytest.mark.parametrize(
"mu, size, value, axis",
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
)
def test_max_discrete(mu, size, value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
x = pm.Poisson.dist(name="x", mu=mu, size=size)
x_max = pt.max(x, axis=axis)
x_max_value = pt.scalar("x_max_value")
x_max_value = pt.scalar("x_max_value", dtype=x.type.dtype)
x_max_logprob = logp(x_max, x_max_value)

test_value = value
Expand All @@ -265,7 +248,7 @@ def test_max_discrete(mu, size, value, axis):
def test_min_discrete(mu, n, test_value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(n,))
x_min = pt.min(x, axis=axis)
x_min_value = pt.scalar("x_min_value")
x_min_value = pt.scalar("x_min_value", dtype=x.type.dtype)
x_min_logprob = logp(x_min, x_min_value)

sf_before = 1 - sp.poisson(mu).cdf(test_value - 1)
Expand Down

0 comments on commit e5899b4

Please sign in to comment.