Skip to content

Commit dc4296c

Browse files
PabloRoquewd60622pre-commit-ci[bot]
authored andcommitted
Enforce check_parameters for alpha in geometric_adstock (#960)
* Enforce check_parameters in geometric_adstock * Remove check_parameters on l_max. Revert to original test * Add test_geometric_adstock_bad_alpha * use ge, le instead of gt,lt * Update tests/mmm/test_transformers.py Co-authored-by: Will Dean <[email protected]> * Simplify test parameters * Update tests/mmm/test_transformers.py Co-authored-by: Will Dean <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Will Dean <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 57378f2 commit dc4296c

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

pymc_marketing/mmm/transformers.py

+5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy.typing as npt
2121
import pymc as pm
2222
import pytensor.tensor as pt
23+
from pymc.distributions.dist_math import check_parameters
2324
from pytensor.tensor.random.utils import params_broadcast_shapes
2425

2526

@@ -235,6 +236,10 @@ def geometric_adstock(
235236
with carryover and shape effects." (2017).
236237
237238
"""
239+
alpha = check_parameters(
240+
alpha, [pt.ge(alpha, 0), pt.le(alpha, 1)], msg="0 <= alpha <= 1"
241+
)
242+
238243
w = pt.power(pt.as_tensor(alpha)[..., None], pt.arange(l_max, dtype=x.dtype))
239244
w = w / pt.sum(w, axis=-1, keepdims=True) if normalize else w
240245
return batched_convolution(x, w, axis=axis, mode=mode)

tests/mmm/test_transformers.py

+18
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytensor.tensor as pt
1919
import pytest
2020
import scipy as sp
21+
from pymc.logprob.utils import ParameterValueError
2122
from pytensor.tensor.variable import TensorVariable
2223

2324
from pymc_marketing.mmm.transformers import (
@@ -148,6 +149,23 @@ def test_geometric_adstock_good_alpha(self, x, alpha, l_max):
148149
assert y_np[1] == x[1] + alpha * x[0]
149150
assert y_np[2] == x[2] + alpha * x[1] + (alpha**2) * x[0]
150151

152+
@pytest.mark.parametrize(
153+
"alpha",
154+
[-0.3, -2, 22.5, 2],
155+
ids=[
156+
"less_than_zero_0",
157+
"less_than_zero_1",
158+
"greater_than_one_0",
159+
"greater_than_one_1",
160+
],
161+
)
162+
def test_geometric_adstock_bad_alpha(self, alpha):
163+
l_max = 10
164+
x = np.ones(shape=100)
165+
y = geometric_adstock(x=x, alpha=alpha, l_max=l_max)
166+
with pytest.raises(ParameterValueError):
167+
y.eval()
168+
151169
@pytest.mark.parametrize(
152170
argnames="mode",
153171
argvalues=[ConvMode.After, ConvMode.Before, ConvMode.Overlap],

0 commit comments

Comments
 (0)