1
- from typing import Optional , Sequence , Tuple
1
+ from typing import Optional , Sequence , Tuple , Union
2
2
3
3
import matplotlib .pyplot as plt
4
4
import numpy as np
5
5
import pandas as pd
6
6
from matplotlib .lines import Line2D
7
7
8
+ from pymc_marketing .clv import BetaGeoModel , ParetoNBDModel
9
+
8
10
__all__ = [
9
11
"plot_customer_exposure" ,
10
12
"plot_frequency_recency_matrix" ,
@@ -156,7 +158,7 @@ def _create_frequency_recency_meshes(
156
158
157
159
158
160
def plot_frequency_recency_matrix (
159
- model ,
161
+ model : Union [ BetaGeoModel , ParetoNBDModel ] ,
160
162
t = 1 ,
161
163
max_frequency : Optional [int ] = None ,
162
164
max_recency : Optional [int ] = None ,
@@ -172,8 +174,8 @@ def plot_frequency_recency_matrix(
172
174
173
175
Parameters
174
176
----------
175
- model: lifetimes model
176
- A fitted lifetimes model.
177
+ model: CLV model
178
+ A fitted CLV model.
177
179
t: float, optional
178
180
Next units of time to make predictions for
179
181
max_frequency: int, optional
@@ -197,27 +199,49 @@ def plot_frequency_recency_matrix(
197
199
axes: matplotlib.AxesSubplot
198
200
"""
199
201
if max_frequency is None :
200
- max_frequency = int (model .frequency .max ())
202
+ max_frequency = int (model .data [ " frequency" ] .max ())
201
203
202
204
if max_recency is None :
203
- max_recency = int (model .recency .max ())
205
+ max_recency = int (model .data [ " recency" ] .max ())
204
206
205
207
mesh_frequency , mesh_recency = _create_frequency_recency_meshes (
206
208
max_frequency = max_frequency ,
207
209
max_recency = max_recency ,
208
210
)
209
211
210
- Z = (
211
- model .expected_num_purchases (
212
- customer_id = np .arange (mesh_recency .size ), # placeholder
213
- t = t ,
214
- frequency = mesh_frequency .ravel (),
215
- recency = mesh_recency .ravel (),
216
- T = max_recency ,
212
+ # FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
213
+ # We should harmonize them!
214
+ if isinstance (model , ParetoNBDModel ):
215
+ transaction_data = pd .DataFrame (
216
+ {
217
+ "customer_id" : np .arange (mesh_recency .size ), # placeholder
218
+ "frequency" : mesh_frequency .ravel (),
219
+ "recency" : mesh_recency .ravel (),
220
+ "T" : max_recency ,
221
+ }
217
222
)
218
- .mean (("draw" , "chain" ))
219
- .values .reshape (mesh_recency .shape )
220
- )
223
+
224
+ Z = (
225
+ model .expected_purchases (
226
+ data = transaction_data ,
227
+ future_t = t ,
228
+ )
229
+ .mean (("draw" , "chain" ))
230
+ .values .reshape (mesh_recency .shape )
231
+ )
232
+ else :
233
+ Z = (
234
+ model .expected_num_purchases (
235
+ customer_id = np .arange (mesh_recency .size ), # placeholder
236
+ frequency = mesh_frequency .ravel (),
237
+ recency = mesh_recency .ravel (),
238
+ T = max_recency ,
239
+ t = t ,
240
+ )
241
+ .mean (("draw" , "chain" ))
242
+ .values .reshape (mesh_recency .shape )
243
+ )
244
+
221
245
if ax is None :
222
246
ax = plt .subplot (111 )
223
247
@@ -245,7 +269,7 @@ def plot_frequency_recency_matrix(
245
269
246
270
247
271
def plot_probability_alive_matrix (
248
- model ,
272
+ model : Union [ BetaGeoModel , ParetoNBDModel ] ,
249
273
max_frequency : Optional [int ] = None ,
250
274
max_recency : Optional [int ] = None ,
251
275
title : str = "Probability Customer is Alive,\n by Frequency and Recency of a Customer" ,
@@ -261,8 +285,8 @@ def plot_probability_alive_matrix(
261
285
262
286
Parameters
263
287
----------
264
- model: lifetimes model
265
- A fitted lifetimes model.
288
+ model: CLV model
289
+ A fitted CLV model.
266
290
max_frequency: int, optional
267
291
The maximum frequency to plot. Default is max observed frequency.
268
292
max_recency: int, optional
@@ -285,26 +309,46 @@ def plot_probability_alive_matrix(
285
309
"""
286
310
287
311
if max_frequency is None :
288
- max_frequency = int (model .frequency .max ())
312
+ max_frequency = int (model .data [ " frequency" ] .max ())
289
313
290
314
if max_recency is None :
291
- max_recency = int (model .recency .max ())
315
+ max_recency = int (model .data [ " recency" ] .max ())
292
316
293
317
mesh_frequency , mesh_recency = _create_frequency_recency_meshes (
294
318
max_frequency = max_frequency ,
295
319
max_recency = max_recency ,
296
320
)
321
+ # FIXME: This is a hotfix for ParetoNBDModel, as it has a different API from BetaGeoModel
322
+ # We should harmonize them!
323
+ if isinstance (model , ParetoNBDModel ):
324
+ transaction_data = pd .DataFrame (
325
+ {
326
+ "customer_id" : np .arange (mesh_recency .size ), # placeholder
327
+ "frequency" : mesh_frequency .ravel (),
328
+ "recency" : mesh_recency .ravel (),
329
+ "T" : max_recency ,
330
+ }
331
+ )
297
332
298
- Z = (
299
- model .expected_probability_alive (
300
- customer_id = np .arange (mesh_recency .size ), # placeholder
301
- frequency = mesh_frequency .ravel (),
302
- recency = mesh_recency .ravel (),
303
- T = max_recency ,
333
+ Z = (
334
+ model .expected_probability_alive (
335
+ data = transaction_data ,
336
+ future_t = 0 , # TODO: This can be a function parameter in the case of ParetoNBDModel
337
+ )
338
+ .mean (("draw" , "chain" ))
339
+ .values .reshape (mesh_recency .shape )
340
+ )
341
+ else :
342
+ Z = (
343
+ model .expected_probability_alive (
344
+ customer_id = np .arange (mesh_recency .size ), # placeholder
345
+ frequency = mesh_frequency .ravel (),
346
+ recency = mesh_recency .ravel (),
347
+ T = max_recency , # type: ignore
348
+ )
349
+ .mean (("draw" , "chain" ))
350
+ .values .reshape (mesh_recency .shape )
304
351
)
305
- .mean (("draw" , "chain" ))
306
- .values .reshape (mesh_recency .shape )
307
- )
308
352
309
353
interpolation = kwargs .pop ("interpolation" , "none" )
310
354
0 commit comments