Skip to content

Commit

Permalink
Add option to prune variables after do intervention
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 2, 2023
1 parent 5d937b2 commit eda61e7
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
35 changes: 35 additions & 0 deletions pymc_experimental/model_transform/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pymc import Model
from pytensor.graph import ancestors

from pymc_experimental.utils.model_fgraph import (
ModelObservedRV,
ModelVar,
fgraph_from_model,
model_from_fgraph,
)


def prune_vars_detached_from_observed(model: Model) -> Model:
"""Prune model variables that are not related to any observed variable in the Model."""

# Potentials are ambiguous as whether they correspond to likelihood or prior terms,
# We simply raise for now
if model.potentials:
raise NotImplementedError("Pruning not implemented for models with Potentials")

fgraph, _ = fgraph_from_model(model, inlined_views=True)
observed_vars = (
out
for node in fgraph.apply_nodes
if isinstance(node.op, ModelObservedRV)
for out in node.outputs
)
ancestor_nodes = {var.owner for var in ancestors(observed_vars)}
nodes_to_remove = {
node
for node in fgraph.apply_nodes
if isinstance(node.op, ModelVar) and node not in ancestor_nodes
}
for node_to_remove in nodes_to_remove:
fgraph.remove_node(node_to_remove)
return model_from_fgraph(fgraph)
13 changes: 11 additions & 2 deletions pymc_experimental/model_transform/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pymc.pytensorf import _replace_vars_in_graphs
from pytensor.tensor import TensorVariable

from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed
from pymc_experimental.utils.model_fgraph import (
ModelDeterministic,
ModelFreeRV,
Expand Down Expand Up @@ -113,7 +114,9 @@ def replacement_fn(var, inner_replacements):
return replaced_graphs


def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any]) -> Model:
def do(
model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any], prune_vars=False
) -> Model:
"""Replace model variables by intervention variables.
Intervention variables will either show up as `Data` or `Deterministics` in the new model,
Expand All @@ -126,6 +129,9 @@ def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], A
Dictionary that maps model variables (or names) to intervention expressions.
Intervention expressions must have a shape and data type that is compatible
with the original model variable.
prune_vars: bool, defaults to False
Whether to prune model variables that are not connected to any observed variables,
after the interventions.
Returns
-------
Expand Down Expand Up @@ -196,4 +202,7 @@ def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], A
# Replace variables by interventions
toposort_replace(fgraph, tuple(replacements.items()))

return model_from_fgraph(fgraph)
model = model_from_fgraph(fgraph)
if prune_vars:
return prune_vars_detached_from_observed(model)
return model
27 changes: 27 additions & 0 deletions pymc_experimental/tests/model_transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,30 @@ def test_do_dims():
},
)
assert do_m.named_vars_to_dims["y"] == ["test_dim"]


@pytest.mark.parametrize("prune", (False, True))
def test_do_prune(prune):

with pm.Model() as m:
x0 = pm.ConstantData("x0", 0)
x1 = pm.ConstantData("x1", 0)
y = pm.Normal("y")
y_det = pm.Deterministic("y_det", y + x0)
z = pm.Normal("z", y_det)
llike = pm.Normal("llike", z + x1, observed=0)

orig_named_vars = {"x0", "x1", "y", "y_det", "z", "llike"}
assert set(m.named_vars) == orig_named_vars

do_m = do(m, {y_det: x0 + 5}, prune_vars=prune)
if prune:
assert set(do_m.named_vars) == {"x0", "x1", "y_det", "z", "llike"}
else:
assert set(do_m.named_vars) == orig_named_vars

do_m = do(m, {z: 0.5}, prune_vars=prune)
if prune:
assert set(do_m.named_vars) == {"x1", "z", "llike"}
else:
assert set(do_m.named_vars) == orig_named_vars

0 comments on commit eda61e7

Please sign in to comment.