Skip to content

Commit 1ad7ed3

Browse files
wd60622twiecki
authored andcommitted
model.fit doesn't remove prior samples (#741)
* type hint only * more informative errors * check for attr * remove type ignore * check for attr * check for attr * reduce indentation * new error names
1 parent 95bf0c3 commit 1ad7ed3

File tree

7 files changed

+87
-48
lines changed

7 files changed

+87
-48
lines changed

pymc_marketing/clv/models/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _validate_cols(
5858
raise ValueError(f"Column {required_col} has duplicate entries")
5959

6060
def __repr__(self):
61-
if self.model is None:
61+
if not hasattr(self, "model"):
6262
return self._model_type
6363
else:
6464
return f"{self._model_type}\n{self.model.str_repr()}"

pymc_marketing/mmm/base.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,16 @@ def get_target_transformer(self) -> Pipeline:
267267
def prior(self) -> Dataset:
268268
if self.idata is None or "prior" not in self.idata:
269269
raise RuntimeError(
270-
"The model hasn't been fit yet, call .sample_prior_predictive() with extend_idata=True first"
270+
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
271271
)
272272
return self.idata["prior"]
273273

274274
@property
275-
def prior_predictive(self) -> az.InferenceData:
275+
def prior_predictive(self) -> Dataset:
276276
if self.idata is None or "prior_predictive" not in self.idata:
277-
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
277+
raise RuntimeError(
278+
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
279+
)
278280
return self.idata["prior_predictive"]
279281

280282
@property
@@ -286,7 +288,9 @@ def fit_result(self) -> Dataset:
286288
@property
287289
def posterior_predictive(self) -> Dataset:
288290
if self.idata is None or "posterior_predictive" not in self.idata:
289-
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
291+
raise RuntimeError(
292+
"The model hasn't been fit yet, call .sample_posterior_predictive() first"
293+
)
290294
return self.idata["posterior_predictive"]
291295

292296
def plot_prior_predictive(

pymc_marketing/mmm/delayed_saturated_mmm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1829,7 +1829,7 @@ def add_lift_test_measurements(
18291829
model.add_lift_test_measurements(df_lift_test)
18301830
18311831
"""
1832-
if self.model is None:
1832+
if not hasattr(self, "model"):
18331833
raise RuntimeError(
18341834
"The model has not been built yet. Please, build the model first."
18351835
)

pymc_marketing/model_builder.py

+40-30
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
self.model_config = (
8787
self.default_model_config | model_config
8888
) # parameters for priors etc.
89-
self.model: pm.Model | None = None # Set by build_model
89+
self.model: pm.Model
9090
self.idata: az.InferenceData | None = None # idata is generated during fitting
9191
self.is_fitted_ = False
9292

@@ -458,19 +458,22 @@ def fit(
458458
if self.X is None or self.y is None:
459459
raise ValueError("X and y must be set before calling build_model!")
460460

461-
if self.model is None:
461+
if not hasattr(self, "model"):
462462
self.build_model(self.X, self.y)
463463

464464
sampler_config = self.sampler_config.copy()
465465
sampler_config["progressbar"] = progressbar
466466
sampler_config["random_seed"] = random_seed
467467
sampler_config.update(**kwargs)
468468

469-
sampler_config.update(**kwargs)
470-
if self.model is not None:
471-
with self.model:
472-
sampler_args = {**self.sampler_config, **kwargs}
473-
self.idata = pm.sample(**sampler_args)
469+
sampler_args = {**self.sampler_config, **kwargs}
470+
with self.model:
471+
idata = pm.sample(**sampler_args)
472+
473+
if self.idata:
474+
self.idata.extend(idata, join="right")
475+
else:
476+
self.idata = idata
474477

475478
X_df = pd.DataFrame(X, columns=X.columns)
476479
combined_data = pd.concat([X_df, y_df], axis=1)
@@ -537,7 +540,7 @@ def sample_prior_predictive(
537540
X_pred,
538541
y_pred=None,
539542
samples: int | None = None,
540-
extend_idata: bool = False,
543+
extend_idata: bool = True,
541544
combined: bool = True,
542545
**kwargs,
543546
):
@@ -552,7 +555,7 @@ def sample_prior_predictive(
552555
Number of samples from the prior parameter distributions to generate.
553556
If not set, uses sampler_config['draws'] if that is available, otherwise defaults to 500.
554557
extend_idata : Boolean determining whether the predictions should be added to inference data object.
555-
Defaults to False.
558+
Defaults to True.
556559
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
557560
Defaults to True.
558561
**kwargs: Additional arguments to pass to pymc.sample_prior_predictive
@@ -567,21 +570,19 @@ def sample_prior_predictive(
567570
if samples is None:
568571
samples = self.sampler_config.get("draws", 500)
569572

570-
if self.model is None:
573+
if not hasattr(self, "model"):
571574
self.build_model(X_pred, y_pred)
572575

573576
self._data_setter(X_pred, y_pred)
574-
if self.model is not None:
575-
with self.model: # sample with new input data
576-
prior_pred: az.InferenceData = pm.sample_prior_predictive(
577-
samples, **kwargs
578-
)
579-
self.set_idata_attrs(prior_pred)
580-
if extend_idata:
581-
if self.idata is not None:
582-
self.idata.extend(prior_pred, join="right")
583-
else:
584-
self.idata = prior_pred
577+
with self.model: # sample with new input data
578+
prior_pred: az.InferenceData = pm.sample_prior_predictive(samples, **kwargs)
579+
self.set_idata_attrs(prior_pred)
580+
581+
if extend_idata:
582+
if self.idata is not None:
583+
self.idata.extend(prior_pred, join="right")
584+
else:
585+
self.idata = prior_pred
585586

586587
prior_predictive_samples = az.extract(
587588
prior_pred, "prior_predictive", combined=combined
@@ -590,7 +591,11 @@ def sample_prior_predictive(
590591
return prior_predictive_samples
591592

592593
def sample_posterior_predictive(
593-
self, X_pred, extend_idata: bool = True, combined: bool = True, **kwargs
594+
self,
595+
X_pred,
596+
extend_idata: bool = True,
597+
combined: bool = True,
598+
**sample_posterior_predictive_kwargs,
594599
):
595600
"""
596601
Sample from the model's posterior predictive distribution.
@@ -603,7 +608,7 @@ def sample_posterior_predictive(
603608
Defaults to True.
604609
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
605610
Defaults to True.
606-
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
611+
**sample_posterior_predictive_kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
607612
608613
Returns
609614
-------
@@ -612,16 +617,21 @@ def sample_posterior_predictive(
612617
"""
613618
self._data_setter(X_pred)
614619

615-
with self.model: # type: ignore
616-
post_pred = pm.sample_posterior_predictive(self.idata, **kwargs)
617-
if extend_idata:
618-
self.idata.extend(post_pred, join="right") # type: ignore
620+
with self.model:
621+
post_pred = pm.sample_posterior_predictive(
622+
self.idata, **sample_posterior_predictive_kwargs
623+
)
624+
625+
if extend_idata:
626+
self.idata.extend(post_pred, join="right") # type: ignore
619627

620-
posterior_predictive_samples = az.extract(
621-
post_pred, "posterior_predictive", combined=combined
628+
variable_name = (
629+
"predictions"
630+
if sample_posterior_predictive_kwargs.get("predictions")
631+
else "posterior_predictive"
622632
)
623633

624-
return posterior_predictive_samples
634+
return az.extract(post_pred, variable_name, combined=combined)
625635

626636
def get_params(self, deep=True):
627637
"""

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def set_model_fit(model: CLVModel, fit: InferenceData | Dataset):
7777
assert "posterior" in fit.groups()
7878
else:
7979
fit = InferenceData(posterior=fit)
80-
if model.model is None:
80+
if not hasattr(model, "model"):
8181
model.build_model()
8282
model.idata = fit
8383
model.idata.add_groups(fit_data=model.data.to_xarray())

tests/mmm/test_base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,9 @@ def test_calling_prior_predictive_before_fit_raises_error(test_mmm, toy_X, toy_y
270270
test_mmm.idata = None
271271
with pytest.raises(
272272
RuntimeError,
273-
match=re.escape("The model hasn't been fit yet, call .fit() first"),
273+
match=re.escape(
274+
"The model hasn't been sampled yet, call .sample_prior_predictive() first"
275+
),
274276
):
275277
test_mmm.prior_predictive
276278

@@ -297,7 +299,7 @@ def test_calling_prior_before_sample_prior_predictive_raises_error(
297299
with pytest.raises(
298300
RuntimeError,
299301
match=re.escape(
300-
"The model hasn't been fit yet, call .sample_prior_predictive() with extend_idata=True first"
302+
"The model hasn't been sampled yet, call .sample_prior_predictive() first",
301303
),
302304
):
303305
test_mmm.prior

tests/model_builder/test_model_builder.py tests/test_model_builder.py

+32-9
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def _save_input_params(self, idata):
135135
def output_var(self):
136136
return "output"
137137

138-
def _data_setter(self, X: pd.Series, y: pd.Series = None):
138+
def _data_setter(self, X: pd.DataFrame, y: pd.Series = None):
139139
with self.model:
140-
pm.set_data({"x": X.values})
140+
pm.set_data({"x": X["input"].values})
141141
if y is not None:
142142
y = y.values if isinstance(y, pd.Series) else y
143143
pm.set_data({"y_data": y})
@@ -195,8 +195,8 @@ def test_save_load(fitted_model_instance):
195195
assert fitted_model_instance.id == test_builder2.id
196196
x_pred = rng.uniform(low=0, high=1, size=100)
197197
prediction_data = pd.DataFrame({"input": x_pred})
198-
pred1 = fitted_model_instance.predict(prediction_data["input"])
199-
pred2 = test_builder2.predict(prediction_data["input"])
198+
pred1 = fitted_model_instance.predict(prediction_data)
199+
pred2 = test_builder2.predict(prediction_data)
200200
assert pred1.shape == pred2.shape
201201
temp.close()
202202

@@ -230,9 +230,9 @@ def test_fit(fitted_model_instance):
230230
assert fitted_model_instance.idata.posterior.dims["draw"] == 100
231231

232232
prediction_data = pd.DataFrame({"input": rng.uniform(low=0, high=1, size=100)})
233-
fitted_model_instance.predict(prediction_data["input"])
233+
fitted_model_instance.predict(prediction_data)
234234
post_pred = fitted_model_instance.sample_posterior_predictive(
235-
prediction_data["input"], extend_idata=True, combined=True
235+
prediction_data, extend_idata=True, combined=True
236236
)
237237
assert (
238238
post_pred[fitted_model_instance.output_var].shape[0]
@@ -256,7 +256,7 @@ def test_predict(fitted_model_instance):
256256
rng = np.random.default_rng(42)
257257
x_pred = rng.uniform(low=0, high=1, size=100)
258258
prediction_data = pd.DataFrame({"input": x_pred})
259-
pred = fitted_model_instance.predict(prediction_data["input"])
259+
pred = fitted_model_instance.predict(prediction_data)
260260
# Perform elementwise comparison using numpy
261261
assert type(pred) == np.ndarray
262262
assert len(pred) > 0
@@ -269,7 +269,7 @@ def test_sample_posterior_predictive(fitted_model_instance, combined):
269269
x_pred = rng.uniform(low=0, high=1, size=n_pred)
270270
prediction_data = pd.DataFrame({"input": x_pred})
271271
pred = fitted_model_instance.sample_posterior_predictive(
272-
prediction_data["input"], combined=combined, extend_idata=True
272+
prediction_data, combined=combined, extend_idata=True
273273
)
274274
chains = fitted_model_instance.idata.sample_stats.dims["chain"]
275275
draws = fitted_model_instance.idata.sample_stats.dims["draw"]
@@ -313,7 +313,7 @@ def test_sample_xxx_predictive_keeps_second(
313313
method_name = f"sample_{name}"
314314
method = getattr(fitted_model_instance, method_name)
315315

316-
X_pred = toy_X["input"]
316+
X_pred = toy_X
317317

318318
kwargs = {
319319
"X_pred": X_pred,
@@ -329,3 +329,26 @@ def test_sample_xxx_predictive_keeps_second(
329329

330330
sample = getattr(fitted_model_instance.idata, name)
331331
xr.testing.assert_allclose(sample, second_sample)
332+
333+
334+
def test_prediction_kwarg(fitted_model_instance, toy_X):
335+
result = fitted_model_instance.sample_posterior_predictive(
336+
toy_X,
337+
extend_idata=True,
338+
predictions=True,
339+
)
340+
assert "predictions" in fitted_model_instance.idata
341+
assert "predictions_constant_data" in fitted_model_instance.idata
342+
343+
assert isinstance(result, xr.Dataset)
344+
345+
346+
def test_fit_after_prior_keeps_prior(toy_X, toy_y):
347+
model = ModelBuilderTest()
348+
model.sample_prior_predictive(toy_X)
349+
assert "prior" in model.idata
350+
assert "prior_predictive" in model.idata
351+
352+
model.fit(X=toy_X, y=toy_y, chains=1, draws=100, tune=100)
353+
assert "prior" in model.idata
354+
assert "prior_predictive" in model.idata

0 commit comments

Comments
 (0)