|
78 | 78 | dims="channel",
|
79 | 79 | )
|
80 | 80 |
|
| 81 | +Create a prior with a custom transform function by registering it with |
| 82 | +`register_tensor_transform`. |
| 83 | +
|
| 84 | +.. code-block:: python |
| 85 | +
|
| 86 | + from pymc_marketing.prior import register_tensor_transform |
| 87 | +
|
| 88 | + def custom_transform(x): |
| 89 | + return x ** 2 |
| 90 | +
|
| 91 | + register_tensor_transform("square", custom_transform) |
| 92 | +
|
| 93 | + custom_distribution = Prior("Normal", transform="square") |
| 94 | +
|
81 | 95 | """
|
82 | 96 |
|
83 | 97 | from __future__ import annotations
|
@@ -198,18 +212,63 @@ def _get_pymc_distribution(name: str) -> type[pm.Distribution]:
|
198 | 212 | return getattr(pm, name)
|
199 | 213 |
|
200 | 214 |
|
| 215 | +Transform = Callable[[pt.TensorLike], pt.TensorLike] |
| 216 | + |
| 217 | +CUSTOM_TRANSFORMS: dict[str, Transform] = {} |
| 218 | + |
| 219 | + |
| 220 | +def register_tensor_transform(name: str, transform: Transform) -> None: |
| 221 | + """Register a tensor transform function to be used in the `Prior` class. |
| 222 | +
|
| 223 | + Parameters |
| 224 | + ---------- |
| 225 | + name : str |
| 226 | + The name of the transform. |
| 227 | + func : Callable[[pt.TensorLike], pt.TensorLike] |
| 228 | + The function to apply to the tensor. |
| 229 | +
|
| 230 | + Examples |
| 231 | + -------- |
| 232 | + Register a custom transform function. |
| 233 | +
|
| 234 | + .. code-block:: python |
| 235 | +
|
| 236 | + from pymc_marketing.prior import ( |
| 237 | + Prior, |
| 238 | + register_tensor_transform, |
| 239 | + ) |
| 240 | +
|
| 241 | + def custom_transform(x): |
| 242 | + return x ** 2 |
| 243 | +
|
| 244 | + register_tensor_transform("square", custom_transform) |
| 245 | +
|
| 246 | + custom_distribution = Prior("Normal", transform="square") |
| 247 | +
|
| 248 | + """ |
| 249 | + CUSTOM_TRANSFORMS[name] = transform |
| 250 | + |
| 251 | + |
201 | 252 | def _get_transform(name: str):
|
| 253 | + if name in CUSTOM_TRANSFORMS: |
| 254 | + return CUSTOM_TRANSFORMS[name] |
| 255 | + |
202 | 256 | for module in (pt, pm.math):
|
203 | 257 | if hasattr(module, name):
|
204 | 258 | break
|
205 | 259 | else:
|
206 | 260 | module = None
|
207 | 261 |
|
208 | 262 | if not module:
|
209 |
| - raise UnknownTransformError( |
210 |
| - f"Neither PyTensor or pm.math have the function {name!r}" |
| 263 | + msg = ( |
| 264 | + f"Neither pytensor.tensor nor pymc.math have the function {name!r}. " |
| 265 | + "If this is a custom function, register it with the " |
| 266 | + "`pymc_marketing.prior.register_tensor_transform` function before " |
| 267 | + "previous function call." |
211 | 268 | )
|
212 | 269 |
|
| 270 | + raise UnknownTransformError(msg) |
| 271 | + |
213 | 272 | return getattr(module, name)
|
214 | 273 |
|
215 | 274 |
|
@@ -243,6 +302,7 @@ class Prior:
|
243 | 302 | transform : str, optional
|
244 | 303 | The name of the transform to apply to the variable after it is
|
245 | 304 | created, by default None or no transform. The transformation must
|
| 305 | + be registered with `register_tensor_transform` function or |
246 | 306 | be available in either `pytensor.tensor` or `pymc.math`.
|
247 | 307 |
|
248 | 308 | """
|
|
0 commit comments