Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix copying of shared variables in fgraph_from_model #7153

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.sharedvar import ScalarSharedVariable
from pytensor.tensor.variable import TensorConstant, TensorVariable
from typing_extensions import Self

Expand Down Expand Up @@ -999,6 +998,7 @@ def add_coord(
length = pytensor.shared(length, name=name)
else:
length = pytensor.tensor.constant(length)
assert length.type.ndim == 0
self._dim_lengths[name] = length
self._coords[name] = values

Expand Down Expand Up @@ -1028,7 +1028,7 @@ def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] =
coord_values
Optional sequence of coordinate values.
"""
if not isinstance(self.dim_lengths[name], ScalarSharedVariable):
if not isinstance(self.dim_lengths[name], SharedVariable):
raise ValueError(f"The dimension '{name}' is immutable.")
if coord_values is None and self.coords.get(name, None) is not None:
raise ValueError(
Expand Down Expand Up @@ -1188,7 +1188,7 @@ def set_data(
actual=new_length,
expected=old_length,
)
if isinstance(length_tensor, ScalarSharedVariable):
if isinstance(length_tensor, SharedVariable):
# The dimension is mutable, but was defined without being linked
# to a shared variable. This is allowed, but a little less robust.
self.set_dim(dname, new_length, coord_values=new_coords)
Expand Down
33 changes: 18 additions & 15 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy
from copy import copy, deepcopy
from typing import Optional

import pytensor

from pytensor import Variable, shared
from pytensor import Variable
from pytensor.compile import SharedVariable
from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.scalar import Identity
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.sharedvar import ScalarSharedVariable

from pymc.logprob.transforms import Transform
from pymc.model.core import Model
Expand Down Expand Up @@ -113,6 +112,21 @@
remove_identity_rewrite = out2in(local_remove_identity)


def deepcopy_shared_variable(var: SharedVariable) -> SharedVariable:
# Shared variables don't have a deepcopy method (SharedVariable.clone reuses the old container and contents).
# We recreate Shared Variables manually after deepcopying their container.
new_var = type(var)(

Check warning on line 118 in pymc/model/fgraph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/fgraph.py#L118

Added line #L118 was not covered by tests
type=var.type,
value=None,
strict=None,
container=deepcopy(var.container),
name=var.name,
)
assert new_var.type == var.type
new_var.tag = copy(var.tag)
return new_var

Check warning on line 127 in pymc/model/fgraph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/fgraph.py#L125-L127

Added lines #L125 - L127 were not covered by tests


def fgraph_from_model(
model: Model, inlined_views=False
) -> tuple[FunctionGraph, dict[Variable, Variable]]:
Expand Down Expand Up @@ -192,18 +206,7 @@
shared_vars_to_copy += [v for v in model.dim_lengths.values() if isinstance(v, SharedVariable)]
shared_vars_to_copy += [v for v in model.named_vars.values() if isinstance(v, SharedVariable)]
for var in shared_vars_to_copy:
# FIXME: ScalarSharedVariables are converted to 0d numpy arrays internally,
# so calling shared(shared(5).get_value()) returns a different type: TensorSharedVariables!
# Furthermore, PyMC silently ignores mutable dim changes that are SharedTensorVariables...
# https://github.com/pymc-devs/pytensor/issues/396
if isinstance(var, ScalarSharedVariable):
new_var = shared(var.get_value(borrow=False).item())
else:
new_var = shared(var.get_value(borrow=False))

assert new_var.type == var.type
new_var.name = var.name
new_var.tag = copy(var.tag)
new_var = deepcopy_shared_variable(var)

Check warning on line 209 in pymc/model/fgraph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/fgraph.py#L209

Added line #L209 was not covered by tests
# We can replace input variables by placing them in the memo
memo[var] = new_var

Expand Down
4 changes: 2 additions & 2 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytensor.raise_op import Assert
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.sharedvar import ScalarSharedVariable
from pytensor.tensor.sharedvar import TensorSharedVariable
from pytensor.tensor.variable import TensorConstant

import pymc as pm
Expand Down Expand Up @@ -823,7 +823,7 @@ def test_add_coord_mutable_kwarg():
m.add_coord("fixed", values=[1], mutable=False)
m.add_coord("mutable1", values=[1, 2], mutable=True)
assert isinstance(m._dim_lengths["fixed"], TensorConstant)
assert isinstance(m._dim_lengths["mutable1"], ScalarSharedVariable)
assert isinstance(m._dim_lengths["mutable1"], TensorSharedVariable)
pm.MutableData("mdata", np.ones((1, 2, 3)), dims=("fixed", "mutable1", "mutable2"))
assert isinstance(m._dim_lengths["mutable2"], TensorVariable)

Expand Down
42 changes: 32 additions & 10 deletions tests/model/test_fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,16 @@ def test_data(inline_views):
with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old:
x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",))
y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",))
sigma = pm.MutableData("sigma", [1.0], shape=(1,))
b0 = pm.ConstantData("b0", np.zeros((1,)))
b1 = pm.DiracDelta("b1", 1.0)
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=y, dims=("test_dim",))

m_fgraph, memo = fgraph_from_model(m_old, inlined_views=inline_views)
assert isinstance(memo[x].owner.op, ModelNamed)
assert isinstance(memo[y].owner.op, ModelNamed)
assert isinstance(memo[sigma].owner.op, ModelNamed)
assert isinstance(memo[b0].owner.op, ModelNamed)
mu_inp = memo[mu].owner.inputs[0]
obs = memo[obs]
Expand All @@ -124,10 +126,13 @@ def test_data(inline_views):
assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x].owner.inputs[0]
# ObservedRV(obs, y, *dims) not ObservedRV(obs, Named(y), *dims)
assert obs.owner.inputs[1] is memo[y].owner.inputs[0]
# ObservedRV(Normal(..., sigma), ...) not ObservedRV(Normal(..., Named(sigma)), ...)
assert obs.owner.inputs[0].owner.inputs[4] is memo[sigma].owner.inputs[0]
else:
assert mu_inp.owner.inputs[0] is memo[b0]
assert mu_inp.owner.inputs[1].owner.inputs[1] is memo[x]
assert obs.owner.inputs[1] is memo[y]
assert obs.owner.inputs[0].owner.inputs[4] is memo[sigma]

m_new = model_from_fgraph(m_fgraph)

Expand All @@ -140,9 +145,17 @@ def test_data(inline_views):
# Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory
assert not same_storage(m_new["x"], x)
assert not same_storage(m_new["y"], y)
assert not same_storage(m_new["sigma"], sigma)
assert not same_storage(m_new["b1"].owner.inputs[0], b1.owner.inputs[0])
assert not same_storage(m_new.dim_lengths["test_dim"], m_old.dim_lengths["test_dim"])

# Check they have the same type
assert m_new["x"].type == x.type
assert m_new["y"].type == y.type
assert m_new["sigma"].type == sigma.type
assert m_new["b1"].owner.inputs[0].type == b1.owner.inputs[0].type
assert m_new.dim_lengths["test_dim"].type == m_old.dim_lengths["test_dim"].type

# Updating model shared variables in new model, doesn't affect old one
with m_new:
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
Expand All @@ -155,22 +168,31 @@ def test_data(inline_views):
@config.change_flags(floatX="float64") # Avoid downcasting Ops in the graph
def test_shared_variable():
"""Test that user defined shared variables (other than RNGs) aren't copied."""
x = shared(np.array([1, 2, 3.0]), name="x")
y = shared(np.array([1, 2, 3.0]), name="y")
mu = shared(np.array([1, 2, 3.0]), shape=(None,), name="mu")
sigma = shared(np.array([1.0]), shape=(1,), name="sigma")
obs = shared(np.array([1, 2, 3.0]), shape=(3,), name="obs")

with pm.Model() as m_old:
test = pm.Normal("test", mu=x, observed=y)
test = pm.Normal("test", mu=mu, sigma=sigma, observed=obs)

assert test.owner.inputs[3] is x
assert m_old.rvs_to_values[test] is y
assert test.owner.inputs[3] is mu
assert test.owner.inputs[4] is sigma
assert m_old.rvs_to_values[test] is obs

m_new = clone_model(m_old)
test_new = m_new["test"]
# Shared Variables are cloned but still point to the same memory
assert test_new.owner.inputs[3] is not x
assert m_new.rvs_to_values[test_new] is not y
assert same_storage(test_new.owner.inputs[3], x)
assert same_storage(m_new.rvs_to_values[test_new], y)
mu_new, sigma_new = test_new.owner.inputs[3:5]
obs_new = m_new.rvs_to_values[test_new]
assert mu_new is not mu
assert sigma_new is not sigma
assert obs_new is not obs
assert mu_new.type == mu.type
assert sigma_new.type == sigma.type
assert obs_new.type == obs.type
assert same_storage(mu, mu_new)
assert same_storage(sigma, sigma_new)
assert same_storage(obs, obs_new)


@pytest.mark.parametrize("inline_views", (False, True))
Expand Down
Loading