From 66deb25ea8d17033de9e00903358a0a811d7bb4c Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sat, 8 Jan 2022 15:25:46 -0800 Subject: [PATCH] Update dist parameter hints --- pymc/distributions/continuous.py | 14 +++++++------- pymc/distributions/distribution.py | 2 ++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 00c57baa02b..987eadaf7c0 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -85,7 +85,7 @@ def polyagamma_cdf(*args, **kwargs): normal_lcdf, zvalue, ) -from pymc.distributions.distribution import Continuous +from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous from pymc.distributions.shape_utils import rv_size_is_none from pymc.math import invlogit, logdiffexp, logit from pymc.util import UNSET @@ -692,12 +692,12 @@ class TruncatedNormal(BoundedContinuous): @classmethod def dist( cls, - mu: Optional[Union[float, np.ndarray]] = None, - sigma: Optional[Union[float, np.ndarray]] = None, - tau: Optional[Union[float, np.ndarray]] = None, - sd: Optional[Union[float, np.ndarray]] = None, - lower: Optional[Union[float, np.ndarray]] = None, - upper: Optional[Union[float, np.ndarray]] = None, + mu: Optional[DIST_PARAMETER_TYPES] = None, + sigma: Optional[DIST_PARAMETER_TYPES] = None, + tau: Optional[DIST_PARAMETER_TYPES] = None, + sd: Optional[DIST_PARAMETER_TYPES] = None, + lower: Optional[DIST_PARAMETER_TYPES] = None, + upper: Optional[DIST_PARAMETER_TYPES] = None, transform: str = "auto", *args, **kwargs, diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 33b9ea45782..07a105efa56 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -61,6 +61,8 @@ "NoDistribution", ] +DIST_PARAMETER_TYPES = Union[np.ndarray, int, float, TensorVariable] + vectorized_ppc = contextvars.ContextVar( "vectorized_ppc", default=None ) # type: contextvars.ContextVar[Optional[Callable]]