Skip to content

Commit

Permalink
Fix copying of shared variables in fgraph_from_model (#7153)
Browse files Browse the repository at this point in the history
* Do not use deprecated ScalarSharedVariable

* Recreate SharedVariables with exact type in fgraph_from_model
  • Loading branch information
ricardoV94 authored Feb 13, 2024
1 parent ff99e3b commit 0d8ddba
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 30 deletions.
6 changes: 3 additions & 3 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.sharedvar import ScalarSharedVariable
from pytensor.tensor.variable import TensorConstant, TensorVariable
from typing_extensions import Self

Expand Down Expand Up @@ -999,6 +998,7 @@ def add_coord(
length = pytensor.shared(length, name=name)
else:
length = pytensor.tensor.constant(length)
assert length.type.ndim == 0
self._dim_lengths[name] = length
self._coords[name] = values

Expand Down Expand Up @@ -1028,7 +1028,7 @@ def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] =
coord_values
Optional sequence of coordinate values.
"""
if not isinstance(self.dim_lengths[name], ScalarSharedVariable):
if not isinstance(self.dim_lengths[name], SharedVariable):
raise ValueError(f"The dimension '{name}' is immutable.")
if coord_values is None and self.coords.get(name, None) is not None:
raise ValueError(
Expand Down Expand Up @@ -1188,7 +1188,7 @@ def set_data(
actual=new_length,
expected=old_length,
)
if isinstance(length_tensor, ScalarSharedVariable):
if isinstance(length_tensor, SharedVariable):
# The dimension is mutable, but was defined without being linked
# to a shared variable. This is allowed, but a little less robust.
self.set_dim(dname, new_length, coord_values=new_coords)
Expand Down
33 changes: 18 additions & 15 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy
from copy import copy, deepcopy
from typing import Optional

import pytensor

from pytensor import Variable, shared
from pytensor import Variable
from pytensor.compile import SharedVariable
from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.scalar import Identity
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.sharedvar import ScalarSharedVariable

from pymc.logprob.transforms import Transform
from pymc.model.core import Model
Expand Down Expand Up @@ -113,6 +112,21 @@ def local_remove_identity(fgraph, node):
remove_identity_rewrite = out2in(local_remove_identity)


def deepcopy_shared_variable(var: SharedVariable) -> SharedVariable:
# Shared variables don't have a deepcopy method (SharedVariable.clone reuses the old container and contents).
# We recreate Shared Variables manually after deepcopying their container.
new_var = type(var)(
type=var.type,
value=None,
strict=None,
container=deepcopy(var.container),
name=var.name,
)
assert new_var.type == var.type
new_var.tag = copy(var.tag)
return new_var


def fgraph_from_model(
model: Model, inlined_views=False
) -> tuple[FunctionGraph, dict[Variable, Variable]]:
Expand Down Expand Up @@ -192,18 +206,7 @@ def fgraph_from_model(
shared_vars_to_copy += [v for v in model.dim_lengths.values() if isinstance(v, SharedVariable)]
shared_vars_to_copy += [v for v in model.named_vars.values() if isinstance(v, SharedVariable)]
for var in shared_vars_to_copy:
# FIXME: ScalarSharedVariables are converted to 0d numpy arrays internally,
# so calling shared(shared(5).get_value()) returns a different type: TensorSharedVariables!
# Furthermore, PyMC silently ignores mutable dim changes that are SharedTensorVariables...
# https://github.com/pymc-devs/pytensor/issues/396
if isinstance(var, ScalarSharedVariable):
new_var = shared(var.get_value(borrow=False).item())
else:
new_var = shared(var.get_value(borrow=False))

assert new_var.type == var.type
new_var.name = var.name
new_var.tag = copy(var.tag)
new_var = deepcopy_shared_variable(var)
# We can replace input variables by placing them in the memo
memo[var] = new_var

Expand Down
4 changes: 2 additions & 2 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytensor.raise_op import Assert
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.sharedvar import ScalarSharedVariable
from pytensor.tensor.sharedvar import TensorSharedVariable
from pytensor.tensor.variable import TensorConstant

import pymc as pm
Expand Down Expand Up @@ -823,7 +823,7 @@ def test_add_coord_mutable_kwarg():
m.add_coord("fixed", values=[1], mutable=False)
m.add_coord("mutable1", values=[1, 2], mutable=True)
assert isinstance(m._dim_lengths["fixed"], TensorConstant)
assert isinstance(m._dim_lengths["mutable1"], ScalarSharedVariable)
assert isinstance(m._dim_lengths["mutable1"], TensorSharedVariable)
pm.MutableData("mdata", np.ones((1, 2, 3)), dims=("fixed", "mutable1", "mutable2"))
assert isinstance(m._dim_lengths["mutable2"], TensorVariable)

Expand Down
42 changes: 32 additions & 10 deletions tests/model/test_fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,16 @@ def test_data(inline_views):
with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old:
x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",))
y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",))
sigma = pm.MutableData("sigma", [1.0], shape=(1,))
b0 = pm.ConstantData("b0", np.zeros((1,)))
b1 = pm.DiracDelta("b1", 1.0)
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=y, dims=("test_dim",))

m_fgraph, memo = fgraph_from_model(m_old, inlined_views=inline_views)
assert isinstance(memo[x].owner.op, ModelNamed)
assert isinstance(memo[y].owner.op, ModelNamed)
assert isinstance(memo[sigma].owner.op, ModelNamed)
assert isinstance(memo[b0].owner.op, ModelNamed)
mu_inp = memo[mu].owner.inputs[0]
obs = memo[obs]
Expand All @@ -124,10 +126,13 @@ def test_data(inline_views):
assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x].owner.inputs[0]
# ObservedRV(obs, y, *dims) not ObservedRV(obs, Named(y), *dims)
assert obs.owner.inputs[1] is memo[y].owner.inputs[0]
# ObservedRV(Normal(..., sigma), ...) not ObservedRV(Normal(..., Named(sigma)), ...)
assert obs.owner.inputs[0].owner.inputs[4] is memo[sigma].owner.inputs[0]
else:
assert mu_inp.owner.inputs[0] is memo[b0]
assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x]
assert obs.owner.inputs[1] is memo[y]
assert obs.owner.inputs[0].owner.inputs[4] is memo[sigma]

m_new = model_from_fgraph(m_fgraph)

Expand All @@ -140,9 +145,17 @@ def test_data(inline_views):
# Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory
assert not same_storage(m_new["x"], x)
assert not same_storage(m_new["y"], y)
assert not same_storage(m_new["sigma"], sigma)
assert not same_storage(m_new["b1"].owner.inputs[0], b1.owner.inputs[0])
assert not same_storage(m_new.dim_lengths["test_dim"], m_old.dim_lengths["test_dim"])

# Check they have the same type
assert m_new["x"].type == x.type
assert m_new["y"].type == y.type
assert m_new["sigma"].type == sigma.type
assert m_new["b1"].owner.inputs[0].type == b1.owner.inputs[0].type
assert m_new.dim_lengths["test_dim"].type == m_old.dim_lengths["test_dim"].type

# Updating model shared variables in new model, doesn't affect old one
with m_new:
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
Expand All @@ -155,22 +168,31 @@ def test_data(inline_views):
@config.change_flags(floatX="float64") # Avoid downcasting Ops in the graph
def test_shared_variable():
"""Test that user defined shared variables (other than RNGs) aren't copied."""
x = shared(np.array([1, 2, 3.0]), name="x")
y = shared(np.array([1, 2, 3.0]), name="y")
mu = shared(np.array([1, 2, 3.0]), shape=(None,), name="mu")
sigma = shared(np.array([1.0]), shape=(1,), name="sigma")
obs = shared(np.array([1, 2, 3.0]), shape=(3,), name="obs")

with pm.Model() as m_old:
test = pm.Normal("test", mu=x, observed=y)
test = pm.Normal("test", mu=mu, sigma=sigma, observed=obs)

assert test.owner.inputs[3] is x
assert m_old.rvs_to_values[test] is y
assert test.owner.inputs[3] is mu
assert test.owner.inputs[4] is sigma
assert m_old.rvs_to_values[test] is obs

m_new = clone_model(m_old)
test_new = m_new["test"]
# Shared Variables are cloned but still point to the same memory
assert test_new.owner.inputs[3] is not x
assert m_new.rvs_to_values[test_new] is not y
assert same_storage(test_new.owner.inputs[3], x)
assert same_storage(m_new.rvs_to_values[test_new], y)
mu_new, sigma_new = test_new.owner.inputs[3:5]
obs_new = m_new.rvs_to_values[test_new]
assert mu_new is not mu
assert sigma_new is not sigma
assert obs_new is not obs
assert mu_new.type == mu.type
assert sigma_new.type == sigma.type
assert obs_new.type == obs.type
assert same_storage(mu, mu_new)
assert same_storage(sigma, sigma_new)
assert same_storage(obs, obs_new)


@pytest.mark.parametrize("inline_views", (False, True))
Expand Down

0 comments on commit 0d8ddba

Please sign in to comment.