Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing to discrete distributions #6410

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 76 additions & 17 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import warnings

from typing import Optional

import numpy as np
import pytensor.tensor as pt

Expand Down Expand Up @@ -45,7 +47,7 @@
normal_lccdf,
normal_lcdf,
)
from pymc.distributions.distribution import Discrete
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Discrete
from pymc.distributions.mixture import Mixture
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.logprob.basic import logp
Expand Down Expand Up @@ -122,7 +124,14 @@ class Binomial(Discrete):
rv_op = binomial

@classmethod
def dist(cls, n, p=None, logit_p=None, *args, **kwargs):
def dist(
cls,
n: DIST_PARAMETER_TYPES,
p: Optional[DIST_PARAMETER_TYPES] = None,
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
*args,
**kwargs,
):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
Expand Down Expand Up @@ -238,7 +247,14 @@ def BetaBinom(a, b, n, x):
rv_op = betabinom

@classmethod
def dist(cls, alpha, beta, n, *args, **kwargs):
def dist(
cls,
alpha: DIST_PARAMETER_TYPES,
beta: DIST_PARAMETER_TYPES,
n: DIST_PARAMETER_TYPES,
*args,
**kwargs,
):
alpha = pt.as_tensor_variable(floatX(alpha))
beta = pt.as_tensor_variable(floatX(beta))
n = pt.as_tensor_variable(intX(n))
Expand Down Expand Up @@ -344,7 +360,13 @@ class Bernoulli(Discrete):
rv_op = bernoulli

@classmethod
def dist(cls, p=None, logit_p=None, *args, **kwargs):
def dist(
cls,
p: Optional[DIST_PARAMETER_TYPES] = None,
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
*args,
**kwargs,
):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
Expand Down Expand Up @@ -460,7 +482,7 @@ def DiscreteWeibull(q, b, x):
rv_op = discrete_weibull

@classmethod
def dist(cls, q, beta, *args, **kwargs):
def dist(cls, q: DIST_PARAMETER_TYPES, beta: DIST_PARAMETER_TYPES, *args, **kwargs):
q = pt.as_tensor_variable(floatX(q))
beta = pt.as_tensor_variable(floatX(beta))
return super().dist([q, beta], **kwargs)
Expand Down Expand Up @@ -549,7 +571,7 @@ class Poisson(Discrete):
rv_op = poisson

@classmethod
def dist(cls, mu, *args, **kwargs):
def dist(cls, mu: DIST_PARAMETER_TYPES, *args, **kwargs):
mu = pt.as_tensor_variable(floatX(mu))
return super().dist([mu], *args, **kwargs)

Expand Down Expand Up @@ -671,7 +693,15 @@ def NegBinom(a, m, x):
rv_op = nbinom

@classmethod
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
def dist(
cls,
mu: Optional[DIST_PARAMETER_TYPES] = None,
alpha: Optional[DIST_PARAMETER_TYPES] = None,
p: Optional[DIST_PARAMETER_TYPES] = None,
n: Optional[DIST_PARAMETER_TYPES] = None,
*args,
**kwargs,
):
n, p = cls.get_n_p(mu=mu, alpha=alpha, p=p, n=n)
n = pt.as_tensor_variable(floatX(n))
p = pt.as_tensor_variable(floatX(p))
Expand Down Expand Up @@ -784,7 +814,7 @@ class Geometric(Discrete):
rv_op = geometric

@classmethod
def dist(cls, p, *args, **kwargs):
def dist(cls, p: DIST_PARAMETER_TYPES, *args, **kwargs):
p = pt.as_tensor_variable(floatX(p))
return super().dist([p], *args, **kwargs)

Expand Down Expand Up @@ -881,7 +911,14 @@ class HyperGeometric(Discrete):
rv_op = hypergeometric

@classmethod
def dist(cls, N, k, n, *args, **kwargs):
def dist(
cls,
N: DIST_PARAMETER_TYPES,
k: DIST_PARAMETER_TYPES,
n: DIST_PARAMETER_TYPES,
*args,
**kwargs,
):
good = pt.as_tensor_variable(intX(k))
bad = pt.as_tensor_variable(intX(N - k))
n = pt.as_tensor_variable(intX(n))
Expand Down Expand Up @@ -1018,7 +1055,7 @@ class DiscreteUniform(Discrete):
rv_op = discrete_uniform

@classmethod
def dist(cls, lower, upper, *args, **kwargs):
def dist(cls, lower: DIST_PARAMETER_TYPES, upper: DIST_PARAMETER_TYPES, *args, **kwargs):
lower = intX(pt.floor(lower))
upper = intX(pt.floor(upper))
return super().dist([lower, upper], **kwargs)
Expand Down Expand Up @@ -1108,7 +1145,12 @@ class Categorical(Discrete):
rv_op = categorical

@classmethod
def dist(cls, p=None, logit_p=None, **kwargs):
def dist(
cls,
p: Optional[DIST_PARAMETER_TYPES] = None,
logit_p: Optional[DIST_PARAMETER_TYPES] = None,
**kwargs,
):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
Expand Down Expand Up @@ -1210,7 +1252,7 @@ class DiracDelta(Discrete):
rv_op = diracdelta

@classmethod
def dist(cls, c, *args, **kwargs):
def dist(cls, c: DIST_PARAMETER_TYPES, *args, **kwargs):
c = pt.as_tensor_variable(c)
if c.dtype in continuous_types:
c = floatX(c)
Expand Down Expand Up @@ -1328,7 +1370,7 @@ def __new__(cls, name, psi, mu, **kwargs):
)

@classmethod
def dist(cls, psi, mu, **kwargs):
def dist(cls, psi: DIST_PARAMETER_TYPES, mu: DIST_PARAMETER_TYPES, **kwargs):
return _zero_inflated_mixture(
name=None, nonzero_p=psi, nonzero_dist=Poisson.dist(mu=mu), **kwargs
)
Expand Down Expand Up @@ -1393,7 +1435,9 @@ def __new__(cls, name, psi, n, p, **kwargs):
)

@classmethod
def dist(cls, psi, n, p, **kwargs):
def dist(
cls, psi: DIST_PARAMETER_TYPES, n: DIST_PARAMETER_TYPES, p: DIST_PARAMETER_TYPES, **kwargs
):
return _zero_inflated_mixture(
name=None, nonzero_p=psi, nonzero_dist=Binomial.dist(n=n, p=p), **kwargs
)
Expand Down Expand Up @@ -1490,7 +1534,15 @@ def __new__(cls, name, psi, mu=None, alpha=None, p=None, n=None, **kwargs):
)

@classmethod
def dist(cls, psi, mu=None, alpha=None, p=None, n=None, **kwargs):
def dist(
cls,
psi: DIST_PARAMETER_TYPES,
mu: Optional[DIST_PARAMETER_TYPES] = None,
alpha: Optional[DIST_PARAMETER_TYPES] = None,
p: Optional[DIST_PARAMETER_TYPES] = None,
n: Optional[DIST_PARAMETER_TYPES] = None,
**kwargs,
):
return _zero_inflated_mixture(
name=None,
nonzero_p=psi,
Expand All @@ -1507,7 +1559,7 @@ class _OrderedLogistic(Categorical):
rv_op = categorical

@classmethod
def dist(cls, eta, cutpoints, *args, **kwargs):
def dist(cls, eta: DIST_PARAMETER_TYPES, cutpoints: DIST_PARAMETER_TYPES, *args, **kwargs):
eta = pt.as_tensor_variable(floatX(eta))
cutpoints = pt.as_tensor_variable(cutpoints)

Expand Down Expand Up @@ -1613,7 +1665,14 @@ class _OrderedProbit(Categorical):
rv_op = categorical

@classmethod
def dist(cls, eta, cutpoints, sigma=1, *args, **kwargs):
def dist(
cls,
eta: DIST_PARAMETER_TYPES,
cutpoints: DIST_PARAMETER_TYPES,
sigma: DIST_PARAMETER_TYPES = 1,
*args,
**kwargs,
):
eta = pt.as_tensor_variable(floatX(eta))
cutpoints = pt.as_tensor_variable(cutpoints)

Expand Down