-
-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add option to prune variables after do intervention
- Loading branch information
1 parent
5d937b2
commit eda61e7
Showing
3 changed files
with
73 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from pymc import Model | ||
from pytensor.graph import ancestors | ||
|
||
from pymc_experimental.utils.model_fgraph import ( | ||
ModelObservedRV, | ||
ModelVar, | ||
fgraph_from_model, | ||
model_from_fgraph, | ||
) | ||
|
||
|
||
def prune_vars_detached_from_observed(model: Model) -> Model: | ||
"""Prune model variables that are not related to any observed variable in the Model.""" | ||
|
||
# Potentials are ambiguous as whether they correspond to likelihood or prior terms, | ||
# We simply raise for now | ||
if model.potentials: | ||
raise NotImplementedError("Pruning not implemented for models with Potentials") | ||
|
||
fgraph, _ = fgraph_from_model(model, inlined_views=True) | ||
observed_vars = ( | ||
out | ||
for node in fgraph.apply_nodes | ||
if isinstance(node.op, ModelObservedRV) | ||
for out in node.outputs | ||
) | ||
ancestor_nodes = {var.owner for var in ancestors(observed_vars)} | ||
nodes_to_remove = { | ||
node | ||
for node in fgraph.apply_nodes | ||
if isinstance(node.op, ModelVar) and node not in ancestor_nodes | ||
} | ||
for node_to_remove in nodes_to_remove: | ||
fgraph.remove_node(node_to_remove) | ||
return model_from_fgraph(fgraph) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters