Skip to content

Commit 5fc8fd3

Browse files
wd60622radiokosmos
authored andcommitted
Register and allow custom transform for Prior class (pymc-labs#972)
* allow register and use custom transform * add to the example block
1 parent 20b5b6f commit 5fc8fd3

File tree

2 files changed

+101
-3
lines changed

2 files changed

+101
-3
lines changed

pymc_marketing/prior.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,20 @@
7878
dims="channel",
7979
)
8080
81+
Create a prior with a custom transform function by registering it with
82+
`register_tensor_transform`.
83+
84+
.. code-block:: python
85+
86+
from pymc_marketing.prior import register_tensor_transform
87+
88+
def custom_transform(x):
89+
return x ** 2
90+
91+
register_tensor_transform("square", custom_transform)
92+
93+
custom_distribution = Prior("Normal", transform="square")
94+
8195
"""
8296

8397
from __future__ import annotations
@@ -198,18 +212,63 @@ def _get_pymc_distribution(name: str) -> type[pm.Distribution]:
198212
return getattr(pm, name)
199213

200214

215+
Transform = Callable[[pt.TensorLike], pt.TensorLike]
216+
217+
CUSTOM_TRANSFORMS: dict[str, Transform] = {}
218+
219+
220+
def register_tensor_transform(name: str, transform: Transform) -> None:
221+
"""Register a tensor transform function to be used in the `Prior` class.
222+
223+
Parameters
224+
----------
225+
name : str
226+
The name of the transform.
227+
func : Callable[[pt.TensorLike], pt.TensorLike]
228+
The function to apply to the tensor.
229+
230+
Examples
231+
--------
232+
Register a custom transform function.
233+
234+
.. code-block:: python
235+
236+
from pymc_marketing.prior import (
237+
Prior,
238+
register_tensor_transform,
239+
)
240+
241+
def custom_transform(x):
242+
return x ** 2
243+
244+
register_tensor_transform("square", custom_transform)
245+
246+
custom_distribution = Prior("Normal", transform="square")
247+
248+
"""
249+
CUSTOM_TRANSFORMS[name] = transform
250+
251+
201252
def _get_transform(name: str):
253+
if name in CUSTOM_TRANSFORMS:
254+
return CUSTOM_TRANSFORMS[name]
255+
202256
for module in (pt, pm.math):
203257
if hasattr(module, name):
204258
break
205259
else:
206260
module = None
207261

208262
if not module:
209-
raise UnknownTransformError(
210-
f"Neither PyTensor or pm.math have the function {name!r}"
263+
msg = (
264+
f"Neither pytensor.tensor nor pymc.math have the function {name!r}. "
265+
"If this is a custom function, register it with the "
266+
"`pymc_marketing.prior.register_tensor_transform` function before "
267+
"previous function call."
211268
)
212269

270+
raise UnknownTransformError(msg)
271+
213272
return getattr(module, name)
214273

215274

@@ -243,6 +302,7 @@ class Prior:
243302
transform : str, optional
244303
The name of the transform to apply to the variable after it is
245304
created, by default None or no transform. The transformation must
305+
be registered with `register_tensor_transform` function or
246306
be available in either `pytensor.tensor` or `pymc.math`.
247307
248308
"""

tests/test_prior.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
UnsupportedParameterizationError,
3232
UnsupportedShapeError,
3333
handle_dims,
34+
register_tensor_transform,
3435
)
3536

3637

@@ -72,7 +73,8 @@ def test_handle_dims(x, dims, desired_dims, expected_fn) -> None:
7273

7374

7475
def test_missing_transform() -> None:
75-
with pytest.raises(UnknownTransformError):
76+
match = "Neither pytensor.tensor nor pymc.math have the function 'foo_bar'"
77+
with pytest.raises(UnknownTransformError, match=match):
7678
Prior("Normal", transform="foo_bar")
7779

7880

@@ -608,3 +610,39 @@ def test_checks_param_value_types() -> None:
608610
def test_check_equality_with_numpy() -> None:
609611
dist = Prior("Normal", mu=np.array([1, 2, 3]), sigma=1)
610612
assert dist == dist.deepcopy()
613+
614+
615+
def clear_custom_transforms() -> None:
616+
global CUSTOM_TRANSFORMS
617+
CUSTOM_TRANSFORMS = {}
618+
619+
620+
def test_custom_transform() -> None:
621+
new_transform_name = "foo_bar"
622+
with pytest.raises(UnknownTransformError):
623+
Prior("Normal", transform=new_transform_name)
624+
625+
register_tensor_transform(new_transform_name, lambda x: x**2)
626+
627+
dist = Prior("Normal", transform=new_transform_name)
628+
prior = dist.sample_prior(samples=10)
629+
df_prior = prior.to_dataframe()
630+
631+
np.testing.assert_array_equal(
632+
df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2
633+
)
634+
635+
636+
def test_custom_transform_comes_first() -> None:
637+
# function in pytensor.tensor
638+
register_tensor_transform("square", lambda x: 2 * x)
639+
640+
dist = Prior("Normal", transform="square")
641+
prior = dist.sample_prior(samples=10)
642+
df_prior = prior.to_dataframe()
643+
644+
np.testing.assert_array_equal(
645+
df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy()
646+
)
647+
648+
clear_custom_transforms()

0 commit comments

Comments
 (0)