Skip to content

Commit f1ffb50

Browse files
wd60622pre-commit-ci[bot]
authored andcommitted
Save off media transformations (pymc-labs#882)
* to_dict via lookup_name * parse to and from dict for attrs * improve the codecov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change test with change in default behavior * increase the MMM model version --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 634c068 commit f1ffb50

File tree

8 files changed

+305
-16
lines changed

8 files changed

+305
-16
lines changed

pymc_marketing/mmm/components/adstock.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from pymc_marketing.prior import Prior
3030
3131
class MyAdstock(AdstockTransformation):
32+
lookup_name: str = "my_adstock"
33+
3234
def function(self, x, alpha):
3335
return x * alpha
3436
@@ -92,7 +94,7 @@ def __init__(
9294
True, description="Whether to normalize the adstock values."
9395
),
9496
mode: ConvMode = Field(ConvMode.After, description="Convolution mode."),
95-
priors: dict[str, str | InstanceOf[Prior]] | None = Field(
97+
priors: dict[str, InstanceOf[Prior]] | None = Field(
9698
default=None, description="Priors for the parameters."
9799
),
98100
prefix: str | None = Field(None, description="Prefix for the parameters."),
@@ -103,6 +105,27 @@ def __init__(
103105

104106
super().__init__(priors=priors, prefix=prefix)
105107

108+
def __repr__(self) -> str:
109+
return (
110+
f"{self.__class__.__name__}("
111+
f"prefix={self.prefix!r}, "
112+
f"l_max={self.l_max}, "
113+
f"normalize={self.normalize}, "
114+
f"mode={self.mode.name!r}, "
115+
f"priors={self.function_priors}"
116+
")"
117+
)
118+
119+
def to_dict(self) -> dict:
120+
"""Convert the adstock transformation to a dictionary."""
121+
data = super().to_dict()
122+
123+
data["l_max"] = self.l_max
124+
data["normalize"] = self.normalize
125+
data["mode"] = self.mode.name
126+
127+
return data
128+
106129
def sample_curve(
107130
self,
108131
parameters: xr.Dataset,
@@ -371,6 +394,21 @@ def function(self, x, lam, k):
371394
}
372395

373396

397+
def register_adstock_transformation(cls: type[AdstockTransformation]) -> None:
398+
"""Register a new adstock transformation."""
399+
ADSTOCK_TRANSFORMATIONS[cls.lookup_name] = cls
400+
401+
402+
def adstock_from_dict(data: dict) -> AdstockTransformation:
403+
"""Create an adstock transformation from a dictionary."""
404+
data = data.copy()
405+
lookup_name = data.pop("lookup_name")
406+
cls = ADSTOCK_TRANSFORMATIONS[lookup_name]
407+
408+
data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()}
409+
return cls(**data)
410+
411+
374412
def _get_adstock_function(
375413
function: str | AdstockTransformation,
376414
**kwargs,

pymc_marketing/mmm/components/base.py

+40-4
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,10 @@ class Transformation:
9797
9898
Parameters
9999
----------
100-
priors : dict, optional
100+
priors : dict[str, Prior], optional
101101
Dictionary with the priors for the parameters of the function. The keys should be the
102-
parameter names and the values should be dictionaries with the distribution and kwargs.
102+
parameter names and the values the priors. If not provided, it will use the default
103+
priors from the subclass.
103104
prefix : str, optional
104105
The prefix for the variables that will be created. If not provided, it will use the prefix
105106
from the subclass.
@@ -112,12 +113,43 @@ class Transformation:
112113
lookup_name: str
113114

114115
def __init__(
115-
self, priors: dict[str, Any | Prior] | None = None, prefix: str | None = None
116+
self, priors: dict[str, Prior] | None = None, prefix: str | None = None
116117
) -> None:
117118
self._checks()
118119
self.function_priors = priors # type: ignore
119120
self.prefix = prefix or self.prefix
120121

122+
def __repr__(self) -> str:
123+
return (
124+
f"{self.__class__.__name__}("
125+
f"prefix={self.prefix!r}, "
126+
f"priors={self.function_priors}"
127+
")"
128+
)
129+
130+
def to_dict(self) -> dict[str, Any]:
131+
"""Convert the transformation to a dictionary.
132+
133+
Returns
134+
-------
135+
dict
136+
The dictionary defining the transformation.
137+
138+
"""
139+
return {
140+
"lookup_name": self.lookup_name,
141+
"prefix": self.prefix,
142+
"priors": {
143+
key: value.to_json() for key, value in self.function_priors.items()
144+
},
145+
}
146+
147+
def __eq__(self, other: Any) -> bool:
148+
if not isinstance(other, self.__class__):
149+
return False
150+
151+
return self.to_dict() == other.to_dict()
152+
121153
@property
122154
def function_priors(self) -> dict[str, Prior]:
123155
return self._function_priors
@@ -137,7 +169,7 @@ def update_priors(self, priors: dict[str, Prior]) -> None:
137169
138170
Parameters
139171
----------
140-
priors : dict
172+
priors : dict[str, Prior]
141173
Dictionary with the new priors for the parameters of the function.
142174
143175
Examples
@@ -150,6 +182,7 @@ def update_priors(self, priors: dict[str, Prior]) -> None:
150182
from pymc_marketing.prior import Prior
151183
152184
class MyTransformation(Transformation):
185+
lookup_name: str = "my_transformation"
153186
prefix: str = "transformation"
154187
function = lambda x, lam: x * lam
155188
default_priors = {"lam": Prior("Gamma", alpha=3, beta=1)}
@@ -200,6 +233,9 @@ def _has_all_attributes(self) -> None:
200233
if not hasattr(self, "function"):
201234
raise NotImplementedError("function must be implemented in the subclass")
202235

236+
if not hasattr(self, "lookup_name"):
237+
raise NotImplementedError("lookup_name must be implemented in the subclass")
238+
203239
def _has_defaults_for_all_arguments(self) -> None:
204240
function_signature = signature(self.function)
205241

pymc_marketing/mmm/components/saturation.py

+18
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from pymc_marketing.prior import Prior
2929
3030
class InfiniteReturns(SaturationTransformation):
31+
lookup_name: str = "infinite_returns"
32+
3133
def function(self, x, b):
3234
return b * x
3335
@@ -109,6 +111,7 @@ def infinite_returns(x, b):
109111
return b * x
110112
111113
class InfiniteReturns(SaturationTransformation):
114+
lookup_name = "infinite_returns"
112115
function = infinite_returns
113116
default_priors = {"b": Prior("HalfNormal")}
114117
@@ -417,6 +420,21 @@ def function(self, x, alpha, beta):
417420
}
418421

419422

423+
def register_saturation_transformation(cls: type[SaturationTransformation]) -> None:
424+
"""Register a new saturation transformation."""
425+
SATURATION_TRANSFORMATIONS[cls.lookup_name] = cls
426+
427+
428+
def saturation_from_dict(data: dict) -> SaturationTransformation:
429+
data = data.copy()
430+
cls = SATURATION_TRANSFORMATIONS[data.pop("lookup_name")]
431+
432+
data["priors"] = {
433+
key: Prior.from_json(value) for key, value in data["priors"].items()
434+
}
435+
return cls(**data)
436+
437+
420438
def _get_saturation_function(
421439
function: str | SaturationTransformation,
422440
) -> SaturationTransformation:

pymc_marketing/mmm/delayed_saturated_mmm.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@
3535
AdstockTransformation,
3636
GeometricAdstock,
3737
_get_adstock_function,
38+
adstock_from_dict,
3839
)
3940
from pymc_marketing.mmm.components.saturation import (
4041
LogisticSaturation,
4142
SaturationTransformation,
4243
_get_saturation_function,
44+
saturation_from_dict,
4345
)
4446
from pymc_marketing.mmm.fourier import YearlyFourier
4547
from pymc_marketing.mmm.lift_test import (
@@ -299,8 +301,8 @@ def _generate_and_preprocess_model_data( # type: ignore
299301
def create_idata_attrs(self) -> dict[str, str]:
300302
attrs = super().create_idata_attrs()
301303
attrs["date_column"] = json.dumps(self.date_column)
302-
attrs["adstock"] = json.dumps(self.adstock.lookup_name)
303-
attrs["saturation"] = json.dumps(self.saturation.lookup_name)
304+
attrs["adstock"] = json.dumps(self.adstock.to_dict())
305+
attrs["saturation"] = json.dumps(self.saturation.to_dict())
304306
attrs["adstock_first"] = json.dumps(self.adstock_first)
305307
attrs["control_columns"] = json.dumps(self.control_columns)
306308
attrs["channel_columns"] = json.dumps(self.channel_columns)
@@ -632,8 +634,8 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
632634
"control_columns": json.loads(attrs["control_columns"]),
633635
"channel_columns": json.loads(attrs["channel_columns"]),
634636
"adstock_max_lag": json.loads(attrs["adstock_max_lag"]),
635-
"adstock": json.loads(attrs.get("adstock", '"geometric"')),
636-
"saturation": json.loads(attrs.get("saturation", '"logistic"')),
637+
"adstock": adstock_from_dict(json.loads(attrs["adstock"])),
638+
"saturation": saturation_from_dict(json.loads(attrs["saturation"])),
637639
"adstock_first": json.loads(attrs.get("adstock_first", "true")),
638640
"yearly_seasonality": json.loads(attrs["yearly_seasonality"]),
639641
"time_varying_intercept": json.loads(
@@ -896,7 +898,7 @@ class MMM(
896898
""" # noqa: E501
897899

898900
_model_type: str = "MMM"
899-
version: str = "0.0.1"
901+
version: str = "0.0.2"
900902

901903
def channel_contributions_forward_pass(
902904
self, channel_data: npt.NDArray[np.float64]

tests/mmm/components/test_adstock.py

+67
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,19 @@
2121
from pydantic import ValidationError
2222

2323
from pymc_marketing.mmm.components.adstock import (
24+
ADSTOCK_TRANSFORMATIONS,
2425
AdstockTransformation,
2526
DelayedAdstock,
2627
GeometricAdstock,
2728
WeibullAdstock,
2829
WeibullCDFAdstock,
2930
WeibullPDFAdstock,
3031
_get_adstock_function,
32+
adstock_from_dict,
33+
register_adstock_transformation,
3134
)
35+
from pymc_marketing.mmm.transformers import ConvMode
36+
from pymc_marketing.prior import Prior
3237

3338

3439
def adstocks() -> list[AdstockTransformation]:
@@ -141,3 +146,65 @@ def test_adstock_sample_curve(adstock) -> None:
141146
assert isinstance(curve, xr.DataArray)
142147
assert curve.name == "adstock"
143148
assert curve.shape == (1, 500, adstock.l_max)
149+
150+
151+
def test_adstock_from_dict() -> None:
152+
data = {
153+
"lookup_name": "geometric",
154+
"l_max": 10,
155+
"prefix": "test",
156+
"mode": "Before",
157+
"priors": {
158+
"alpha": {
159+
"dist": "Beta",
160+
"kwargs": {
161+
"alpha": 1,
162+
"beta": 2,
163+
},
164+
},
165+
},
166+
}
167+
168+
adstock = adstock_from_dict(data)
169+
assert adstock == GeometricAdstock(
170+
l_max=10,
171+
prefix="test",
172+
priors={
173+
"alpha": Prior("Beta", alpha=1, beta=2),
174+
},
175+
mode=ConvMode.Before,
176+
)
177+
178+
179+
def test_register_adstock_transformation() -> None:
180+
class NewTransformation(AdstockTransformation):
181+
lookup_name: str = "new_transformation"
182+
default_priors = {}
183+
184+
def function(self, x):
185+
return x
186+
187+
register_adstock_transformation(NewTransformation)
188+
assert "new_transformation" in ADSTOCK_TRANSFORMATIONS
189+
190+
data = {
191+
"lookup_name": "new_transformation",
192+
"l_max": 10,
193+
"normalize": False,
194+
"mode": "Before",
195+
"priors": {},
196+
}
197+
adstock = adstock_from_dict(data)
198+
assert adstock == NewTransformation(
199+
l_max=10, mode=ConvMode.Before, normalize=False, priors={}
200+
)
201+
202+
203+
def test_repr() -> None:
204+
assert repr(GeometricAdstock(l_max=10)) == (
205+
"GeometricAdstock(prefix='adstock', l_max=10, "
206+
"normalize=True, "
207+
"mode='After', "
208+
"priors={'alpha': Prior(\"Beta\", alpha=1, beta=3)}"
209+
")"
210+
)

0 commit comments

Comments
 (0)