From 0595bcfe462631fcefd3c13ab6cd54cd2630d6a4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Feb 2024 16:47:35 +0100 Subject: [PATCH] Add gufunc signature to pre-built CustomSymbolicDistributions --- pymc/distributions/censored.py | 2 +- pymc/distributions/distribution.py | 31 +++++++++++++++++++++++++++--- pymc/distributions/mixture.py | 9 ++++++++- pymc/distributions/multivariate.py | 8 +++++--- pymc/distributions/timeseries.py | 13 ++++++++++--- 5 files changed, 52 insertions(+), 11 deletions(-) diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 87b700f7f6d..fe18ed2caa0 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -115,7 +115,7 @@ def rv_op(cls, dist, lower=None, upper=None, size=None): return CensoredRV( inputs=[dist_, lower_, upper_], outputs=[censored_rv_], - ndim_supp=0, + gufunc_signature="(),(),()->()", )(dist, lower, upper) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 6de3c1a41eb..567a8902cf7 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -38,6 +38,7 @@ from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.rewriting.shape import ShapeFeature +from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.variable import TensorVariable from typing_extensions import TypeAlias @@ -256,11 +257,28 @@ class SymbolicRandomVariable(OpFromGraph): If `False`, a logprob function must be dispatched directly to the subclass type. """ + signature: str = None + """Numpy-like vectorized signature of the distribution.""" + _print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}") """Tuple of (name, latex name) used for for pretty-printing variables of this type""" - def __init__(self, *args, ndim_supp, **kwargs): - """Initialitze a SymbolicRandomVariable class.""" + def __init__( + self, + *args, + ndim_supp: Optional[int] = None, + gufunc_signature: Optional[str] = None, + **kwargs, + ): + """Initialize a SymbolicRandomVariable class.""" + if gufunc_signature is not None: + self.gufunc_signature = gufunc_signature + if ndim_supp is None: + if gufunc_signature is not None: + _, outputs_sig = _parse_gufunc_signature(gufunc_signature) + ndim_supp = max(len(out_sig) for out_sig in outputs_sig) + else: + raise ValueError("ndim_supp must be specified if gufunc_signature is not.") self.ndim_supp = ndim_supp kwargs.setdefault("inline", True) super().__init__(*args, **kwargs) @@ -274,6 +292,11 @@ def update(self, node: Node): """ return {} + def batch_ndim(self, node: Node) -> int: + """Number of dimensions of the distribution's batch shape.""" + out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs) + return out_ndim - self.ndim_supp + class Distribution(metaclass=DistributionMeta): """Statistical distribution""" @@ -682,7 +705,8 @@ def dist( logp: Optional[Callable] = None, logcdf: Optional[Callable] = None, moment: Optional[Callable] = None, - ndim_supp: int = 0, + ndim_supp: Optional[int] = None, + gufunc_signature: Optional[str] = None, dtype: str = "floatX", class_name: str = "CustomDist", **kwargs, @@ -700,6 +724,7 @@ def dist( dist=dist, moment=moment, ndim_supp=ndim_supp, + gufunc_signature=gufunc_signature, **kwargs, ) diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 11ac3ce2437..6ecec663183 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -297,10 +297,17 @@ def rv_op(cls, weights, *components, size=None): # Output mix_indexes rng update so that it can be updated in place mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0] + s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp)) + if len(components) == 1: + comp_s = ",".join((*s, "w")) + gufunc_signature = f"(),(w),({comp_s})->({s})" + else: + comps_s = ",".join(f"({s})" for _ in components) + gufunc_signature = f"(),(w),{comps_s}->({s})" mix_op = MarginalMixtureRV( inputs=[mix_indexes_rng_, weights_, *components_], outputs=[mix_indexes_rng_next_, mix_out_], - ndim_supp=components[0].owner.op.ndim_supp, + gufunc_signature=gufunc_signature, ) # Create the actual MarginalMixture variable diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 5a24bbc7cd6..63e201f643c 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1221,7 +1221,7 @@ def rv_op(cls, n, eta, sd_dist, size=None): return _LKJCholeskyCovRV( inputs=[rng_, n_, eta_, sd_dist_], outputs=[next_rng_, lkjcov_], - ndim_supp=1, + gufunc_signature="(),(),(),(n)->(),(n)", )(rng, n, eta, sd_dist) @@ -2790,10 +2790,12 @@ def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None): for axis in range(n_zerosum_axes): zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) + support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)]) + gufunc_signature = f"({support_str}),(),(s)->({support_str})" return ZeroSumNormalRV( inputs=[normal_dist_, sigma_, support_shape_], - outputs=[zerosum_rv_, support_shape_], - ndim_supp=n_zerosum_axes, + outputs=[zerosum_rv_], + gufunc_signature=gufunc_signature, )(normal_dist, sigma, support_shape) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 6bf653c3561..8ef0e3e5e4a 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -195,12 +195,17 @@ def rv_op(cls, init_dist, innovation_dist, steps, size=None): # shape = (B, T, S) grw_ = pt.concatenate([init_dist_dimswapped_, innovation_dist_dimswapped_], axis=-ndim_supp) grw_ = pt.cumsum(grw_, axis=-ndim_supp) + + innov_supp_dims = [f"d{i}" for i in range(ndim_supp)] + innov_supp_str = ",".join(innov_supp_dims) + out_supp_str = ",".join(["t", *innov_supp_dims]) + gufunc_signature = f"({innov_supp_str}),({innov_supp_str}),(s)->({out_supp_str})" return RandomWalkRV( [init_dist_, innovation_dist_, steps_], # We pass steps_ through just so we can keep a reference to it, even though # it's no longer needed at this point - [grw_, steps_], - ndim_supp=ndim_supp, + [grw_], + gufunc_signature=gufunc_signature, )(init_dist, innovation_dist, steps) @@ -655,6 +660,7 @@ def step(*args): outputs=[noise_next_rng, ar_], ar_order=ar_order, constant_term=constant_term, + gufunc_signature="(o),(),(o),(s)->(),(t)", ndim_supp=1, ) @@ -825,7 +831,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng): garch11_op = GARCH11RV( inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_], outputs=[noise_next_rng, garch11_], - ndim_supp=1, + gufunc_signature="(),(),(),(),(),(s)->(),(t)", ) garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps) @@ -1006,6 +1012,7 @@ def step(*prev_args): outputs=[noise_next_rng, sde_out_], dt=dt, sde_fn=sde_fn, + gufunc_signature=f"(),(s),{','.join('()' for _ in sde_pars_)}->(),(t)", ndim_supp=1, )