Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of aeppl with PyMC #4887

Merged
merged 1 commit into from
Oct 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda-envs/environment-dev-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disclaimer: I skimmed the discussion and have not reviewed the code

What does this dependency mean for the users? Will using aeppl functions/methods be necessary? Or will it be handled under the hood and advanced users will be able to use it if they want?

This is a minor-medium concern for the beta release, but aeppl seems to be very little documented. I was unable to find the link to it's docs anywhere. I also did check and saw there was a gh-pages branch so I went to https://aesara-devs.github.io/aeppl/ which has some api docs. We need to work on aeppl docs if we expect pymc users to work with it directly before releasing stable 4.0

Copy link
Member

@ricardoV94 ricardoV94 Oct 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users don't need to know about aeppl, it will be used by us developers to create more complex distributions.

Indeed we need documentation building on aeppl, we have a PR opened for that, but we welcome help. All our methods and functions are quite well documented on the other hand

- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-dev-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools
Expand Down
1 change: 1 addition & 0 deletions conda-envs/windows-environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- defaults
dependencies:
# base dependencies (see install guide for Windows)
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.4
- cachetools>=4.2.1
Expand Down
1 change: 1 addition & 0 deletions conda-envs/windows-environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
- defaults
dependencies:
# base dependencies (see install guide for Windows)
- aeppl>=0.0.13
- aesara>=2.2.2
- arviz>=0.11.2
- cachetools
Expand Down
2 changes: 1 addition & 1 deletion pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def transform_replacements(var, replacements):
# potential replacements
return [rv_value_var]

trans_rv_value = transform.backward(rv_var, rv_value_var)
trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs)
replacements[var] = trans_rv_value

# Walk the transformed variable and make replacements
Expand Down
23 changes: 23 additions & 0 deletions pymc/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import aesara.tensor as at
import numpy as np

from aeppl.logprob import _logprob
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from pandas import DataFrame, Series

Expand Down Expand Up @@ -146,6 +148,20 @@ def __new__(
def dist(cls, *params, **kwargs):
return super().dist(params, **kwargs)

def logp(x, *inputs):
"""Calculate log probability.

Parameters
----------
x: numeric, TensorVariable
Value for which log-probability is calculated.

Returns
-------
TensorVariable
"""
return at.zeros_like(x)


def preprocess_XY(X, Y):
if isinstance(Y, (Series, DataFrame)):
Expand All @@ -156,3 +172,10 @@ def preprocess_XY(X, Y):
Y = Y.astype(float)
X = X.astype(float)
return X, Y


@_logprob.register(BARTRV)
def logp(op, value_var, *dist_params, **kwargs):
_dist_params = dist_params[3:]
value_var = value_var[0]
return BART.logp(value_var, *_dist_params)
3 changes: 1 addition & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from pymc.distributions.logprob import ( # isort:skip
_logcdf,
_logp,
logcdf,
logp,
logcdfpt,
logp_transform,
logpt,
logpt_sum,
Expand Down Expand Up @@ -193,7 +193,6 @@
"PolyaGamma",
"logpt",
"logp",
"_logp",
"logp_transform",
"logcdf",
"_logcdf",
Expand Down
11 changes: 6 additions & 5 deletions pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import numpy as np

from aeppl.logprob import logprob
from aesara.tensor import as_tensor_variable
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable

from pymc.aesaraf import floatX, intX
from pymc.distributions import _logp
from pymc.distributions.continuous import BoundedContinuous
from pymc.distributions.dist_math import bound
from pymc.distributions.distribution import Continuous, Discrete
Expand All @@ -46,7 +46,7 @@ def rng_fn(cls, rng, distribution, lower, upper, size):

class _ContinuousBounded(BoundedContinuous):
rv_op = boundrv
bound_args_indices = [1, 2]
bound_args_indices = [4, 5]

def logp(value, distribution, lower, upper):
"""
Expand All @@ -67,7 +67,7 @@ def logp(value, distribution, lower, upper):
-------
TensorVariable
"""
logp = _logp(distribution.owner.op, value, {}, *distribution.owner.inputs[3:])
logp = logprob(distribution, value)
return bound(logp, (value >= lower), (value <= upper))


Expand Down Expand Up @@ -107,7 +107,7 @@ def logp(value, distribution, lower, upper):
-------
TensorVariable
"""
logp = _logp(distribution.owner.op, value, {}, *distribution.owner.inputs[3:])
logp = logprob(distribution, value)
return bound(logp, (value >= lower), (value <= upper))


Expand Down Expand Up @@ -166,6 +166,7 @@ def __new__(
raise ValueError("Given dims do not exist in model coordinates.")

lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
distribution.tag.ignore_logprob = True

if isinstance(distribution.owner.op, Continuous):
res = _ContinuousBounded(
Expand Down Expand Up @@ -200,7 +201,7 @@ def dist(

cls._argument_checks(distribution, **kwargs)
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)

distribution.tag.ignore_logprob = True
if isinstance(distribution.owner.op, Continuous):
res = _ContinuousBounded.dist(
[distribution, lower, upper],
Expand Down
Loading