Skip to content

Commit

Permalink
.WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 18, 2024
1 parent 6169cf3 commit 01ed4c0
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 110 deletions.
21 changes: 8 additions & 13 deletions pymc_experimental/model/marginal/distributions.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
from typing import Sequence
from collections.abc import Sequence

import numpy as np
import pytensor.tensor as pt
from pymc.distributions import (
Bernoulli,
Categorical,
DiscreteUniform,
SymbolicRandomVariable
)
from pymc.logprob.basic import conditional_logp, logp

from pymc.distributions import Bernoulli, Categorical, DiscreteUniform, SymbolicRandomVariable
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.scan import scan, map as scan_map
from pytensor.compile.mode import Mode
from pytensor.graph import vectorize_graph
from pytensor.tensor import TensorVariable, TensorType
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.scan import map as scan_map
from pytensor.scan import scan
from pytensor.tensor import TensorType, TensorVariable

from pymc_experimental.distributions import DiscreteMarkovChain

Expand Down Expand Up @@ -80,8 +77,6 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
inner_rv_values = dict(zip(inner_rvs, values))
marginalized_vv = marginalized_rv.clone()
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
print("")
print("Inner conditional logp call >> ")
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)

# Reduce logp dimensions corresponding to broadcasted variables
Expand Down
88 changes: 57 additions & 31 deletions pymc_experimental/model/marginal/graph_analysis.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from itertools import zip_longest, chain
from typing import Sequence
from collections.abc import Sequence
from itertools import chain, zip_longest

from pymc import SymbolicRandomVariable
from pytensor.compile import SharedVariable
from pytensor.graph import ancestors, Constant, graph_inputs, Variable
from pytensor.graph import Constant, Variable, ancestors, graph_inputs
from pytensor.graph.basic import io_toposort
from pytensor.tensor import TensorVariable, TensorType
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise, CAReduce
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import Subtensor, get_idx_list, AdvancedSubtensor
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list
from pytensor.tensor.type_other import NoneTypeT

from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV


def static_shape_ancestors(vars):
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
Expand Down Expand Up @@ -58,7 +60,6 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs):
]



def collect_shared_vars(outputs, blockers):
return [
inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)
Expand Down Expand Up @@ -86,32 +87,24 @@ def _advanced_indexing_axis_and_ndim(idxs) -> tuple[int, int]:
return adv_group_axis, adv_group_ndim


def _broadcast_dims(inputs_dims: Sequence[tuple[tuple[int, ...], ...]]) -> tuple[tuple[int, ...], ...]:
def _broadcast_dims(
inputs_dims: Sequence[tuple[tuple[int, ...], ...]],
) -> tuple[tuple[int, ...], ...]:
output_ndim = max((len(input_dim) for input_dim in inputs_dims), default=0)
# Add missing dims
inputs_dims = [
((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims
]
inputs_dims = [((),) * (output_ndim - len(input_dim)) + input_dim for input_dim in inputs_dims]
# Combine aligned dims
output_dims = tuple(tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims))
output_dims = tuple(
tuple(sorted(set(chain.from_iterable(inputs_dim)))) for inputs_dim in zip(*inputs_dims)
)
return output_dims


def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[tuple[int, ...], ...]]:
"""Identify how the dims of rv_to_marginalize are consumed by the dims of the output_rvs.
Raises
------
NotImplementedError
If variable related to marginalized batch_dims is used in an operation that is not yet supported
"""
VAR_DIMS = dict[Variable, tuple[tuple[int, ...], ...]]

var_dims: dict[Variable, tuple[tuple[int, ...], ...]] = {
input_var: tuple((i,) for i in range(input_var.type.ndim))
}

for node in io_toposort([input_var, *other_inputs], output_vars):
def _subgraph_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) -> VAR_DIMS:
for node in io_toposort(input_vars, output_vars):
inputs_dims = [var_dims.get(inp, ()) for inp in node.inputs]

if not any(inputs_dims):
Expand All @@ -126,6 +119,20 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
)
var_dims[node.outputs[0]] = output_dims

elif isinstance(node.op, FiniteDiscreteMarginalRV):
# FiniteDiscreteMarginalRV does not behave like a standard SymbolicRandomVariable, due to how we truncate the graph.
# We analyze the inner graph of the Marginalized RV to find the true dim connections
inner_var_dims = {
inner_inp: input_dims
for inner_inp, input_dims in zip(node.op.inner_inputs, inputs_dims)
}
inner_var_dims = _subgraph_dim_connection(
inner_var_dims, node.op.inner_inputs, node.op.inner_outputs
)
for out, inner_out in zip(node.outputs, node.op.inner_outputs):
if inner_out in inner_var_dims:
var_dims[out] = inner_var_dims[inner_out]

elif isinstance(node.op, Elemwise | Blockwise | RandomVariable | SymbolicRandomVariable):
# NOTE: User-provided CustomDist may not respect core dimensions on the left.

Expand All @@ -135,13 +142,16 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
op_batch_ndim = node.op.batch_ndim(node)

# Collapse all core_dims
core_dims = tuple(sorted(chain.from_iterable([i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]])))
batch_dims = _broadcast_dims(
tuple(
input_dims[:op_batch_ndim]
for input_dims in inputs_dims
core_dims = tuple(
sorted(
chain.from_iterable(
[i for input_dim in inputs_dims for i in input_dim[op_batch_ndim:]]
)
)
)
batch_dims = _broadcast_dims(
tuple(input_dims[:op_batch_ndim] for input_dims in inputs_dims)
)
# Add batch dims to each output_dims
batch_dims = tuple(batch_dim + core_dims for batch_dim in batch_dims)
for out in node.outputs:
Expand Down Expand Up @@ -221,7 +231,7 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
elif value_dim:
# We are trying to partially slice or index a known dimension
raise NotImplementedError(
f"Partial slicing or advanced integer indexing of known dimensions not supported"
"Partial slicing or advanced integer indexing of known dimensions not supported"
)
elif isinstance(idx, slice):
# Unknown dimensions kept by partial slice.
Expand Down Expand Up @@ -252,4 +262,20 @@ def subgraph_dim_connection(input_var, other_inputs, output_vars) -> list[tuple[
else:
raise NotImplementedError(f"Marginalization through operation {node} not supported")

return var_dims


def subgraph_dim_connection(
input_var, other_inputs, output_vars
) -> list[tuple[tuple[int, ...], ...]]:
"""Identify how the dims of rv_to_marginalize are consumed by the dims of the output_rvs.
Raises
------
NotImplementedError
If variable related to marginalized batch_dims is used in an operation that is not yet supported
"""
var_dims = {input_var: tuple((i,) for i in range(input_var.type.ndim))}
var_dims = _subgraph_dim_connection(var_dims, [input_var, *other_inputs], output_vars)
return [var_dims[output_rv] for output_rv in output_vars]
86 changes: 34 additions & 52 deletions pymc_experimental/model/marginal/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,29 @@
from pymc.distributions.transforms import Chain
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Model
from pymc.pytensorf import compile_pymc, constant_fold, toposort_replace
from pymc.pytensorf import compile_pymc, constant_fold
from pymc.util import RandomState, _get_seeds_per_chain, treedict
from pytensor.graph import FunctionGraph, clone_replace
from pytensor.graph.basic import truncated_graph_inputs, Constant, ancestors
from pytensor.graph.replace import vectorize_graph
from pytensor.tensor import TensorVariable, extract_constant
from pytensor.tensor import TensorVariable
from pytensor.tensor.special import log_softmax

__all__ = ["MarginalModel", "marginalize"]

from pymc_experimental.distributions import DiscreteMarkovChain
from pymc_experimental.model.marginal.distributions import FiniteDiscreteMarginalRV, DiscreteMarginalMarkovChainRV, \
get_domain_of_finite_discrete_rv, _add_reduce_batch_dependent_logps
from pymc_experimental.model.marginal.graph_analysis import find_conditional_input_rvs, is_conditional_dependent, \
find_conditional_dependent_rvs, subgraph_dim_connection, collect_shared_vars
from pymc_experimental.model.marginal.distributions import (
DiscreteMarginalMarkovChainRV,
FiniteDiscreteMarginalRV,
_add_reduce_batch_dependent_logps,
get_domain_of_finite_discrete_rv,
)
from pymc_experimental.model.marginal.graph_analysis import (
collect_shared_vars,
find_conditional_dependent_rvs,
find_conditional_input_rvs,
is_conditional_dependent,
subgraph_dim_connection,
)

ModelRVs = TensorVariable | Sequence[TensorVariable] | str | Sequence[str]

Expand Down Expand Up @@ -537,10 +545,6 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:


def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
# TODO: This should eventually be integrated in a more general routine that can
# identify other types of supported marginalization, of which finite discrete
# RVs is just one

dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
if not dependent_rvs:
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")
Expand All @@ -552,7 +556,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
if rv is not rv_to_marginalize
]

if all (rv_to_marginalize.type.broadcastable):
if rv_to_marginalize.type.ndim == 0:
ndim_supp = max([dependent_rv.type.ndim for dependent_rv in dependent_rvs])
else:
# If the marginalized RV has multiple dimensions, check that graph between
Expand All @@ -561,23 +565,27 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
dependent_rvs_dim_connections = subgraph_dim_connection(
rv_to_marginalize, other_direct_rv_ancestors, dependent_rvs
)
# dependent_rvs_dim_connections = subgraph_dim_connection(
# rv_to_marginalize, other_inputs, dependent_rvs
# )

ndim_supp = max((dependent_rv.type.ndim - rv_to_marginalize.type.ndim) for dependent_rv in dependent_rvs)
ndim_supp = max(
(dependent_rv.type.ndim - rv_to_marginalize.type.ndim) for dependent_rv in dependent_rvs
)

if any(len(dim) > 1 for rv_dim_connections in dependent_rvs_dim_connections for dim in rv_dim_connections):
if any(
len(dim) > 1
for rv_dim_connections in dependent_rvs_dim_connections
for dim in rv_dim_connections
):
raise NotImplementedError("Multiple dimensions are mixed")

# We further check that:
# 1) Dimensions of dependent RVs are aligned with those of the marginalized RV
# 2) Any extra batch dimensions of dependent RVs beyond those implied by the MarginalizedRV
# show up on the right, so that collapsing logic in logp can be more straightforward.
# This also ensures the MarginalizedRV still behaves as an RV itself
marginal_batch_ndim = rv_to_marginalize.owner.op.batch_ndim(rv_to_marginalize.owner)
marginal_batch_dims = tuple((i,) for i in range(marginal_batch_ndim))
for dependent_rv, dependent_rv_batch_dims in zip(dependent_rvs, dependent_rvs_dim_connections):
for dependent_rv, dependent_rv_batch_dims in zip(
dependent_rvs, dependent_rvs_dim_connections
):
extra_batch_ndim = dependent_rv.type.ndim - marginal_batch_ndim
valid_dependent_batch_dims = marginal_batch_dims + (((),) * extra_batch_ndim)
if dependent_rv_batch_dims != valid_dependent_batch_dims:
Expand All @@ -587,47 +595,21 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
)

input_rvs = [*marginalized_rv_input_rvs, *other_direct_rv_ancestors]
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]
output_rvs = [rv_to_marginalize, *dependent_rvs]

outputs = rvs_to_marginalize
# We are strict about shared variables in SymbolicRandomVariables
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)
# inputs = [
# inp
# for rv in rvs_to_marginalize # should be toposort
# for inp in rv.owner.inputs
# if not(all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
# ]
# inputs = [
# inp for inp in truncated_graph_inputs(outputs, ancestors_to_include=inputs)
# if not (all(isinstance(a, Constant) for a in ancestors([inp], blockers=all_rvs)))
# ]
# inputs = truncated_graph_inputs(outputs, ancestors_to_include=[
# # inp
# # for output in outputs
# # for inp in output.owner.inputs
# # ])
# inputs = [inp for inp in inputs if not isinstance(constant_fold([inp], raise_not_constant=False)[0], Constant | np.ndarray)]
inputs = input_rvs + collect_shared_vars(output_rvs, blockers=input_rvs)

if isinstance(rv_to_marginalize.owner.op, DiscreteMarkovChain):
marginalize_constructor = DiscreteMarginalMarkovChainRV
else:
marginalize_constructor = FiniteDiscreteMarginalRV

marginalization_op = marginalize_constructor(
inputs=inputs,
outputs=outputs,
outputs=output_rvs, # TODO: Add RNG updates to outputs
ndim_supp=ndim_supp,
)

marginalized_rvs = marginalization_op(*inputs)
print()
import pytensor
pytensor.dprint(marginalized_rvs, print_type=True)
fgraph.replace_all(reversed(tuple(zip(rvs_to_marginalize, marginalized_rvs))))
# assert 0
# fgraph.dprint()
# assert 0
# toposort_replace(fgraph, tuple(zip(rvs_to_marginalize, marginalized_rvs)))
# assert 0
return rvs_to_marginalize, marginalized_rvs

new_output_rvs = marginalization_op(*inputs)
fgraph.replace_all(tuple(zip(output_rvs, new_output_rvs)))
return output_rvs, new_output_rvs
Loading

0 comments on commit 01ed4c0

Please sign in to comment.