Skip to content

Commit 8d116b5

Browse files
authored
Log number of posterior & tuning samples (#943)
* helper command to view the artifacts from test * pass tune from kwargs * test for support of all samplers * add mlflow as a mock import * actual import as autolog is missing from docs
1 parent 809a079 commit 8d116b5

File tree

4 files changed

+42
-10
lines changed

4 files changed

+42
-10
lines changed

Makefile

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ uml: ## Install documentation dependencies and generate UML diagrams
5050
pyreverse pymc_marketing/mmm -d docs/source/uml -f 'ALL' -o png -p mmm
5151
pyreverse pymc_marketing/clv -d docs/source/uml -f 'ALL' -o png -p clv
5252

53+
mlflow_server: ## Start MLflow server on port 5000
54+
mlflow server --backend-store-uri sqlite:///mlruns.db --default-artifact-root ./mlruns
55+
5356

5457
#################################################################################
5558
# Self Documenting Commands #

pymc_marketing/mlflow.py

+33-8
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,11 @@ def log_model_derived_info(model: Model) -> None:
304304
- The model representation (str).
305305
- The model coordinates (coords.json).
306306
307+
Parameters
308+
----------
309+
model : Model
310+
The PyMC model object.
311+
307312
"""
308313
log_types_of_parameters(model)
309314

@@ -321,6 +326,7 @@ def log_model_derived_info(model: Model) -> None:
321326

322327
def log_sample_diagnostics(
323328
idata: az.InferenceData,
329+
tune: int | None = None,
324330
) -> None:
325331
"""Log sample diagnostics to MLflow.
326332
@@ -336,6 +342,14 @@ def log_sample_diagnostics(
336342
- The version of the inference library
337343
- The version of ArviZ
338344
345+
Parameters
346+
----------
347+
idata : az.InferenceData
348+
The InferenceData object returned by the sampling method.
349+
tune : int, optional
350+
The number of tuning steps used in sampling. Derived from the
351+
inference data if not provided.
352+
339353
"""
340354
if "posterior" not in idata:
341355
raise KeyError("InferenceData object does not contain the group posterior.")
@@ -348,19 +362,28 @@ def log_sample_diagnostics(
348362

349363
diverging = sample_stats["diverging"]
350364

365+
chains = posterior.sizes["chain"]
366+
draws = posterior.sizes["draw"]
367+
posterior_samples = chains * draws
368+
369+
tuning_step = sample_stats.attrs.get("tuning_steps", tune)
370+
if tuning_step is not None:
371+
tuning_samples = tuning_step * chains
372+
mlflow.log_param("tuning_steps", tuning_step)
373+
mlflow.log_param("tuning_samples", tuning_samples)
374+
351375
total_divergences = diverging.sum().item()
352376
mlflow.log_metric("total_divergences", total_divergences)
353377
if sampling_time := sample_stats.attrs.get("sampling_time"):
354378
mlflow.log_metric("sampling_time", sampling_time)
355379
mlflow.log_metric(
356380
"time_per_draw",
357-
sampling_time / (posterior.sizes["draw"] * posterior.sizes["chain"]),
381+
sampling_time / posterior_samples,
358382
)
359383

360-
if tuning_step := sample_stats.attrs.get("tuning_steps"):
361-
mlflow.log_param("tuning_steps", tuning_step)
362-
mlflow.log_param("draws", posterior.sizes["draw"])
363-
mlflow.log_param("chains", posterior.sizes["chain"])
384+
mlflow.log_param("draws", draws)
385+
mlflow.log_param("chains", chains)
386+
mlflow.log_param("posterior_samples", posterior_samples)
364387

365388
if inference_library := posterior.attrs.get("inference_library"):
366389
mlflow.log_param("inference_library", inference_library)
@@ -382,8 +405,7 @@ def log_inference_data(
382405
idata : az.InferenceData
383406
The InferenceData object returned by the sampling method.
384407
save_file : str | Path
385-
The path to save the InferenceData object as a net
386-
CDF file.
408+
The path to save the InferenceData object as a netCDF file.
387409
388410
"""
389411
idata.to_netcdf(str(save_file))
@@ -516,8 +538,11 @@ def new_sample(*args, **kwargs):
516538
mlflow.log_param("pymc_version", pm.__version__)
517539
mlflow.log_param("nuts_sampler", kwargs.get("nuts_sampler", "pymc"))
518540

541+
# Align with the default values in pymc.sample
542+
tune = kwargs.get("tune", 1000)
543+
519544
if log_sampler_info:
520-
log_sample_diagnostics(idata)
545+
log_sample_diagnostics(idata, tune=tune)
521546
log_arviz_summary(
522547
idata,
523548
"summary.html",

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ docs = [
6262
"sphinx",
6363
"sphinxext-opengraph",
6464
"watermark",
65+
"mlflow>=2.0.0",
6566
]
6667
lint = ["mypy", "pandas-stubs", "pre-commit>=2.19.0", "ruff>=0.1.4"]
6768
test = [

tests/test_mlflow.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,13 @@ def metric_checks(metrics, nuts_sampler) -> None:
231231
def param_checks(params, draws: int, chains: int, tune: int, nuts_sampler: str) -> None:
232232
assert params["draws"] == str(draws)
233233
assert params["chains"] == str(chains)
234+
assert params["posterior_samples"] == str(draws * chains)
235+
234236
if nuts_sampler not in ["numpyro", "blackjax"]:
235237
assert params["inference_library"] == nuts_sampler
236-
if nuts_sampler not in ["numpyro", "nutpie", "blackjax"]:
237-
assert params["tuning_steps"] == str(tune)
238+
239+
assert params["tuning_steps"] == str(tune)
240+
assert params["tuning_samples"] == str(tune * chains)
238241

239242
assert params["pymc_marketing_version"] == __version__
240243

0 commit comments

Comments
 (0)