Skip to content

Commit c69eb97

Browse files
ColtAllentwiecki
authored andcommitted
Fix clv plotting bugs and edits to Quickstart (#601)
* move fixtures to conftest * docstrings and moved set_model_fit to conftest * fixed pandas quickstart warnings * revert to MockModel and add ParetoNBD support * quickstart edit for issue 609 * notebook edit
1 parent 210456f commit c69eb97

File tree

9 files changed

+445
-413
lines changed

9 files changed

+445
-413
lines changed

docs/source/notebooks/clv/clv_quickstart.ipynb

+286-304
Large diffs are not rendered by default.

pymc_marketing/clv/plotting.py

+74-30
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Optional, Sequence, Tuple
1+
from typing import Optional, Sequence, Tuple, Union
22

33
import matplotlib.pyplot as plt
44
import numpy as np
55
import pandas as pd
66
from matplotlib.lines import Line2D
77

8+
from pymc_marketing.clv import BetaGeoModel, ParetoNBDModel
9+
810
__all__ = [
911
"plot_customer_exposure",
1012
"plot_frequency_recency_matrix",
@@ -156,7 +158,7 @@ def _create_frequency_recency_meshes(
156158

157159

158160
def plot_frequency_recency_matrix(
159-
model,
161+
model: Union[BetaGeoModel, ParetoNBDModel],
160162
t=1,
161163
max_frequency: Optional[int] = None,
162164
max_recency: Optional[int] = None,
@@ -172,8 +174,8 @@ def plot_frequency_recency_matrix(
172174
173175
Parameters
174176
----------
175-
model: lifetimes model
176-
A fitted lifetimes model.
177+
model: CLV model
178+
A fitted CLV model.
177179
t: float, optional
178180
Next units of time to make predictions for
179181
max_frequency: int, optional
@@ -197,27 +199,49 @@ def plot_frequency_recency_matrix(
197199
axes: matplotlib.AxesSubplot
198200
"""
199201
if max_frequency is None:
200-
max_frequency = int(model.frequency.max())
202+
max_frequency = int(model.data["frequency"].max())
201203

202204
if max_recency is None:
203-
max_recency = int(model.recency.max())
205+
max_recency = int(model.data["recency"].max())
204206

205207
mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
206208
max_frequency=max_frequency,
207209
max_recency=max_recency,
208210
)
209211

210-
Z = (
211-
model.expected_num_purchases(
212-
customer_id=np.arange(mesh_recency.size), # placeholder
213-
t=t,
214-
frequency=mesh_frequency.ravel(),
215-
recency=mesh_recency.ravel(),
216-
T=max_recency,
212+
# FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
213+
# We should harmonize them!
214+
if isinstance(model, ParetoNBDModel):
215+
transaction_data = pd.DataFrame(
216+
{
217+
"customer_id": np.arange(mesh_recency.size), # placeholder
218+
"frequency": mesh_frequency.ravel(),
219+
"recency": mesh_recency.ravel(),
220+
"T": max_recency,
221+
}
217222
)
218-
.mean(("draw", "chain"))
219-
.values.reshape(mesh_recency.shape)
220-
)
223+
224+
Z = (
225+
model.expected_purchases(
226+
data=transaction_data,
227+
future_t=t,
228+
)
229+
.mean(("draw", "chain"))
230+
.values.reshape(mesh_recency.shape)
231+
)
232+
else:
233+
Z = (
234+
model.expected_num_purchases(
235+
customer_id=np.arange(mesh_recency.size), # placeholder
236+
frequency=mesh_frequency.ravel(),
237+
recency=mesh_recency.ravel(),
238+
T=max_recency,
239+
t=t,
240+
)
241+
.mean(("draw", "chain"))
242+
.values.reshape(mesh_recency.shape)
243+
)
244+
221245
if ax is None:
222246
ax = plt.subplot(111)
223247

@@ -245,7 +269,7 @@ def plot_frequency_recency_matrix(
245269

246270

247271
def plot_probability_alive_matrix(
248-
model,
272+
model: Union[BetaGeoModel, ParetoNBDModel],
249273
max_frequency: Optional[int] = None,
250274
max_recency: Optional[int] = None,
251275
title: str = "Probability Customer is Alive,\nby Frequency and Recency of a Customer",
@@ -261,8 +285,8 @@ def plot_probability_alive_matrix(
261285
262286
Parameters
263287
----------
264-
model: lifetimes model
265-
A fitted lifetimes model.
288+
model: CLV model
289+
A fitted CLV model.
266290
max_frequency: int, optional
267291
The maximum frequency to plot. Default is max observed frequency.
268292
max_recency: int, optional
@@ -285,26 +309,46 @@ def plot_probability_alive_matrix(
285309
"""
286310

287311
if max_frequency is None:
288-
max_frequency = int(model.frequency.max())
312+
max_frequency = int(model.data["frequency"].max())
289313

290314
if max_recency is None:
291-
max_recency = int(model.recency.max())
315+
max_recency = int(model.data["recency"].max())
292316

293317
mesh_frequency, mesh_recency = _create_frequency_recency_meshes(
294318
max_frequency=max_frequency,
295319
max_recency=max_recency,
296320
)
321+
# FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
322+
# We should harmonize them!
323+
if isinstance(model, ParetoNBDModel):
324+
transaction_data = pd.DataFrame(
325+
{
326+
"customer_id": np.arange(mesh_recency.size), # placeholder
327+
"frequency": mesh_frequency.ravel(),
328+
"recency": mesh_recency.ravel(),
329+
"T": max_recency,
330+
}
331+
)
297332

298-
Z = (
299-
model.expected_probability_alive(
300-
customer_id=np.arange(mesh_recency.size), # placeholder
301-
frequency=mesh_frequency.ravel(),
302-
recency=mesh_recency.ravel(),
303-
T=max_recency,
333+
Z = (
334+
model.expected_probability_alive(
335+
data=transaction_data,
336+
future_t=0, # TODO: This can be a function parameter in the case of ParetoNBDModel
337+
)
338+
.mean(("draw", "chain"))
339+
.values.reshape(mesh_recency.shape)
340+
)
341+
else:
342+
Z = (
343+
model.expected_probability_alive(
344+
customer_id=np.arange(mesh_recency.size), # placeholder
345+
frequency=mesh_frequency.ravel(),
346+
recency=mesh_recency.ravel(),
347+
T=max_recency, # type: ignore
348+
)
349+
.mean(("draw", "chain"))
350+
.values.reshape(mesh_recency.shape)
304351
)
305-
.mean(("draw", "chain"))
306-
.values.reshape(mesh_recency.shape)
307-
)
308352

309353
interpolation = kwargs.pop("interpolation", "none")
310354

tests/clv/models/test_basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from arviz import InferenceData, from_dict
88

99
from pymc_marketing.clv.models.basic import CLVModel
10-
from tests.clv.utils import set_model_fit
10+
from tests.conftest import set_model_fit
1111

1212

1313
class CLVModelTest(CLVModel):

tests/clv/models/test_gamma_gamma.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
GammaGammaModel,
1111
GammaGammaModelIndividual,
1212
)
13-
from tests.clv.utils import set_model_fit
13+
from tests.conftest import set_model_fit
1414

1515

1616
class BaseTestGammaGammaModel:

tests/clv/models/test_pareto_nbd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pymc_marketing.clv import ParetoNBDModel
1111
from pymc_marketing.clv.distributions import ParetoNBD
12-
from tests.clv.utils import set_model_fit
12+
from tests.conftest import set_model_fit
1313

1414

1515
class TestParetoNBDModel:

tests/clv/test_plotting.py

+52-48
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,57 @@
1414
)
1515

1616

17-
@pytest.fixture(scope="module")
18-
def test_summary_data() -> pd.DataFrame:
19-
return pd.read_csv("tests/clv/datasets/test_summary_data.csv", index_col=0)
17+
class MockModel:
18+
def __init__(self, data: pd.DataFrame):
19+
self.data = data
20+
21+
def _mock_posterior(
22+
self, customer_id: Union[np.ndarray, pd.Series]
23+
) -> xr.DataArray:
24+
n_customers = len(customer_id)
25+
n_chains = 4
26+
n_draws = 10
27+
chains = np.arange(n_chains)
28+
draws = np.arange(n_draws)
29+
return xr.DataArray(
30+
data=np.ones((n_customers, n_chains, n_draws)),
31+
coords={"customer_id": customer_id, "chain": chains, "draw": draws},
32+
dims=["customer_id", "chain", "draw"],
33+
)
34+
35+
def expected_probability_alive(
36+
self,
37+
customer_id: Union[np.ndarray, pd.Series],
38+
frequency: Union[np.ndarray, pd.Series],
39+
recency: Union[np.ndarray, pd.Series],
40+
T: Union[np.ndarray, pd.Series],
41+
):
42+
return self._mock_posterior(customer_id)
43+
44+
def expected_purchases(
45+
self,
46+
customer_id: Union[np.ndarray, pd.Series],
47+
data: pd.DataFrame,
48+
*,
49+
future_t: Union[np.ndarray, pd.Series, TensorVariable],
50+
):
51+
return self._mock_posterior(customer_id)
52+
53+
# TODO: This is required until CLV API is standardized.
54+
def expected_num_purchases(
55+
self,
56+
customer_id: Union[np.ndarray, pd.Series],
57+
t: Union[np.ndarray, pd.Series, TensorVariable],
58+
frequency: Union[np.ndarray, pd.Series, TensorVariable],
59+
recency: Union[np.ndarray, pd.Series, TensorVariable],
60+
T: Union[np.ndarray, pd.Series, TensorVariable],
61+
):
62+
return self._mock_posterior(customer_id)
63+
64+
65+
@pytest.fixture
66+
def mock_model(test_summary_data) -> MockModel:
67+
return MockModel(test_summary_data)
2068

2169

2270
@pytest.mark.parametrize(
@@ -33,7 +81,7 @@ def test_plot_customer_exposure(test_summary_data, kwargs) -> None:
3381
assert isinstance(ax, plt.Axes)
3482

3583

36-
def test_plot_cumstomer_exposure_with_ax(test_summary_data) -> None:
84+
def test_plot_customer_exposure_with_ax(test_summary_data) -> None:
3785
ax = plt.subplot()
3886
plot_customer_exposure(test_summary_data, ax=ax)
3987

@@ -59,50 +107,6 @@ def test_plot_customer_exposure_invalid_args(test_summary_data, kwargs) -> None:
59107
plot_customer_exposure(test_summary_data, **kwargs)
60108

61109

62-
class MockModel:
63-
def __init__(self, frequency, recency):
64-
self.frequency = frequency
65-
self.recency = recency
66-
67-
def _mock_posterior(
68-
self, customer_id: Union[np.ndarray, pd.Series]
69-
) -> xr.DataArray:
70-
n_customers = len(customer_id)
71-
n_chains = 4
72-
n_draws = 10
73-
chains = np.arange(n_chains)
74-
draws = np.arange(n_draws)
75-
return xr.DataArray(
76-
data=np.ones((n_customers, n_chains, n_draws)),
77-
coords={"customer_id": customer_id, "chain": chains, "draw": draws},
78-
dims=["customer_id", "chain", "draw"],
79-
)
80-
81-
def expected_probability_alive(
82-
self,
83-
customer_id: Union[np.ndarray, pd.Series],
84-
frequency: Union[np.ndarray, pd.Series],
85-
recency: Union[np.ndarray, pd.Series],
86-
T: Union[np.ndarray, pd.Series],
87-
):
88-
return self._mock_posterior(customer_id)
89-
90-
def expected_num_purchases(
91-
self,
92-
customer_id: Union[np.ndarray, pd.Series],
93-
t: Union[np.ndarray, pd.Series, TensorVariable],
94-
frequency: Union[np.ndarray, pd.Series, TensorVariable],
95-
recency: Union[np.ndarray, pd.Series, TensorVariable],
96-
T: Union[np.ndarray, pd.Series, TensorVariable],
97-
):
98-
return self._mock_posterior(customer_id)
99-
100-
101-
@pytest.fixture
102-
def mock_model(test_summary_data) -> MockModel:
103-
return MockModel(test_summary_data["frequency"], test_summary_data["recency"])
104-
105-
106110
def test_plot_frequency_recency_matrix(mock_model) -> None:
107111
ax: plt.Axes = plot_frequency_recency_matrix(mock_model)
108112

tests/clv/test_utils.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
rfm_train_test_split,
1717
to_xarray,
1818
)
19-
from tests.clv.utils import set_model_fit
19+
from tests.conftest import set_model_fit
2020

2121

2222
def test_to_xarray():
@@ -42,15 +42,6 @@ def test_to_xarray():
4242
np.testing.assert_array_equal(new_y.coords["test_dim"], customer_id)
4343

4444

45-
@pytest.fixture(scope="module")
46-
def test_summary_data() -> pd.DataFrame:
47-
rng = np.random.default_rng(14)
48-
df = pd.read_csv("tests/clv/datasets/test_summary_data.csv", index_col=0)
49-
df["monetary_value"] = rng.lognormal(size=(len(df)))
50-
df["customer_id"] = df.index
51-
return df
52-
53-
5445
@pytest.fixture(scope="module")
5546
def fitted_bg(test_summary_data) -> BetaGeoModel:
5647
rng = np.random.default_rng(13)
@@ -100,6 +91,7 @@ def fitted_pnbd(test_summary_data) -> ParetoNBDModel:
10091
pnbd_model.build_model()
10192

10293
# Mock an idata object for tests requiring a fitted model
94+
# TODO: This is quite slow. Check similar fixtures in the model tests to speed this up.
10395
fake_fit = pm.sample_prior_predictive(
10496
samples=50, model=pnbd_model.model, random_seed=rng
10597
).prior

tests/clv/utils.py

-18
This file was deleted.

0 commit comments

Comments
 (0)