Skip to content

Commit

Permalink
Add initial distribution parameter types
Browse files Browse the repository at this point in the history
  • Loading branch information
canyon289 committed Jan 6, 2022
1 parent 75ea2a8 commit d6f63c4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,9 @@ def logcdf(value, mu, sigma):
)


from pymc.distributions.distribution import DIST_PARAMETER_TYPES


class TruncatedNormalRV(RandomVariable):
name = "truncated_normal"
ndim_supp = 0
Expand All @@ -594,7 +597,7 @@ class TruncatedNormalRV(RandomVariable):
def rng_fn(
cls,
rng: np.random.RandomState,
mu: Union[np.ndarray, float],
mu: DIST_PARAMETER_TYPES,
sigma: Union[np.ndarray, float],
lower: Union[np.ndarray, float],
upper: Union[np.ndarray, float],
Expand Down
5 changes: 4 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@

from abc import ABCMeta
from functools import singledispatch
from typing import Callable, Optional, Sequence
from typing import Callable, Optional, Sequence, Union

import aesara
import numpy as np

from aeppl.logprob import _logcdf, _logprob
from aesara import tensor as at
Expand Down Expand Up @@ -56,6 +57,8 @@
"NoDistribution",
]

DIST_PARAMETER_TYPES = Union[np.ndarray, int, float, TensorVariable]

vectorized_ppc = contextvars.ContextVar(
"vectorized_ppc", default=None
) # type: contextvars.ContextVar[Optional[Callable]]
Expand Down

0 comments on commit d6f63c4

Please sign in to comment.