Skip to content

Commit

Permalink
Replace rvs_to_total_sizes mapping by ManibatchRandomVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 28, 2023
1 parent 33d641d commit ecba5bb
Show file tree
Hide file tree
Showing 15 changed files with 320 additions and 323 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ jobs:
- |
tests/sampling/test_parallel.py
tests/test_data.py
tests/variational/test_minibatch_rv.py
tests/test_model.py
- |
Expand Down
87 changes: 4 additions & 83 deletions pymc/logprob/joint_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from collections import deque
from typing import Dict, List, Optional, Sequence, Union

import numpy as np
import pytensor
import pytensor.tensor as pt

Expand All @@ -55,7 +54,6 @@
from pymc.logprob.rewriting import construct_ir_fgraph
from pymc.logprob.transforms import RVTransform, TransformValuesRewrite
from pymc.logprob.utils import rvs_to_value_vars
from pymc.pytensorf import floatX


def logp(rv: TensorVariable, value) -> TensorVariable:
Expand Down Expand Up @@ -248,77 +246,6 @@ def factorized_joint_logprob(
return logprob_vars


TOTAL_SIZE = Union[int, Sequence[int], None]


def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable:
"""
Gets scaling constant for logp.
Parameters
----------
total_size: Optional[int|List[int]]
size of a fully observed data without minibatching,
`None` means data is fully observed
shape: shape
shape of an observed data
ndim: int
ndim hint
Returns
-------
scalar
"""
if total_size is None:
coef = 1.0
elif isinstance(total_size, int):
if ndim >= 1:
denom = shape[0]
else:
denom = 1
coef = floatX(total_size) / floatX(denom)
elif isinstance(total_size, (list, tuple)):
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
raise TypeError(
"Unrecognized `total_size` type, expected "
"int or list of ints, got %r" % total_size
)
if Ellipsis in total_size:
sep = total_size.index(Ellipsis)
begin = total_size[:sep]
end = total_size[sep + 1 :]
if Ellipsis in end:
raise ValueError(
"Double Ellipsis in `total_size` is restricted, got %r" % total_size
)
else:
begin = total_size
end = []
if (len(begin) + len(end)) > ndim:
raise ValueError(
"Length of `total_size` is too big, "
"number of scalings is bigger that ndim, got %r" % total_size
)
elif (len(begin) + len(end)) == 0:
coef = 1.0
if len(end) > 0:
shp_end = shape[-len(end) :]
else:
shp_end = np.asarray([])
shp_begin = shape[: len(begin)]
begin_coef = [
floatX(t) / floatX(shp_begin[i]) for i, t in enumerate(begin) if t is not None
]
end_coef = [floatX(t) / floatX(shp_end[i]) for i, t in enumerate(end) if t is not None]
coefs = begin_coef + end_coef
coef = pt.prod(coefs)
else:
raise TypeError(
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
)
return pt.as_tensor(coef, dtype=pytensor.config.floatX)


def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
# Raise if there are unexpected RandomVariables in the logp graph
# Only SimulatorRVs MinibatchIndexRVs are allowed
Expand Down Expand Up @@ -348,7 +275,6 @@ def joint_logp(
rvs_to_values: Dict[TensorVariable, TensorVariable],
rvs_to_transforms: Dict[TensorVariable, RVTransform],
jacobian: bool = True,
rvs_to_total_sizes: Dict[TensorVariable, TOTAL_SIZE],
**kwargs,
) -> List[TensorVariable]:
"""Thin wrapper around pymc.logprob.factorized_joint_logprob, extended with Model
Expand All @@ -371,18 +297,13 @@ def joint_logp(
**kwargs,
)

# The function returns the logp for every single value term we provided to it. This
# includes the extra values we plugged in above, so we filter those we actually
# wanted in the same order they were given in.
# The function returns the logp for every single value term we provided to it.
# This includes the extra values we plugged in above, so we filter those we
# actually wanted in the same order they were given in.
logp_terms = {}
for rv in rvs:
value_var = rvs_to_values[rv]
logp_term = temp_logp_terms[value_var]
total_size = rvs_to_total_sizes.get(rv, None)
if total_size is not None:
scaling = _get_scaling(total_size, value_var.shape, value_var.ndim)
logp_term *= scaling
logp_terms[value_var] = logp_term
logp_terms[value_var] = temp_logp_terms[value_var]

_check_no_rvs(list(logp_terms.values()))
return list(logp_terms.values())
31 changes: 19 additions & 12 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,6 @@ def __init__(
self.values_to_rvs = treedict(parent=self.parent.values_to_rvs)
self.rvs_to_values = treedict(parent=self.parent.rvs_to_values)
self.rvs_to_transforms = treedict(parent=self.parent.rvs_to_transforms)
self.rvs_to_total_sizes = treedict(parent=self.parent.rvs_to_total_sizes)
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
self.free_RVs = treelist(parent=self.parent.free_RVs)
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
Expand All @@ -578,7 +577,6 @@ def __init__(
self.values_to_rvs = treedict()
self.rvs_to_values = treedict()
self.rvs_to_transforms = treedict()
self.rvs_to_total_sizes = treedict()
self.rvs_to_initial_values = treedict()
self.free_RVs = treelist()
self.observed_RVs = treelist()
Expand Down Expand Up @@ -762,7 +760,6 @@ def logp(
rvs=rvs,
rvs_to_values=self.rvs_to_values,
rvs_to_transforms=self.rvs_to_transforms,
rvs_to_total_sizes=self.rvs_to_total_sizes,
jacobian=jacobian,
)
assert isinstance(rv_logps, list)
Expand Down Expand Up @@ -1314,8 +1311,6 @@ def register_rv(
name = self.name_for(name)
rv_var.name = name
_add_future_warning_tag(rv_var)
rv_var.tag.total_size = total_size
self.rvs_to_total_sizes[rv_var] = total_size

# Associate previously unknown dimension names with
# the length of the corresponding RV dimension.
Expand All @@ -1327,6 +1322,8 @@ def register_rv(
self.add_coord(dname, values=None, length=rv_var.shape[d])

if observed is None:
if total_size is not None:
raise ValueError("total_size can only be passed to observed RVs")
self.free_RVs.append(rv_var)
self.create_value_var(rv_var, transform)
self.add_named_variable(rv_var, dims)
Expand All @@ -1351,12 +1348,17 @@ def register_rv(

# `rv_var` is potentially changed by `make_obs_var`,
# for example into a new graph for imputation of missing data.
rv_var = self.make_obs_var(rv_var, observed, dims, transform)
rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)

return rv_var

def make_obs_var(
self, rv_var: TensorVariable, data: np.ndarray, dims, transform: Optional[Any]
self,
rv_var: TensorVariable,
data: np.ndarray,
dims,
transform: Union[Any, None],
total_size: Union[int, None],
) -> TensorVariable:
"""Create a `TensorVariable` for an observed random variable.
Expand Down Expand Up @@ -1392,18 +1394,16 @@ def make_obs_var(

mask = getattr(data, "mask", None)
if mask is not None:
if mask.all():
# If there are no observed values, this variable isn't really
# observed.
return rv_var

impute_message = (
f"Data in {rv_var} contains missing values and"
" will be automatically imputed from the"
" sampling distribution."
)
warnings.warn(impute_message, ImputationWarning)

if total_size is not None:
raise ValueError("total_size is not compatible with imputed variables")

if not isinstance(rv_var.owner.op, RandomVariable):
raise NotImplementedError(
"Automatic inputation is only supported for univariate RandomVariables."
Expand Down Expand Up @@ -1471,6 +1471,13 @@ def make_obs_var(
data = sparse.basic.as_sparse(data, name=name)
else:
data = at.as_tensor_variable(data, name=name)

if total_size:
from pymc.variational.minibatch_rv import create_minibatch_rv

rv_var = create_minibatch_rv(rv_var, total_size)
rv_var.name = name

rv_var.tag.observations = data
self.create_value_var(rv_var, transform=None, value_var=data)
self.add_named_variable(rv_var, dims)
Expand Down
1 change: 0 additions & 1 deletion pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@ def __getattribute__(self, name):
for deprecated_names, alternative in (
(("value_var", "observations"), "model.rvs_to_values[rv]"),
(("transform",), "model.rvs_to_transforms[rv]"),
(("total_size",), "model.rvs_to_total_sizes[rv]"),
):
if name in deprecated_names:
try:
Expand Down
113 changes: 113 additions & 0 deletions pymc/variational/minibatch_rv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Sequence, Union, cast

import pytensor.tensor as pt

from pytensor import Variable, config
from pytensor.graph import Apply, Op
from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable

from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob
from pymc.logprob.abstract import logprob as logprob_logprob
from pymc.logprob.utils import ignore_logprob


class MinibatchRandomVariable(Op):
"""RV whose logprob should be rescaled to match total_size"""

__props__ = ()
view_map = {0: [0]}

def make_node(self, rv, *total_size):
rv = as_tensor_variable(rv)
total_size = [
as_tensor_variable(t, dtype="int64", ndim=0) if t is not None else NoneConst
for t in total_size
]
assert len(total_size) == rv.ndim
out = rv.type()
return Apply(self, [rv, *total_size], [out])

def perform(self, node, inputs, output_storage):
output_storage[0][0] = inputs[0]


minibatch_rv = MinibatchRandomVariable()


EllipsisType = Any # EllipsisType is not present in Python 3.8 yet


def create_minibatch_rv(
rv: TensorVariable,
total_size: Union[int, None, Sequence[Union[int, EllipsisType, None]]],
) -> TensorVariable:
"""Create variable whose logp is rescaled by total_size."""
if isinstance(total_size, int):
if rv.ndim <= 1:
total_size = [total_size]
else:
missing_ndims = rv.ndim - 1
total_size = [total_size] + [None] * missing_ndims
elif isinstance(total_size, (list, tuple)):
total_size = list(total_size)
if Ellipsis in total_size:
# Replace Ellipsis by None
if total_size.count(Ellipsis) > 1:
raise ValueError("Only one Ellipsis can be present in total_size")
sep = total_size.index(Ellipsis)
begin = total_size[:sep]
end = total_size[sep + 1 :]
missing_ndims = max((rv.ndim - len(begin) - len(end), 0))
total_size = begin + [None] * missing_ndims + end
if len(total_size) > rv.ndim:
raise ValueError(f"Length of total_size {total_size} is langer than RV ndim {rv.ndim}")
else:
raise TypeError(f"Invalid type for total_size: {total_size}")

rv = ignore_logprob(rv)

return cast(TensorVariable, minibatch_rv(rv, *total_size))


def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> TensorVariable:
"""Gets scaling constant for logp."""

# mypy doesn't understand we can convert a shape TensorVariable into a tuple
shape = tuple(shape) # type: ignore

# Scalar RV
if len(shape) == 0: # type: ignore
coef = total_size[0] if not NoneConst.equals(total_size[0]) else 1.0
else:
coefs = [t / shape[i] for i, t in enumerate(total_size) if not NoneConst.equals(t)]
coef = pt.prod(coefs)

return pt.cast(coef, dtype=config.floatX)


MeasurableVariable.register(MinibatchRandomVariable)


@_get_measurable_outputs.register(MinibatchRandomVariable)
def _get_measurable_outputs_minibatch_random_variable(op, node):
return [node.outputs[0]]


@_logprob.register(MinibatchRandomVariable)
def minibatch_rv_logprob(op, values, *inputs, **kwargs):
[value] = values
rv, *total_size = inputs
return logprob_logprob(rv, value, **kwargs) * get_scaling(total_size, value.shape)
13 changes: 6 additions & 7 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
from pymc.backends.ndarray import NDArray
from pymc.blocking import DictToArrayBijection
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.joint_logprob import _get_scaling
from pymc.model import modelcontext
from pymc.pytensorf import (
SeedSequenceSeed,
Expand All @@ -82,6 +81,7 @@
_get_seeds_per_chain,
locally_cachedmethod,
)
from pymc.variational.minibatch_rv import MinibatchRandomVariable, get_scaling
from pymc.variational.updates import adagrad_window
from pymc.vartypes import discrete_types

Expand Down Expand Up @@ -1069,9 +1069,11 @@ def symbolic_normalizing_constant(self):
t = self.to_flat_input(
at.max(
[
_get_scaling(self.model.rvs_to_total_sizes.get(v, None), v.shape, v.ndim)
get_scaling(v.owner.inputs[1:], v.shape)
for v in self.group
if isinstance(v.owner.op, MinibatchRandomVariable)
]
+ [1.0] # To avoid empty max
)
)
t = self.symbolic_single_sample(t)
Expand Down Expand Up @@ -1237,12 +1239,9 @@ def symbolic_normalizing_constant(self):
t = at.max(
self.collect("symbolic_normalizing_constant")
+ [
_get_scaling(
self.model.rvs_to_total_sizes.get(obs, None),
obs.shape,
obs.ndim,
)
get_scaling(obs.owner.inputs[1:], obs.shape)
for obs in self.model.observed_RVs
if isinstance(obs.owner.op, MinibatchRandomVariable)
]
)
t = at.switch(self._scale_cost_to_minibatch, t, at.constant(1, dtype=t.dtype))
Expand Down
Loading

0 comments on commit ecba5bb

Please sign in to comment.