Skip to content

Commit

Permalink
Update competence methods to work with RandomVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 5, 2021
1 parent c6cfae9 commit 33b240d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
3 changes: 2 additions & 1 deletion pymc3/step_methods/gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def astep(self, q, logp):

@staticmethod
def competence(var, has_grad):
if isinstance(var.distribution, Categorical):
dist = getattr(var.owner, "op", None)
if isinstance(dist, Categorical):
return Competence.COMPATIBLE
return Competence.INCOMPATIBLE

Expand Down
3 changes: 2 additions & 1 deletion pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def _hamiltonian_step(self, start, p0, step_size):
@staticmethod
def competence(var, has_grad):
"""Check how appropriate this class is for sampling a random variable."""
if var.dtype in continuous_types and has_grad and not isinstance(var.distribution, BART):
dist = getattr(var.owner, "op", None)
if var.dtype in continuous_types and has_grad and not isinstance(dist, BART):
return Competence.IDEAL
return Competence.INCOMPATIBLE

Expand Down
40 changes: 27 additions & 13 deletions pymc3/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import numpy.random as nr
import scipy.linalg
import theano
import theano.tensor as tt

from theano.tensor.random.basic import CategoricalRV

import pymc3 as pm

Expand Down Expand Up @@ -344,11 +347,14 @@ def competence(var):
BinaryMetropolis is only suitable for binary (bool)
and Categorical variables with k=1.
"""
distribution = getattr(var.distribution, "parent_dist", var.distribution)
distribution = getattr(var.owner, "op", None)
if isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):
return Competence.COMPATIBLE
elif isinstance(distribution, pm.Categorical) and (distribution.k == 2):
return Competence.COMPATIBLE
return Competence.IDEAL

if isinstance(distribution, CategoricalRV):
k = tt.get_scalar_constant_value(distribution.owner.inputs[2])
if k == 2:
return Competence.IDEAL
return Competence.INCOMPATIBLE


Expand Down Expand Up @@ -421,11 +427,14 @@ def competence(var):
BinaryMetropolis is only suitable for Bernoulli
and Categorical variables with k=2.
"""
distribution = getattr(var.distribution, "parent_dist", var.distribution)
distribution = getattr(var.owner, "op", None)
if isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):
return Competence.IDEAL
elif isinstance(distribution, pm.Categorical) and (distribution.k == 2):
return Competence.IDEAL

if isinstance(distribution, CategoricalRV):
k = tt.get_scalar_constant_value(distribution.owner.inputs[2])
if k == 2:
return Competence.IDEAL
return Competence.INCOMPATIBLE


Expand All @@ -451,8 +460,10 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
# variable with M categories and y being a 3-D variable with N
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
for v in vars:
distr = getattr(v.distribution, "parent_dist", v.distribution)
if isinstance(distr, pm.Categorical):

distr = getattr(v.owner, "op", None)

if isinstance(distr, CategoricalRV):
k = draw_values([distr.k])[0]
elif isinstance(distr, pm.Bernoulli) or (v.dtype in pm.bool_types):
k = 2
Expand Down Expand Up @@ -537,13 +548,16 @@ def competence(var):
CategoricalGibbsMetropolis is only suitable for Bernoulli and
Categorical variables.
"""
distribution = getattr(var.distribution, "parent_dist", var.distribution)
if isinstance(distribution, pm.Categorical):
if distribution.k > 2:
distribution = getattr(var.owner, "op", None)
if isinstance(distribution, CategoricalRV):
k = tt.get_scalar_constant_value(distribution.owner.inputs[2])
if k == 2:
return Competence.IDEAL
return Competence.COMPATIBLE
elif isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):

if isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):
return Competence.COMPATIBLE

return Competence.INCOMPATIBLE


Expand Down
3 changes: 2 additions & 1 deletion pymc3/step_methods/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def competence(var, has_grad):
"""
PGBART is only suitable for BART distributions
"""
if isinstance(var.distribution, BART):
dist = getattr(var.owner, "op", None)
if isinstance(dist, BART):
return Competence.IDEAL
return Competence.INCOMPATIBLE

Expand Down

0 comments on commit 33b240d

Please sign in to comment.