Skip to content

Commit 255eac1

Browse files
authored
deprecate WeibullAdstock in favor of WeibullCDFAdstock and WeibullPDFAdstock (#957)
* deprecate in favor of WeibullCDFAdstock and WeibullPDFAdstock * Update UML Diagrams * Update UML Diagrams
1 parent 39d38b7 commit 255eac1

File tree

5 files changed

+2
-76
lines changed

5 files changed

+2
-76
lines changed

docs/source/uml/classes_mmm.png

-25.6 KB
Loading

pymc_marketing/mmm/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
AdstockTransformation,
2020
DelayedAdstock,
2121
GeometricAdstock,
22-
WeibullAdstock,
2322
WeibullCDFAdstock,
2423
WeibullPDFAdstock,
2524
adstock_from_dict,
@@ -66,7 +65,6 @@
6665
"TanhSaturationBaselined",
6766
"saturation_from_dict",
6867
"register_saturation_transformation",
69-
"WeibullAdstock",
7068
"WeibullCDFAdstock",
7169
"WeibullPDFAdstock",
7270
"adstock_from_dict",

pymc_marketing/mmm/components/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pymc_marketing.mmm import (
2424
SaturationTransformation,
2525
MMM,
26-
WeibullAdstock,
26+
WeibullPDFAdstock,
2727
)
2828
2929
class InfiniteReturns(SaturationTransformation):
@@ -34,7 +34,7 @@ def function(self, x, b):
3434
3535
3636
saturation = InfiniteReturns()
37-
adstock = WeibullAdstock(l_max=15, kind="PDF")
37+
adstock = WeibullPDFAdstock(l_max=15)
3838
3939
mmm = MMM(
4040
...,

pymc_marketing/mmm/components/adstock.py

-68
Original file line numberDiff line numberDiff line change
@@ -320,79 +320,11 @@ def function(self, x, lam, k):
320320
}
321321

322322

323-
class WeibullAdstock(AdstockTransformation):
324-
"""Wrapper around weibull adstock function.
325-
326-
For more information, see :func:`pymc_marketing.mmm.transformers.weibull_adstock`.
327-
328-
.. plot::
329-
:context: close-figs
330-
331-
import matplotlib.pyplot as plt
332-
import numpy as np
333-
from pymc_marketing.mmm import WeibullAdstock
334-
335-
rng = np.random.default_rng(0)
336-
337-
adstock = WeibullAdstock(l_max=10, kind="CDF")
338-
prior = adstock.sample_prior(random_seed=rng)
339-
curve = adstock.sample_curve(prior)
340-
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
341-
plt.show()
342-
343-
"""
344-
345-
lookup_name = "weibull"
346-
347-
def __init__(
348-
self,
349-
l_max: int,
350-
normalize: bool = True,
351-
kind=WeibullType.PDF,
352-
mode: ConvMode = ConvMode.After,
353-
priors: dict | None = None,
354-
prefix: str | None = None,
355-
) -> None:
356-
self.kind = kind
357-
358-
super().__init__(
359-
l_max=l_max, normalize=normalize, mode=mode, priors=priors, prefix=prefix
360-
)
361-
362-
msg = (
363-
f"Use the Weibull{kind}Adstock class instead for better default priors. "
364-
"This class will deprecate in 0.9.0."
365-
)
366-
warnings.warn(
367-
msg,
368-
UserWarning,
369-
stacklevel=1,
370-
)
371-
372-
def function(self, x, lam, k):
373-
"""Weibull adstock function."""
374-
return weibull_adstock(
375-
x=x,
376-
lam=lam,
377-
k=k,
378-
l_max=self.l_max,
379-
mode=self.mode,
380-
type=self.kind,
381-
normalize=self.normalize,
382-
)
383-
384-
default_priors = {
385-
"lam": Prior("HalfNormal", sigma=1),
386-
"k": Prior("HalfNormal", sigma=1),
387-
}
388-
389-
390323
ADSTOCK_TRANSFORMATIONS: dict[str, type[AdstockTransformation]] = {
391324
cls.lookup_name: cls # type: ignore
392325
for cls in [
393326
GeometricAdstock,
394327
DelayedAdstock,
395-
WeibullAdstock,
396328
WeibullPDFAdstock,
397329
WeibullCDFAdstock,
398330
]

tests/mmm/components/test_adstock.py

-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
AdstockTransformation,
2525
DelayedAdstock,
2626
GeometricAdstock,
27-
WeibullAdstock,
2827
WeibullCDFAdstock,
2928
WeibullPDFAdstock,
3029
adstock_from_dict,
@@ -44,8 +43,6 @@ def adstocks() -> list[AdstockTransformation]:
4443
return [
4544
DelayedAdstock(l_max=10),
4645
GeometricAdstock(l_max=10),
47-
WeibullAdstock(l_max=10, kind="PDF"),
48-
WeibullAdstock(l_max=10, kind="CDF"),
4946
WeibullPDFAdstock(l_max=10),
5047
WeibullCDFAdstock(l_max=10),
5148
]
@@ -95,7 +92,6 @@ def test_default_prefix(adstock) -> None:
9592
[
9693
("delayed", DelayedAdstock, {"l_max": 10}),
9794
("geometric", GeometricAdstock, {"l_max": 10}),
98-
("weibull", WeibullAdstock, {"l_max": 10}),
9995
],
10096
)
10197
def test_get_adstock_function(name, adstock_cls, kwargs):

0 commit comments

Comments
 (0)