From fec9eeb998e68aac2207894854d67c92046d036e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 15 Sep 2024 20:41:33 +0200 Subject: [PATCH] Support more kinds of marginalization via dim analysis This commit lifts the restriction that only Elemwise operations may link marginalized to dependent RVs. We map input dims to output dims, to assess whether an operation mixes information from different dims or not. Graphs where information is not mixed can be efficiently marginalized. --- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- .../model/marginal/distributions.py | 179 ++++++--- .../model/marginal/graph_analysis.py | 364 ++++++++++++++++-- .../model/marginal/marginal_model.py | 129 +++---- requirements.txt | 2 +- tests/model/marginal/test_distributions.py | 52 ++- tests/model/marginal/test_graph_analysis.py | 172 ++++++++- tests/model/marginal/test_marginal_model.py | 240 ++++++++++-- 9 files changed, 926 insertions(+), 216 deletions(-) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 360a8199..4deda063 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -10,6 +10,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.16.1 # CI was failing to resolve + - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 360a8199..4deda063 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -10,6 +10,6 @@ dependencies: - xhistogram - statsmodels - pip: - - pymc>=5.16.1 # CI was failing to resolve + - pymc>=5.17.0 # CI was failing to resolve - blackjax - scikit-learn diff --git a/pymc_experimental/model/marginal/distributions.py b/pymc_experimental/model/marginal/distributions.py index a3a2adbe..661665e9 100644 --- a/pymc_experimental/model/marginal/distributions.py +++ b/pymc_experimental/model/marginal/distributions.py @@ -1,30 +1,55 @@ from collections.abc import Sequence import numpy as np +import pytensor.tensor as pt -from pymc import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable, logp -from pymc.logprob import conditional_logp -from pymc.logprob.abstract import _logprob +from pymc.distributions import Bernoulli, Categorical, DiscreteUniform +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.logprob.basic import conditional_logp, logp from pymc.pytensorf import constant_fold -from pytensor import Mode, clone_replace, graph_replace, scan -from pytensor import map as scan_map -from pytensor import tensor as pt -from pytensor.graph import vectorize_graph -from pytensor.tensor import TensorType, TensorVariable +from pytensor import Variable +from pytensor.compile.builders import OpFromGraph +from pytensor.compile.mode import Mode +from pytensor.graph import Op, vectorize_graph +from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.scan import map as scan_map +from pytensor.scan import scan +from pytensor.tensor import TensorVariable from pymc_experimental.distributions import DiscreteMarkovChain -class MarginalRV(SymbolicRandomVariable): +class MarginalRV(OpFromGraph, MeasurableOp): """Base class for Marginalized RVs""" + def __init__(self, *args, dims_connections: tuple[tuple[int | None]], **kwargs) -> None: + self.dims_connections = dims_connections + super().__init__(*args, **kwargs) -class FiniteDiscreteMarginalRV(MarginalRV): - """Base class for Finite Discrete Marginalized RVs""" + @property + def support_axes(self) -> tuple[tuple[int]]: + """Dimensions of dependent RVs that belong to the core (non-batched) marginalized variable.""" + marginalized_ndim_supp = self.inner_outputs[0].owner.op.ndim_supp + support_axes_vars = [] + for dims_connection in self.dims_connections: + ndim = len(dims_connection) + marginalized_supp_axes = ndim - marginalized_ndim_supp + support_axes_vars.append( + tuple( + -i + for i, dim in enumerate(reversed(dims_connection), start=1) + if (dim is None or dim > marginalized_supp_axes) + ) + ) + return tuple(support_axes_vars) -class DiscreteMarginalMarkovChainRV(MarginalRV): - """Base class for Discrete Marginal Markov Chain RVs""" +class MarginalFiniteDiscreteRV(MarginalRV): + """Base class for Marginalized Finite Discrete RVs""" + + +class MarginalDiscreteMarkovChainRV(MarginalRV): + """Base class for Marginalized Discrete Markov Chain RVs""" def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: @@ -34,7 +59,8 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: return (0, 1) elif isinstance(op, Categorical): [p_param] = dist_params - return tuple(range(pt.get_vector_length(p_param))) + [p_param_length] = constant_fold([p_param.shape[-1]]) + return tuple(range(p_param_length)) elif isinstance(op, DiscreteUniform): lower, upper = constant_fold(dist_params) return tuple(np.arange(lower, upper + 1)) @@ -45,31 +71,81 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: raise NotImplementedError(f"Cannot compute domain for op {op}") -def _add_reduce_batch_dependent_logps( - marginalized_type: TensorType, dependent_logps: Sequence[TensorVariable] -): - """Add the logps of dependent RVs while reducing extra batch dims relative to `marginalized_type`.""" +def reduce_batch_dependent_logps( + dependent_dims_connections: Sequence[tuple[int | None, ...]], + dependent_ops: Sequence[Op], + dependent_logps: Sequence[TensorVariable], +) -> TensorVariable: + """Combine the logps of dependent RVs and align them with the marginalized logp. + + This requires reducing extra batch dims and transposing when they are not aligned. + + idx = pm.Bernoulli(idx, shape=(3, 2)) # 0, 1 + pm.Normal("dep1", mu=idx.T[..., None] * 2, shape=(3, 2, 5)) + pm.Normal("dep2", mu=idx * 2, shape=(7, 2, 3)) + + marginalize(idx) + + The marginalized op will have dims_connections = [(1, 0, None), (None, 0, 1)] + which tells us we need to reduce the last axis of dep1 logp and the first of dep2 logp, + as well as transpose the remaining axis of dep1 logp before adding the two element-wise. + + """ + from pymc_experimental.model.marginal.graph_analysis import get_support_axes - mbcast = marginalized_type.broadcastable reduced_logps = [] - for dependent_logp in dependent_logps: - dbcast = dependent_logp.type.broadcastable - dim_diff = len(dbcast) - len(mbcast) - mbcast_aligned = (True,) * dim_diff + mbcast - vbcast_axis = [i for i, (m, v) in enumerate(zip(mbcast_aligned, dbcast)) if m and not v] - reduced_logps.append(dependent_logp.sum(vbcast_axis)) - return pt.add(*reduced_logps) + for dependent_op, dependent_logp, dependent_dims_connection in zip( + dependent_ops, dependent_logps, dependent_dims_connections + ): + if dependent_logp.type.ndim > 0: + # Find which support axis implied by the MarginalRV need to be reduced + # Some may have already been reduced by the logp expression of the dependent RV (e.g., multivariate RVs) + dep_supp_axes = get_support_axes(dependent_op)[0] + # Dependent RV support axes are already collapsed in the logp, so we ignore them + supp_axes = [ + -i + for i, dim in enumerate(reversed(dependent_dims_connection), start=1) + if (dim is None and -i not in dep_supp_axes) + ] + dependent_logp = dependent_logp.sum(supp_axes) -@_logprob.register(FiniteDiscreteMarginalRV) -def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): - # Clone the inner RV graph of the Marginalized RV - marginalized_rvs_node = op.make_node(*inputs) - marginalized_rv, *inner_rvs = clone_replace( + # Finally, we need to align the dependent logp batch dimensions with the marginalized logp + dims_alignment = [dim for dim in dependent_dims_connection if dim is not None] + dependent_logp = dependent_logp.transpose(*dims_alignment) + + reduced_logps.append(dependent_logp) + + reduced_logp = pt.add(*reduced_logps) + return reduced_logp + + +def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> TensorVariable: + """Align the logp with the order specified in dims.""" + dims_alignment = [dim for dim in dims if dim is not None] + return logp.transpose(*dims_alignment) + + +def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: + """Inline the inner graph (outputs) of an OpFromGraph Op. + + Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" + the inner graph. + """ + return clone_replace( op.inner_outputs, - replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, + replace=tuple(zip(op.inner_inputs, inputs)), ) + +DUMMY_ZERO = pt.constant(0, name="dummy_zero") + + +@_logprob.register(MarginalFiniteDiscreteRV) +def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inputs, **kwargs): + # Clone the inner RV graph of the Marginalized RV + marginalized_rv, *inner_rvs = inline_ofg_outputs(op, inputs) + # Obtain the joint_logp graph of the inner RV graph inner_rv_values = dict(zip(inner_rvs, values)) marginalized_vv = marginalized_rv.clone() @@ -78,8 +154,10 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): # Reduce logp dimensions corresponding to broadcasted variables marginalized_logp = logps_dict.pop(marginalized_vv) - joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( - marginalized_rv.type, logps_dict.values() + joint_logp = marginalized_logp + reduce_batch_dependent_logps( + dependent_dims_connections=op.dims_connections, + dependent_ops=[inner_rv.owner.op for inner_rv in inner_rvs], + dependent_logps=[logps_dict[value] for value in values], ) # Compute the joint_logp for all possible n values of the marginalized RV. We assume @@ -116,21 +194,20 @@ def logp_fn(marginalized_rv_const, *non_sequences): mode=Mode().including("local_remove_check_parameter"), ) - joint_logps = pt.logsumexp(joint_logps, axis=0) + joint_logp = pt.logsumexp(joint_logps, axis=0) + + # Align logp with non-collapsed batch dimensions of first RV + joint_logp = align_logp_dims(dims=op.dims_connections[0], logp=joint_logp) # We have to add dummy logps for the remaining value variables, otherwise PyMC will raise - return joint_logps, *(pt.constant(0),) * (len(values) - 1) + dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) + return joint_logp, *dummy_logps -@_logprob.register(DiscreteMarginalMarkovChainRV) +@_logprob.register(MarginalDiscreteMarkovChainRV) def marginal_hmm_logp(op, values, *inputs, **kwargs): - marginalized_rvs_node = op.make_node(*inputs) - inner_rvs = clone_replace( - op.inner_outputs, - replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, - ) + chain_rv, *dependent_rvs = inline_ofg_outputs(op, inputs) - chain_rv, *dependent_rvs = inner_rvs P, n_steps_, init_dist_, rng = chain_rv.owner.inputs domain = pt.arange(P.shape[-1], dtype="int32") @@ -145,8 +222,10 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs): logp_emissions_dict = conditional_logp(dict(zip(dependent_rvs, values))) # Reduce and add the batch dims beyond the chain dimension - reduced_logp_emissions = _add_reduce_batch_dependent_logps( - chain_rv.type, logp_emissions_dict.values() + reduced_logp_emissions = reduce_batch_dependent_logps( + dependent_dims_connections=op.dims_connections, + dependent_ops=[dependent_rv.owner.op for dependent_rv in dependent_rvs], + dependent_logps=[logp_emissions_dict[value] for value in values], ) # Add a batch dimension for the domain of the chain @@ -185,7 +264,13 @@ def step_alpha(logp_emission, log_alpha, log_P): # Final logp is just the sum of the last scan state joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0) + # Align logp with non-collapsed batch dimensions of first RV + remaining_dims_first_emission = list(op.dims_connections[0]) + # The last dim of chain_rv was removed when computing the logp + remaining_dims_first_emission.remove(chain_rv.type.ndim - 1) + joint_logp = align_logp_dims(remaining_dims_first_emission, joint_logp) + # If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first - # return is the joint probability of everything together, but PyMC still expects one logp for each one. - dummy_logps = (pt.constant(0),) * (len(values) - 1) + # return is the joint probability of everything together, but PyMC still expects one logp for each emission stream. + dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) return joint_logp, *dummy_logps diff --git a/pymc_experimental/model/marginal/graph_analysis.py b/pymc_experimental/model/marginal/graph_analysis.py index 58c11a1e..62ac2abb 100644 --- a/pymc_experimental/model/marginal/graph_analysis.py +++ b/pymc_experimental/model/marginal/graph_analysis.py @@ -1,8 +1,22 @@ +import itertools + +from collections.abc import Sequence +from itertools import zip_longest + +from pymc import SymbolicRandomVariable from pytensor.compile import SharedVariable -from pytensor.graph import Constant, FunctionGraph, ancestors -from pytensor.tensor import TensorVariable -from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.graph import Constant, Variable, ancestors +from pytensor.graph.basic import io_toposort +from pytensor.tensor import TensorType, TensorVariable +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.shape import Shape +from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list +from pytensor.tensor.type_other import NoneTypeT + +from pymc_experimental.model.marginal.distributions import MarginalRV def static_shape_ancestors(vars): @@ -48,45 +62,311 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs): ] -def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs): - # TODO: No need to consider apply nodes outside the subgraph... - fg = FunctionGraph(outputs=output_rvs, clone=False) - - non_elemwise_blockers = [ - o - for node in fg.apply_nodes - if not ( - isinstance(node.op, Elemwise) - # Allow expand_dims on the left - or ( - isinstance(node.op, DimShuffle) - and not node.op.drop - and node.op.shuffle == sorted(node.op.shuffle) - ) - ) - for o in node.outputs - ] - blocker_candidates = [rv_to_marginalize, *other_input_rvs, *non_elemwise_blockers] - blockers = [var for var in blocker_candidates if var not in output_rvs] +def get_support_axes(op) -> tuple[tuple[int, ...], ...]: + if isinstance(op, MarginalRV): + return op.support_axes + else: + # For vanilla RVs, the support axes are the last ndim_supp + return (tuple(range(-op.ndim_supp, 0)),) - truncated_inputs = [ - var - for var in ancestors(output_rvs, blockers=blockers) - if ( - var in blockers - or (var.owner is None and not isinstance(var, Constant | SharedVariable)) - ) + +def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]: + """Find the output axis and dimensionality of the advanced indexing group (i.e., array indexing). + + There is a special case: when there are non-consecutive advanced indexing groups, the advanced indexing + group is always moved to the front. + + See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + """ + adv_group_axis = None + simple_group_after_adv = False + for axis, idx in enumerate(idxs): + if isinstance(idx.type, TensorType): + if simple_group_after_adv: + # Special non-consecutive case + adv_group_axis = 0 + break + elif adv_group_axis is None: + adv_group_axis = axis + elif adv_group_axis is not None: + # Special non-consecutive case + simple_group_after_adv = True + + adv_group_ndim = max(idx.type.ndim for idx in idxs if isinstance(idx.type, TensorType)) + return adv_group_axis, adv_group_ndim + + +DIMS = tuple[int | None, ...] +VAR_DIMS = dict[Variable, DIMS] + + +def _broadcast_dims( + inputs_dims: Sequence[DIMS], +) -> DIMS: + output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0) + + # Add missing dims + inputs_dims = [ + (None,) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims ] - # Check that we reach the marginalized rv following a pure elemwise graph - if rv_to_marginalize not in truncated_inputs: - return False - - # Check that none of the truncated inputs depends on the marginalized_rv - other_truncated_inputs = [inp for inp in truncated_inputs if inp is not rv_to_marginalize] - # TODO: We don't need to go all the way to the root variables - if rv_to_marginalize in ancestors( - other_truncated_inputs, blockers=[rv_to_marginalize, *other_input_rvs] - ): - return False - return True + # Find which known dims show in the output, while checking no mixing + output_dims = [] + for inputs_dim in zip(*inputs_dims): + output_dim = None + for input_dim in inputs_dim: + if input_dim is None: + continue + if output_dim is not None and output_dim != input_dim: + raise ValueError("Different known dimensions mixed via broadcasting") + output_dim = input_dim + output_dims.append(output_dim) + + # Check for duplicates + known_dims = [dim for dim in output_dims if dim is not None] + if len(known_dims) > len(set(known_dims)): + raise ValueError("Same known dimension used in different axis after broadcasting") + + return tuple(output_dims) + + +def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR_DIMS: + for node in io_toposort(input_vars, output_vars): + inputs_dims = [ + var_dims.get(inp, ((None,) * inp.type.ndim) if hasattr(inp.type, "ndim") else ()) + for inp in node.inputs + ] + + if all(dim is None for input_dims in inputs_dims for dim in input_dims): + # None of the inputs are related to the batch_axes of the input_vars + continue + + elif isinstance(node.op, DimShuffle): + [input_dims] = inputs_dims + output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order) + var_dims[node.outputs[0]] = output_dims + + elif isinstance(node.op, MarginalRV) or ( + isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None + ): + # MarginalRV and SymbolicRandomVariables without signature are a wild-card, + # so we need to introspect the inner graph. + op = node.op + inner_inputs = op.inner_inputs + inner_outputs = op.inner_outputs + + inner_var_dims = _subgraph_batch_dim_connection( + dict(zip(inner_inputs, inputs_dims)), inner_inputs, inner_outputs + ) + + support_axes = iter(get_support_axes(op)) + if isinstance(op, MarginalRV): + # The first output is the marginalized variable for which we don't compute support axes + support_axes = itertools.chain(((),), support_axes) + for i, (out, inner_out) in enumerate(zip(node.outputs, inner_outputs)): + if not isinstance(out.type, TensorType): + continue + support_axes_out = next(support_axes) + + if inner_out in inner_var_dims: + out_dims = inner_var_dims[inner_out] + if any( + dim is not None for dim in (out_dims[axis] for axis in support_axes_out) + ): + raise ValueError(f"Known dim corresponds to core dimension of {node.op}") + var_dims[out] = out_dims + + elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable): + # NOTE: User-provided CustomDist may not respect core dimensions on the left. + + if isinstance(node.op, Elemwise): + op_batch_ndim = node.outputs[0].type.ndim + else: + op_batch_ndim = node.op.batch_ndim(node) + + if isinstance(node.op, SymbolicRandomVariable): + # SymbolicRandomVariable don't have explicit expand_dims unlike the other Ops considered in this + [_, _, param_idxs], _ = node.op.get_input_output_type_idxs( + node.op.extended_signature + ) + for param_idx, param_core_ndim in zip(param_idxs, node.op.ndims_params): + param_dims = inputs_dims[param_idx] + missing_ndim = op_batch_ndim - (len(param_dims) - param_core_ndim) + inputs_dims[param_idx] = (None,) * missing_ndim + param_dims + + if any( + dim is not None for input_dim in inputs_dims for dim in input_dim[op_batch_ndim:] + ): + raise ValueError( + f"Use of known dimensions as core dimensions of op {node.op} not supported." + ) + + batch_dims = _broadcast_dims( + tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims) + ) + for out in node.outputs: + if isinstance(out.type, TensorType): + core_ndim = out.type.ndim - op_batch_ndim + output_dims = batch_dims + (None,) * core_ndim + var_dims[out] = output_dims + + elif isinstance(node.op, CAReduce): + [input_dims] = inputs_dims + + axes = node.op.axis + if isinstance(axes, int): + axes = (axes,) + elif axes is None: + axes = tuple(range(node.inputs[0].type.ndim)) + + if any(input_dims[axis] for axis in axes): + raise ValueError( + f"Use of known dimensions as reduced dimensions of op {node.op} not supported." + ) + + output_dims = [dims for i, dims in enumerate(input_dims) if i not in axes] + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, Subtensor): + value_dims, *keys_dims = inputs_dims + # Dims in basic indexing must belong to the value variable, since indexing keys are always scalar + assert not any(dim is None for dim in keys_dims) + keys = get_idx_list(node.inputs, node.op.idx_list) + + output_dims = [] + for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if idx == slice(None): + # Dim is kept + output_dims.append(value_dim) + elif value_dim is not None: + raise ValueError( + "Partial slicing or indexing of known dimensions not supported." + ) + elif isinstance(idx, slice): + # Unknown dimensions kept by partial slice. + output_dims.append(None) + + var_dims[node.outputs[0]] = tuple(output_dims) + + elif isinstance(node.op, AdvancedSubtensor): + # AdvancedSubtensor dimensions can show up as both the indexed variable and indexing variables + value, *keys = node.inputs + value_dims, *keys_dims = inputs_dims + + # Just to stay sane, we forbid any boolean indexing... + if any(isinstance(idx.type, TensorType) and idx.type.dtype == "bool" for idx in keys): + raise NotImplementedError( + f"Array indexing with boolean variables in node {node} not supported." + ) + + if any(dim is not None for dim in value_dims) and any( + dim is not None for key_dims in keys_dims for dim in key_dims + ): + # Both indexed variable and indexing variables have known dimensions + # I am to lazy to think through these, so we raise for now. + raise NotImplementedError( + f"Simultaneous use of known dimensions in indexed and indexing variables in node {node} not supported." + ) + + adv_group_axis, adv_group_ndim = _advanced_indexing_axis_and_ndim(keys) + + if any(dim is not None for dim in value_dims): + # Indexed variable has known dimensions + + if any(isinstance(idx.type, NoneTypeT) for idx in keys): + # Corresponds to an expand_dims, for now not supported + raise NotImplementedError( + f"Advanced indexing in node {node} which introduces new axis is not supported." + ) + + non_adv_dims = [] + for value_dim, idx in zip_longest(value_dims, keys, fillvalue=slice(None)): + if is_full_slice(idx): + non_adv_dims.append(value_dim) + elif value_dim is not None: + # We are trying to partially slice or index a known dimension + raise ValueError( + "Partial slicing or advanced integer indexing of known dimensions not supported." + ) + elif isinstance(idx, slice): + # Unknown dimensions kept by partial slice. + non_adv_dims.append(None) + + # Insert unknown dimensions corresponding to advanced indexing + output_dims = tuple( + non_adv_dims[:adv_group_axis] + + [None] * adv_group_ndim + + non_adv_dims[adv_group_axis:] + ) + + else: + # Indexing keys have known dimensions. + # Only array indices can have dimensions, the rest are just slices or newaxis + + # Advanced indexing variables broadcast together, so we apply same rules as in Elemwise + adv_dims = _broadcast_dims(keys_dims) + + start_non_adv_dims = (None,) * adv_group_axis + end_non_adv_dims = (None,) * ( + node.outputs[0].type.ndim - adv_group_axis - adv_group_ndim + ) + output_dims = start_non_adv_dims + adv_dims + end_non_adv_dims + + var_dims[node.outputs[0]] = output_dims + + else: + raise NotImplementedError(f"Marginalization through operation {node} not supported.") + + return var_dims + + +def subgraph_batch_dim_connection(input_var, output_vars) -> list[DIMS]: + """Identify how the batch dims of input map to the batch dimensions of the output_rvs. + + Example: + ------- + In the example below `idx` has two batch dimensions (indexed 0, 1 from left to right). + The two uncommented dependent variables each have 2 batch dimensions where each entry + results from a mapping of a single entry from one of these batch dimensions. + + This mapping is transposed in the case of the first dependent variable, and shows up in + the same order for the second dependent variable. Each of the variables as a further + batch dimension encoded as `None`. + + The commented out third dependent variable combines information from the batch dimensions + of `idx` via the `sum` operation. A `ValueError` would be raised if we requested the + connection of batch dims. + + .. code-block:: python + import pymc as pm + + idx = pm.Bernoulli.dist(shape=(3, 2)) + dep1 = pm.Normal.dist(mu=idx.T[..., None] * 2, shape=(3, 2, 5)) + dep2 = pm.Normal.dist(mu=idx * 2, shape=(7, 2, 3)) + # dep3 = pm.Normal.dist(mu=idx.sum()) # Would raise if requested + + print(subgraph_batch_dim_connection(idx, [], [dep1, dep2])) + # [(1, 0, None), (None, 0, 1)] + + Returns: + ------- + list of tuples + Each tuple corresponds to the batch dimensions of the output_rv in the order they are found in the output. + None is used to indicate a batch dimension that is not mapped from the input. + + Raises: + ------ + ValueError + If input batch dimensions are mixed in the graph leading to output_vars. + + NotImplementedError + If variable related to marginalized batch_dims is used in an operation that is not yet supported + """ + var_dims = {input_var: tuple(range(input_var.type.ndim))} + var_dims = _subgraph_batch_dim_connection(var_dims, [input_var], output_vars) + ret = [] + for output_var in output_vars: + output_dims = var_dims.get(output_var, (None,) * output_var.type.ndim) + assert len(output_dims) == output_var.type.ndim + ret.append(output_dims) + return ret diff --git a/pymc_experimental/model/marginal/marginal_model.py b/pymc_experimental/model/marginal/marginal_model.py index 94c577c4..b4700c3d 100644 --- a/pymc_experimental/model/marginal/marginal_model.py +++ b/pymc_experimental/model/marginal/marginal_model.py @@ -16,8 +16,7 @@ from pymc.pytensorf import compile_pymc, constant_fold from pymc.util import RandomState, _get_seeds_per_chain, treedict from pytensor.compile import SharedVariable -from pytensor.graph import FunctionGraph, clone_replace -from pytensor.graph.basic import graph_inputs +from pytensor.graph import FunctionGraph, clone_replace, graph_inputs from pytensor.graph.replace import vectorize_graph from pytensor.tensor import TensorVariable from pytensor.tensor.special import log_softmax @@ -26,16 +25,16 @@ from pymc_experimental.distributions import DiscreteMarkovChain from pymc_experimental.model.marginal.distributions import ( - DiscreteMarginalMarkovChainRV, - FiniteDiscreteMarginalRV, - _add_reduce_batch_dependent_logps, + MarginalDiscreteMarkovChainRV, + MarginalFiniteDiscreteRV, get_domain_of_finite_discrete_rv, + reduce_batch_dependent_logps, ) from pymc_experimental.model.marginal.graph_analysis import ( find_conditional_dependent_rvs, find_conditional_input_rvs, is_conditional_dependent, - is_elemwise_subgraph, + subgraph_batch_dim_connection, ) ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str] @@ -424,17 +423,22 @@ def transform_input(inputs): m = self.clone() marginalized_rv = m.vars_to_clone[marginalized_rv] m.unmarginalize([marginalized_rv]) - dependent_vars = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) - joint_logps = m.logp(vars=[marginalized_rv, *dependent_vars], sum=False) - - marginalized_value = m.rvs_to_values[marginalized_rv] - other_values = [v for v in m.value_vars if v is not marginalized_value] + dependent_rvs = find_conditional_dependent_rvs(marginalized_rv, m.basic_RVs) + logps = m.logp(vars=[marginalized_rv, *dependent_rvs], sum=False) # Handle batch dims for marginalized value and its dependent RVs - marginalized_logp, *dependent_logps = joint_logps - joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps( - marginalized_rv.type, dependent_logps + dependent_rvs_dim_connections = subgraph_batch_dim_connection( + marginalized_rv, dependent_rvs ) + marginalized_logp, *dependent_logps = logps + joint_logp = marginalized_logp + reduce_batch_dependent_logps( + dependent_rvs_dim_connections, + [dependent_var.owner.op for dependent_var in dependent_rvs], + dependent_logps, + ) + + marginalized_value = m.rvs_to_values[marginalized_rv] + other_values = [v for v in m.value_vars if v is not marginalized_value] rv_shape = constant_fold(tuple(marginalized_rv.shape), raise_not_constant=False) rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) @@ -448,37 +452,30 @@ def transform_input(inputs): 0, ) - joint_logps = vectorize_graph( + batched_joint_logp = vectorize_graph( joint_logp, replace={marginalized_value: rv_domain_tensor}, ) - joint_logps = pt.moveaxis(joint_logps, 0, -1) + batched_joint_logp = pt.moveaxis(batched_joint_logp, 0, -1) - rv_loglike_fn = None - joint_logps_norm = log_softmax(joint_logps, axis=-1) + joint_logp_norm = log_softmax(batched_joint_logp, axis=-1) if return_samples: - sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps) + rv_draws = pymc.Categorical.dist(logit_p=batched_joint_logp) if isinstance(marginalized_rv.owner.op, DiscreteUniform): - sample_rv_outs += rv_domain[0] - - rv_loglike_fn = compile_pymc( - inputs=other_values, - outputs=[joint_logps_norm, sample_rv_outs], - on_unused_input="ignore", - random_seed=seed, - ) + rv_draws += rv_domain[0] + outputs = [joint_logp_norm, rv_draws] else: - rv_loglike_fn = compile_pymc( - inputs=other_values, - outputs=joint_logps_norm, - on_unused_input="ignore", - random_seed=seed, - ) + outputs = joint_logp_norm + + rv_loglike_fn = compile_pymc( + inputs=other_values, + outputs=outputs, + on_unused_input="ignore", + random_seed=seed, + ) logvs = [rv_loglike_fn(**vs) for vs in posterior_pts] - logps = None - samples = None if return_samples: logps, samples = zip(*logvs) logps = np.array(logps) @@ -552,61 +549,47 @@ def collect_shared_vars(outputs, blockers): def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): - # TODO: This should eventually be integrated in a more general routine that can - # identify other types of supported marginalization, of which finite discrete - # RVs is just one - dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) if not dependent_rvs: raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") - ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs} - if len(ndim_supp) != 1: - raise NotImplementedError( - "Marginalization with dependent variables of different support dimensionality not implemented" - ) - [ndim_supp] = ndim_supp - if ndim_supp > 0: - raise NotImplementedError("Marginalization with dependent Multivariate RVs not implemented") - marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) - dependent_rvs_input_rvs = [ + other_direct_rv_ancestors = [ rv for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) if rv is not rv_to_marginalize ] - # If the marginalized RV has batched dimensions, check that graph between - # marginalized RV and dependent RVs is composed strictly of Elemwise Operations. - # This implies (?) that the dimensions are completely independent and a logp graph - # can ultimately be generated that is proportional to the support domain and not - # to the variables dimensions - # We don't need to worry about this if the RV is scalar. - if np.prod(constant_fold(tuple(rv_to_marginalize.shape), raise_not_constant=False)) != 1: - if not is_elemwise_subgraph(rv_to_marginalize, dependent_rvs_input_rvs, dependent_rvs): - raise NotImplementedError( - "The subgraph between a marginalized RV and its dependents includes non Elemwise operations. " - "This is currently not supported", - ) + # If the marginalized RV has multiple dimensions, check that graph between + # marginalized RV and dependent RVs does not mix information from batch dimensions + # (otherwise logp would require enumerating over all combinations of batch dimension values) + try: + dependent_rvs_dim_connections = subgraph_batch_dim_connection( + rv_to_marginalize, dependent_rvs + ) + except (ValueError, NotImplementedError) as e: + # For the perspective of the user this is a NotImplementedError + raise NotImplementedError( + "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. " + "You can try splitting the marginalized RV into separate components and marginalizing them separately." + ) from e - input_rvs = list(set((*marginalized_rv_input_rvs, *dependent_rvs_input_rvs))) - rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] + input_rvs = list(set((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))) + output_rvs = [rv_to_marginalize, *dependent_rvs] - outputs = rvs_to_marginalize # We are strict about shared variables in SymbolicRandomVariables - inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs) + inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs) if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain): - marginalize_constructor = DiscreteMarginalMarkovChainRV + marginalize_constructor = MarginalDiscreteMarkovChainRV else: - marginalize_constructor = FiniteDiscreteMarginalRV + marginalize_constructor = MarginalFiniteDiscreteRV marginalization_op = marginalize_constructor( inputs=inputs, - outputs=outputs, - ndim_supp=ndim_supp, + outputs=output_rvs, # TODO: Add RNG updates to outputs so this can be used in the generative graph + dims_connections=dependent_rvs_dim_connections, ) - - marginalized_rvs = marginalization_op(*inputs) - fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) - return rvs_to_marginalize, marginalized_rvs + new_output_rvs = marginalization_op(*inputs) + fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs))) + return output_rvs, new_output_rvs diff --git a/requirements.txt b/requirements.txt index a7141a82..b992ad37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.16.1 +pymc>=5.17.0 scikit-learn diff --git a/tests/model/marginal/test_distributions.py b/tests/model/marginal/test_distributions.py index 7c0e0fd5..ecbc8817 100644 --- a/tests/model/marginal/test_distributions.py +++ b/tests/model/marginal/test_distributions.py @@ -8,7 +8,7 @@ from pymc_experimental import MarginalModel from pymc_experimental.distributions import DiscreteMarkovChain -from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV +from pymc_experimental.model.marginal.distributions import MarginalFiniteDiscreteRV def test_marginalized_bernoulli_logp(): @@ -17,13 +17,10 @@ def test_marginalized_bernoulli_logp(): idx = pm.Bernoulli.dist(0.7, name="idx") y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y") - marginal_rv_node = FiniteDiscreteMarginalRV( + marginal_rv_node = MarginalFiniteDiscreteRV( [mu], [idx, y], - ndim_supp=0, - n_updates=0, - # Ignore the fact we didn't specify shared RNG input/outputs for idx,y - strict=False, + dims_connections=(((),),), )(mu)[0].owner y_vv = y.clone() @@ -78,9 +75,7 @@ def test_marginalized_hmm_categorical_emission(categorical_emission): init_dist = pm.Categorical.dist(p=[0.375, 0.625]) chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=2) if categorical_emission: - emission = pm.Categorical( - "emission", p=pt.where(pt.eq(chain, 0)[..., None], [0.8, 0.2], [0.4, 0.6]) - ) + emission = pm.Categorical("emission", p=pt.constant([[0.8, 0.2], [0.4, 0.6]])[chain]) else: emission = pm.Bernoulli("emission", p=pt.where(pt.eq(chain, 0), 0.2, 0.6)) m.marginalize([chain]) @@ -91,29 +86,46 @@ def test_marginalized_hmm_categorical_emission(categorical_emission): np.testing.assert_allclose(logp_fn({"emission": test_value}), expected_logp) +@pytest.mark.parametrize("batch_chain", (False, True)) @pytest.mark.parametrize("batch_emission1", (False, True)) @pytest.mark.parametrize("batch_emission2", (False, True)) -def test_marginalized_hmm_multiple_emissions(batch_emission1, batch_emission2): - emission1_shape = (2, 4) if batch_emission1 else (4,) - emission2_shape = (2, 4) if batch_emission2 else (4,) +def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch_emission2): + chain_shape = (3, 1, 4) if batch_chain else (4,) + emission1_shape = ( + (2, *reversed(chain_shape)) if batch_emission1 else tuple(reversed(chain_shape)) + ) + emission2_shape = (*chain_shape, 2) if batch_emission2 else chain_shape with MarginalModel() as m: P = [[0, 1], [1, 0]] init_dist = pm.Categorical.dist(p=[1, 0]) - chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=3) - emission_1 = pm.Normal("emission_1", mu=chain * 2 - 1, sigma=1e-1, shape=emission1_shape) - emission_2 = pm.Normal( - "emission_2", mu=(1 - chain) * 2 - 1, sigma=1e-1, shape=emission2_shape + chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, shape=chain_shape) + emission_1 = pm.Normal( + "emission_1", mu=(chain * 2 - 1).T, sigma=1e-1, shape=emission1_shape ) + emission2_mu = (1 - chain) * 2 - 1 + if batch_emission2: + emission2_mu = emission2_mu[..., None] + emission_2 = pm.Normal("emission_2", mu=emission2_mu, sigma=1e-1, shape=emission2_shape) + with pytest.warns(UserWarning, match="multiple dependent variables"): m.marginalize([chain]) - logp_fn = m.compile_logp() + logp_fn = m.compile_logp(sum=False) test_value = np.array([-1, 1, -1, 1]) multiplier = 2 + batch_emission1 + batch_emission2 + if batch_chain: + multiplier *= 3 expected_logp = norm.logpdf(np.zeros_like(test_value), 0, 1e-1).sum() * multiplier - test_value_emission1 = np.broadcast_to(test_value, emission1_shape) - test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) + + test_value = np.broadcast_to(test_value, chain_shape) + test_value_emission1 = np.broadcast_to(test_value.T, emission1_shape) + if batch_emission2: + test_value_emission2 = np.broadcast_to(-test_value[..., None], emission2_shape) + else: + test_value_emission2 = np.broadcast_to(-test_value, emission2_shape) test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2} - np.testing.assert_allclose(logp_fn(test_point), expected_logp) + res_logp, dummy_logp = logp_fn(test_point) + assert res_logp.shape == ((1, 3) if batch_chain else ()) + np.testing.assert_allclose(res_logp.sum(), expected_logp) diff --git a/tests/model/marginal/test_graph_analysis.py b/tests/model/marginal/test_graph_analysis.py index 58d65dbe..2382247b 100644 --- a/tests/model/marginal/test_graph_analysis.py +++ b/tests/model/marginal/test_graph_analysis.py @@ -1,6 +1,13 @@ -from pytensor import tensor as pt +import pytensor.tensor as pt +import pytest -from pymc_experimental.model.marginal.graph_analysis import is_conditional_dependent +from pymc.distributions import CustomDist +from pytensor.tensor.type_other import NoneTypeT + +from pymc_experimental.model.marginal.graph_analysis import ( + is_conditional_dependent, + subgraph_batch_dim_connection, +) def test_is_conditional_dependent_static_shape(): @@ -12,3 +19,164 @@ def test_is_conditional_dependent_static_shape(): x2 = pt.matrix("x2", shape=(9, 5)) y2 = pt.random.normal(size=pt.shape(x2)) assert not is_conditional_dependent(y2, x2, [x2, y2]) + + +class TestSubgraphBatchDimConnection: + def test_dimshuffle(self): + inp = pt.tensor(shape=(5, 1, 4, 3)) + out1 = pt.matrix_transpose(inp) + out2 = pt.expand_dims(inp, 1) + out3 = pt.squeeze(inp) + [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3]) + assert dims1 == (0, 1, 3, 2) + assert dims2 == (0, None, 1, 2, 3) + assert dims3 == (0, 2, 3) + + def test_careduce(self): + inp = pt.tensor(shape=(4, 3, 2)) + + out = pt.sum(inp[:, None], axis=(1,)) + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, 2) + + invalid_out = pt.sum(inp, axis=(1,)) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + def test_subtensor(self): + inp = pt.tensor(shape=(4, 3, 2)) + + invalid_out = inp[0, :1] + with pytest.raises( + ValueError, + match="Partial slicing or indexing of known dimensions not supported", + ): + subgraph_batch_dim_connection(inp, [invalid_out]) + + # If we are selecting dummy / unknown dimensions that's fine + valid_out = pt.expand_dims(inp, (0, 1))[0, :1] + [dims] = subgraph_batch_dim_connection(inp, [valid_out]) + assert dims == (None, 0, 1, 2) + + def test_advanced_subtensor_value(self): + inp = pt.tensor(shape=(2, 4)) + intermediate_out = inp[:, None, :, None] + pt.zeros((2, 3, 4, 5)) + + # Index on an unlabled dim introduced by broadcasting with zeros + out = intermediate_out[:, [0, 0, 1, 2]] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, None, 1, None) + + # Indexing that introduces more dimensions + out = intermediate_out[:, [[0, 0], [1, 2]], :] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, None, None, 1, None) + + # Special case where advanced dims are moved to the front of the output + out = intermediate_out[:, [0, 0, 1, 2], :, 0] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (None, 0, 1) + + # Indexing on a labeled dim fails + out = intermediate_out[:, :, [0, 0, 1, 2]] + with pytest.raises(ValueError, match="Partial slicing or advanced integer indexing"): + subgraph_batch_dim_connection(inp, [out]) + + def test_advanced_subtensor_key(self): + inp = pt.tensor(shape=(5, 5), dtype=int) + base = pt.zeros((2, 3, 4)) + + out = base[inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, None, None) + + out = base[:, :, inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == ( + None, + None, + 0, + 1, + ) + + out = base[1:, 0, inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (None, 0, 1) + + # Special case where advanced dims are moved to the front of the output + out = base[0, :, inp] + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, None) + + # Mix keys dimensions + out = base[:, inp, inp.T] + with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"): + subgraph_batch_dim_connection(inp, [out]) + + def test_elemwise(self): + inp = pt.tensor(shape=(5, 5)) + + out = inp + inp + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1) + + out = inp + inp.T + with pytest.raises(ValueError, match="Different known dimensions mixed via broadcasting"): + subgraph_batch_dim_connection(inp, [out]) + + out = inp[None, :, None, :] + inp[:, None, :, None] + with pytest.raises( + ValueError, match="Same known dimension used in different axis after broadcasting" + ): + subgraph_batch_dim_connection(inp, [out]) + + def test_blockwise(self): + inp = pt.tensor(shape=(5, 4)) + + invalid_out = inp @ pt.ones((4, 3)) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((2, 3)) + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, None, None) + + def test_random_variable(self): + inp = pt.tensor(shape=(5, 4, 3)) + + out1 = pt.random.normal(loc=inp) + out2 = pt.random.categorical(p=inp[..., None]) + out3 = pt.random.multivariate_normal(mean=inp[..., None], cov=pt.eye(1)) + [dims1, dims2, dims3] = subgraph_batch_dim_connection(inp, [out1, out2, out3]) + assert dims1 == (0, 1, 2) + assert dims2 == (0, 1, 2) + assert dims3 == (0, 1, 2, None) + + invalid_out = pt.random.categorical(p=inp) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + invalid_out = pt.random.multivariate_normal(mean=inp, cov=pt.eye(3)) + with pytest.raises(ValueError, match="Use of known dimensions"): + subgraph_batch_dim_connection(inp, [invalid_out]) + + def test_symbolic_random_variable(self): + inp = pt.tensor(shape=(4, 3, 2)) + + # Test univariate + out = CustomDist.dist( + inp, + dist=lambda mu, size: pt.random.normal(loc=mu, size=size), + ) + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, 2) + + # Test multivariate + def dist(mu, size): + if isinstance(size.type, NoneTypeT): + size = mu.shape + return pt.random.normal(loc=mu[..., None], size=(*size, 2)) + + out = CustomDist.dist(inp, dist=dist, size=(4, 3, 2), signature="()->(2)") + [dims] = subgraph_batch_dim_connection(inp, [out]) + assert dims == (0, 1, 2, None) diff --git a/tests/model/marginal/test_marginal_model.py b/tests/model/marginal/test_marginal_model.py index a94499cf..c93cdb74 100644 --- a/tests/model/marginal/test_marginal_model.py +++ b/tests/model/marginal/test_marginal_model.py @@ -10,6 +10,7 @@ from arviz import InferenceData, dict_to_dataset from pymc.distributions import transforms +from pymc.distributions.transforms import ordered from pymc.model.fgraph import fgraph_from_model from pymc.pytensorf import inputvars from pymc.util import UNSET @@ -117,6 +118,36 @@ def test_one_to_many_marginalized_rvs(): np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y) +def test_one_to_many_unaligned_marginalized_rvs(): + """Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned""" + + def build_model(build_batched: bool): + with MarginalModel() as m: + if build_batched: + idx = pm.Bernoulli("idx", p=[0.75, 0.4], shape=(3, 2)) + else: + idxs = [pm.Bernoulli(f"idx_{i}", p=(0.75 if i % 2 == 0 else 0.4)) for i in range(6)] + idx = pt.stack(idxs, axis=0).reshape((3, 2)) + + x = pm.Normal("x", mu=idx.T[:, :, None], shape=(2, 3, 1)) + y = pm.Normal("y", mu=(idx * 2 - 1), shape=(1, 3, 2)) + + return m + + m = build_model(build_batched=True) + ref_m = build_model(build_batched=False) + + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize(["idx"]) + ref_m.marginalize([f"idx_{i}" for i in range(6)]) + + test_point = m.initial_point() + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + def test_many_to_one_marginalized_rvs(): """Test when random variables depend on multiple marginalized variables""" with MarginalModel() as m: @@ -132,40 +163,127 @@ def test_many_to_one_marginalized_rvs(): np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3) -def test_nested_marginalized_rvs(): +@pytest.mark.parametrize("batched", (False, "left", "right")) +def test_nested_marginalized_rvs(batched): """Test that marginalization works when there are nested marginalized RVs""" - with MarginalModel() as m: - sigma = pm.HalfNormal("sigma") + def build_model(build_batched: bool) -> MarginalModel: + idx_shape = (3,) if build_batched else () + sub_idx_shape = (5,) if not build_batched else (5, 3) if batched == "left" else (3, 5) - idx = pm.Bernoulli("idx", p=0.75) - dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma) + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") - sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95), shape=(5,)) - sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma, shape=(5,)) + idx = pm.Bernoulli("idx", p=0.75, shape=idx_shape) + dep = pm.Normal("dep", mu=pt.switch(pt.eq(idx, 0), -1000.0, 1000.0), sigma=sigma) - ref_logp_fn = m.compile_logp(vars=[idx, dep, sub_idx, sub_dep]) + sub_idx_p = pt.switch(pt.eq(idx, 0), 0.15, 0.95) + if build_batched and batched == "right": + sub_idx_p = sub_idx_p[..., None] + dep = dep[..., None] + sub_idx = pm.Bernoulli("sub_idx", p=sub_idx_p, shape=sub_idx_shape) + sub_dep = pm.Normal("sub_dep", mu=dep + sub_idx * 100, sigma=sigma) - with pytest.warns(UserWarning, match="There are multiple dependent variables"): - m.marginalize([idx, sub_idx]) + return m - assert set(m.marginalized_rvs) == {idx, sub_idx} + m = build_model(build_batched=batched) + with pytest.warns(UserWarning, match="There are multiple dependent variables"): + m.marginalize(["idx", "sub_idx"]) + assert sorted(m.name for m in m.marginalized_rvs) == ["idx", "sub_idx"] # Test logp + ref_m = build_model(build_batched=False) + ref_logp_fn = ref_m.compile_logp( + vars=[ref_m["idx"], ref_m["dep"], ref_m["sub_idx"], ref_m["sub_dep"]] + ) + + test_point = ref_m.initial_point() + test_point["dep"] = np.full_like(test_point["dep"], 1000) + test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100) + ref_logp = logsumexp( + [ + ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}}) + for idx in (0, 1) + for sub_idxs in itertools.product((0, 1), repeat=5) + ] + ) + if batched: + ref_logp *= 3 + test_point = m.initial_point() - test_point["dep"] = 1000 - test_point["sub_dep"] = np.full((5,), 1000 + 100) - - ref_logp = [ - ref_logp_fn({**test_point, **{"idx": idx, "sub_idx": np.array(sub_idxs)}}) - for idx in (0, 1) - for sub_idxs in itertools.product((0, 1), repeat=5) - ] - logp = m.compile_logp(vars=[dep, sub_dep])(test_point) - - np.testing.assert_almost_equal( - logp, - logsumexp(ref_logp), + test_point["dep"] = np.full_like(test_point["dep"], 1000) + test_point["sub_dep"] = np.full_like(test_point["sub_dep"], 1000 + 100) + logp = m.compile_logp(vars=[m["dep"], m["sub_dep"]])(test_point) + + np.testing.assert_almost_equal(logp, ref_logp) + + +@pytest.mark.parametrize("advanced_indexing", (False, True)) +def test_marginalized_index_as_key(advanced_indexing): + """Test we can marginalize graphs where indexing is used as a mapping.""" + + w = [0.1, 0.3, 0.6] + mu = pt.as_tensor([-1, 0, 1]) + + if advanced_indexing: + y_val = pt.as_tensor([[-1, -1], [0, 1]]) + shape = (2, 2) + else: + y_val = -1 + shape = () + + with MarginalModel() as m: + x = pm.Categorical("x", p=w, shape=shape) + y = pm.Normal("y", mu[x].T, sigma=1, observed=y_val) + + m.marginalize(x) + + marginal_logp = m.compile_logp(sum=False)({})[0] + ref_logp = pm.logp(pm.NormalMixture.dist(w=w, mu=mu.T, sigma=1, shape=shape), y_val).eval() + + np.testing.assert_allclose(marginal_logp, ref_logp) + + +def test_marginalized_index_as_value_and_key(): + """Test we can marginalize graphs were marginalized_rv is indexed.""" + + def build_model(build_batched: bool) -> MarginalModel: + with MarginalModel() as m: + if build_batched: + latent_state = pm.Bernoulli("latent_state", p=0.3, size=(4,)) + else: + latent_state = pm.math.stack( + [pm.Bernoulli(f"latent_state_{i}", p=0.3) for i in range(4)] + ) + # latent state is used as the indexed variable + latent_intensities = pt.where(latent_state[:, None], [0.0, 1.0, 2.0], [0.0, 10.0, 20.0]) + picked_intensity = pm.Categorical("picked_intensity", p=[0.2, 0.2, 0.6]) + # picked intensity is used as the indexing variable + pm.Normal( + "intensity", + mu=latent_intensities[:, picked_intensity], + observed=[0.5, 1.5, 5.0, 15.0], + ) + return m + + # We compare with the equivalent but less efficient batched model + m = build_model(build_batched=True) + ref_m = build_model(build_batched=False) + + m.marginalize(["latent_state"]) + ref_m.marginalize([f"latent_state_{i}" for i in range(4)]) + test_point = {"picked_intensity": 1} + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + + m.marginalize(["picked_intensity"]) + ref_m.marginalize(["picked_intensity"]) + test_point = {} + np.testing.assert_allclose( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), ) @@ -229,6 +347,15 @@ def test_mixed_dims_via_support_dimension(self): with pytest.raises(NotImplementedError): m.marginalize(x) + def test_mixed_dims_via_nested_marginalization(self): + with MarginalModel() as m: + x = pm.Bernoulli("x", p=0.7, shape=(3,)) + y = pm.Bernoulli("y", p=0.7, shape=(2,)) + z = pm.Normal("z", mu=pt.add.outer(x, y), shape=(3, 2)) + + with pytest.raises(NotImplementedError): + m.marginalize([x, y]) + def test_marginalized_deterministic_and_potential(): rng = np.random.default_rng(299) @@ -531,6 +658,62 @@ def dist(idx, size): pt = {"norm": test_value} np.testing.assert_allclose(logp_fn(pt), ref_logp_fn(pt)) + def test_k_censored_clusters_model(self): + def build_model(build_batched: bool) -> MarginalModel: + data = np.array([[-1.0, -1.0], [0.0, 0.0], [1.0, 1.0]]) + nobs = data.shape[0] + n_clusters = 5 + coords = { + "cluster": range(n_clusters), + "ndim": ("x", "y"), + "obs": range(nobs), + } + with MarginalModel(coords=coords) as m: + if build_batched: + idx = pm.Categorical("idx", p=np.ones(n_clusters) / n_clusters, dims=["obs"]) + else: + idx = pm.math.stack( + [ + pm.Categorical(f"idx_{i}", p=np.ones(n_clusters) / n_clusters) + for i in range(nobs) + ] + ) + + mu_x = pm.Normal( + "mu_x", + dims=["cluster"], + transform=ordered, + initval=np.linspace(-1, 1, n_clusters), + ) + mu_y = pm.Normal("mu_y", dims=["cluster"]) + mu = pm.math.stack([mu_x, mu_y], axis=-1) # (cluster, ndim) + mu_indexed = mu[idx, :] + + sigma = pm.HalfNormal("sigma") + + y = pm.Censored( + "y", + dist=pm.Normal.dist(mu_indexed, sigma), + lower=-3, + upper=3, + observed=data, + dims=["obs", "ndim"], + ) + + return m + + m = build_model(build_batched=True) + ref_m = build_model(build_batched=False) + + m.marginalize([m["idx"]]) + ref_m.marginalize([n for n in ref_m.named_vars if n.startswith("idx_")]) + + test_point = m.initial_point() + np.testing.assert_almost_equal( + m.compile_logp()(test_point), + ref_m.compile_logp()(test_point), + ) + class TestRecoverMarginals: def test_basic(self): @@ -608,7 +791,7 @@ def test_batched(self): with MarginalModel() as m: sigma = pm.HalfNormal("sigma") idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2)) - y = pm.Normal("y", mu=idx, sigma=sigma, shape=(3, 2)) + y = pm.Normal("y", mu=idx.T, sigma=sigma, shape=(2, 3)) m.marginalize([idx]) @@ -626,10 +809,9 @@ def test_batched(self): idata = m.recover_marginals(idata, return_samples=True) post = idata.posterior - assert "idx" in post - assert "lp_idx" in post - assert post.idx.shape == post.y.shape - assert post.lp_idx.shape == (*post.idx.shape, 2) + assert post["y"].shape == (1, 20, 2, 3) + assert post["idx"].shape == (1, 20, 3, 2) + assert post["lp_idx"].shape == (1, 20, 3, 2, 2) def test_nested(self): """Test that marginalization works when there are nested marginalized RVs"""