Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CustomDist and Simulator no longer require class_name when creating a dist #6668

Merged
merged 2 commits into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 26 additions & 39 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,14 @@ class _CustomDist(Distribution):
def dist(
cls,
*dist_params,
class_name: str,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
random: Optional[Callable] = None,
moment: Optional[Callable] = None,
ndim_supp: int = 0,
ndims_params: Optional[Sequence[int]] = None,
dtype: str = "floatX",
class_name: str = "CustomDist",
**kwargs,
):
dist_params = [as_tensor_variable(param) for param in dist_params]
Expand Down Expand Up @@ -523,36 +523,36 @@ def dist(

return super().dist(
dist_params,
class_name=class_name,
logp=logp,
logcdf=logcdf,
random=random,
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
class_name=class_name,
**kwargs,
)

@classmethod
def rv_op(
cls,
*dist_params,
class_name: str,
logp: Optional[Callable],
logcdf: Optional[Callable],
random: Optional[Callable],
moment: Optional[Callable],
ndim_supp: int,
ndims_params: Optional[Sequence[int]],
dtype: str,
class_name: str,
**kwargs,
):
rv_type = type(
f"CustomDistRV_{class_name}",
class_name,
(CustomDistRV,),
dict(
name=f"CustomDist_{class_name}",
name=class_name,
inplace=False,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
Expand Down Expand Up @@ -613,20 +613,15 @@ class _CustomSymbolicDist(Distribution):
def dist(
cls,
*dist_params,
class_name: str,
dist: Callable,
logp: Optional[Callable] = None,
logcdf: Optional[Callable] = None,
moment: Optional[Callable] = None,
ndim_supp: int = 0,
dtype: str = "floatX",
class_name: str = "CustomSymbolicDist",
**kwargs,
):
warnings.warn(
"CustomDist with dist function is still experimental. Expect bugs!",
UserWarning,
)

dist_params = [as_tensor_variable(param) for param in dist_params]

if logcdf is None:
Expand Down Expand Up @@ -655,13 +650,13 @@ def dist(
def rv_op(
cls,
*dist_params,
class_name: str,
dist: Callable,
logp: Optional[Callable],
logcdf: Optional[Callable],
moment: Optional[Callable],
size=None,
ndim_supp: int,
class_name: str,
):
size = normalize_size_param(size)
dummy_size_param = size.type()
Expand All @@ -674,7 +669,7 @@ def rv_op(
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))

rv_type = type(
f"CustomSymbolicDistRV_{class_name}",
class_name,
(CustomSymbolicDistRV,),
# If logp is not provided, we try to infer it from the dist graph
dict(
Expand Down Expand Up @@ -758,15 +753,6 @@ class CustomDist:
dist_params : Tuple
A sequence of the distribution's parameter. These will be converted into
Pytensor tensor variables internally.
class_name : str
Name for the class which will wrap the CustomDist methods. When not specified,
it will be given the name of the model variable.

.. warning:: New CustomDists created with the same class_name will override the
methods dispatched onto the previous classes. If using CustomDists with
different methods across separate models, be sure to use distinct
class_names.

dist: Optional[Callable]
A callable that returns a PyTensor graph built from simpler PyMC distributions
which represents the distribution. This can be used by PyMC to take random draws
Expand Down Expand Up @@ -831,6 +817,9 @@ class CustomDist:
The dtype of the distribution. All draws and observations passed into the
distribution will be cast onto this dtype. This is not needed if an PyTensor
dist function is provided, which should already return the right dtype!
class_name : str
Name for the class which will wrap the CustomDist methods. When not specified,
it will be given the name of the model variable.
kwargs :
Extra keyword arguments are passed to the parent's class ``__new__`` method.

Expand Down Expand Up @@ -979,36 +968,36 @@ def __new__(
dist_params = cls.parse_dist_params(dist_params)
cls.check_valid_dist_random(dist, random, dist_params)
if dist is not None:
kwargs.setdefault("class_name", f"CustomSymbolicDist_{name}")
return _CustomSymbolicDist(
name,
*dist_params,
class_name=name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, was old behavior to use the RV name as the class name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were doing "CustomDist_{class_name}", so yes, if a user created one in a model context it was CustomDIst_mu` for example. Pretty meh

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you have a concern in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No concern, just checking my understanding

dist=dist,
logp=logp,
logcdf=logcdf,
moment=moment,
ndim_supp=ndim_supp,
**kwargs,
)
return _CustomDist(
name,
*dist_params,
class_name=name,
random=random,
logp=logp,
logcdf=logcdf,
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
**kwargs,
)
else:
kwargs.setdefault("class_name", f"CustomDist_{name}")
return _CustomDist(
name,
*dist_params,
random=random,
logp=logp,
logcdf=logcdf,
moment=moment,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
**kwargs,
)

@classmethod
def dist(
cls,
*dist_params,
class_name: str,
dist: Optional[Callable] = None,
random: Optional[Callable] = None,
logp: Optional[Callable] = None,
Expand All @@ -1024,7 +1013,6 @@ def dist(
if dist is not None:
return _CustomSymbolicDist.dist(
*dist_params,
class_name=class_name,
dist=dist,
logp=logp,
logcdf=logcdf,
Expand All @@ -1035,7 +1023,6 @@ def dist(
else:
return _CustomDist.dist(
*dist_params,
class_name=class_name,
random=random,
logp=logp,
logcdf=logcdf,
Expand Down
23 changes: 8 additions & 15 deletions pymc/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,6 @@ class Simulator(Distribution):
Keyword form of ''unnamed_params''.
One of unnamed_params or params must be provided.
If passed both unnamed_params and params, an error is raised.
class_name : str
Name for the RandomVariable class which will wrap the Simulator methods.
When not specified, it will be given the name of the variable.

.. warning:: New Simulators created with the same class_name will override the
methods dispatched onto the previous classes. If using Simulators with
different methods across separate models, be sure to use distinct
class_names.

distance : PyTensor_Op, callable or str, default "gaussian"
Distance function. Available options are ``"gaussian"``, ``"laplace"``,
``"kullback_leibler"`` or a user defined function (or PyTensor_Op) that takes
Expand Down Expand Up @@ -123,6 +114,8 @@ class Simulator(Distribution):
Number of minimum dimensions of each parameter of the RV. For example,
if the Simulator accepts two scalar inputs, it should be ``[0, 0]``.
Default to list of 0 with length equal to the number of parameters.
class_name : str, optional
Suffix name for the RandomVariable class which will wrap the Simulator methods.

Examples
--------
Expand All @@ -149,7 +142,7 @@ def simulator_fn(rng, loc, scale, size):
rv_type = SimulatorRV

def __new__(cls, name, *args, **kwargs):
kwargs.setdefault("class_name", name)
kwargs.setdefault("class_name", f"Simulator_{name}")
return super().__new__(cls, name, *args, **kwargs)

@classmethod
Expand All @@ -158,13 +151,13 @@ def dist( # type: ignore
fn,
*unnamed_params,
params=None,
class_name: str,
distance="gaussian",
sum_stat="identity",
epsilon=1,
ndim_supp=0,
ndims_params=None,
dtype="floatX",
class_name: str = "Simulator",
**kwargs,
):
if not isinstance(distance, Op):
Expand Down Expand Up @@ -213,36 +206,36 @@ def dist( # type: ignore

return super().dist(
params,
class_name=class_name,
fn=fn,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
distance=distance,
sum_stat=sum_stat,
epsilon=epsilon,
class_name=class_name,
**kwargs,
)

@classmethod
def rv_op(
cls,
*params,
class_name,
fn,
ndim_supp,
ndims_params,
dtype,
distance,
sum_stat,
epsilon,
class_name,
**kwargs,
):
sim_op = type(
f"Simulator_{class_name}",
class_name,
(SimulatorRV,),
dict(
name=f"Simulator_{class_name}",
name=class_name,
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
Expand Down
Loading