Skip to content

Commit

Permalink
Added aeppl based log-likelihood graph generation and aeppl based tra…
Browse files Browse the repository at this point in the history
…nsforms

Co-authored-by: kc611 <[email protected]>
  • Loading branch information
ricardoV94 and kc611 committed Oct 21, 2021
1 parent e03f5bf commit 2baa8ce
Show file tree
Hide file tree
Showing 38 changed files with 564 additions and 1,095 deletions.
1 change: 1 addition & 0 deletions conda-envs/environment-dev-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-dev-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools
Expand Down
1 change: 1 addition & 0 deletions conda-envs/windows-environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- defaults
dependencies:
# base dependencies (see install guide for Windows)
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/windows-environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- defaults
dependencies:
# base dependencies (see install guide for Windows)
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.2
- cachetools
Expand Down
2 changes: 1 addition & 1 deletion pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def transform_replacements(var, replacements):
# potential replacements
return [rv_value_var]

trans_rv_value = transform.backward(rv_var, rv_value_var)
trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs)
replacements[var] = trans_rv_value

# Walk the transformed variable and make replacements
Expand Down
23 changes: 23 additions & 0 deletions pymc/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import aesara.tensor as at
import numpy as np

from aeppl.logprob import _logprob
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from pandas import DataFrame, Series

Expand Down Expand Up @@ -146,6 +148,20 @@ def __new__(
def dist(cls, *params, **kwargs):
return super().dist(params, **kwargs)

def logp(x, *inputs):
"""Calculate log probability.
Parameters
----------
x: numeric, TensorVariable
Value for which log-probability is calculated.
Returns
-------
TensorVariable
"""
return at.zeros_like(x)


def preprocess_XY(X, Y):
if isinstance(Y, (Series, DataFrame)):
Expand All @@ -156,3 +172,10 @@ def preprocess_XY(X, Y):
Y = Y.astype(float)
X = X.astype(float)
return X, Y


@_logprob.register(BARTRV)
def logp(op, value_var, *dist_params, **kwargs):
_dist_params = dist_params[3:]
value_var = value_var[0]
return BART.logp(value_var, *_dist_params)
3 changes: 1 addition & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from pymc.distributions.logprob import ( # isort:skip
_logcdf,
_logp,
logcdf,
logp,
logcdfpt,
logp_transform,
logpt,
logpt_sum,
Expand Down Expand Up @@ -193,7 +193,6 @@
"PolyaGamma",
"logpt",
"logp",
"_logp",
"logp_transform",
"logcdf",
"_logcdf",
Expand Down
11 changes: 6 additions & 5 deletions pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import numpy as np

from aeppl.logprob import logprob
from aesara.tensor import as_tensor_variable
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable

from pymc.aesaraf import floatX, intX
from pymc.distributions import _logp
from pymc.distributions.continuous import BoundedContinuous
from pymc.distributions.dist_math import bound
from pymc.distributions.distribution import Continuous, Discrete
Expand All @@ -46,7 +46,7 @@ def rng_fn(cls, rng, distribution, lower, upper, size):

class _ContinuousBounded(BoundedContinuous):
rv_op = boundrv
bound_args_indices = [1, 2]
bound_args_indices = [4, 5]

def logp(value, distribution, lower, upper):
"""
Expand All @@ -67,7 +67,7 @@ def logp(value, distribution, lower, upper):
-------
TensorVariable
"""
logp = _logp(distribution.owner.op, value, {}, *distribution.owner.inputs[3:])
logp = logprob(distribution, value)
return bound(logp, (value >= lower), (value <= upper))


Expand Down Expand Up @@ -107,7 +107,7 @@ def logp(value, distribution, lower, upper):
-------
TensorVariable
"""
logp = _logp(distribution.owner.op, value, {}, *distribution.owner.inputs[3:])
logp = logprob(distribution, value)
return bound(logp, (value >= lower), (value <= upper))


Expand Down Expand Up @@ -166,6 +166,7 @@ def __new__(
raise ValueError("Given dims do not exist in model coordinates.")

lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
distribution.tag.ignore_logprob = True

if isinstance(distribution.owner.op, Continuous):
res = _ContinuousBounded(
Expand Down Expand Up @@ -200,7 +201,7 @@ def dist(

cls._argument_checks(distribution, **kwargs)
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)

distribution.tag.ignore_logprob = True
if isinstance(distribution.owner.op, Continuous):
res = _ContinuousBounded.dist(
[distribution, lower, upper],
Expand Down
Loading

0 comments on commit 2baa8ce

Please sign in to comment.