@@ -304,6 +304,11 @@ def log_model_derived_info(model: Model) -> None:
304
304
- The model representation (str).
305
305
- The model coordinates (coords.json).
306
306
307
+ Parameters
308
+ ----------
309
+ model : Model
310
+ The PyMC model object.
311
+
307
312
"""
308
313
log_types_of_parameters (model )
309
314
@@ -321,6 +326,7 @@ def log_model_derived_info(model: Model) -> None:
321
326
322
327
def log_sample_diagnostics (
323
328
idata : az .InferenceData ,
329
+ tune : int | None = None ,
324
330
) -> None :
325
331
"""Log sample diagnostics to MLflow.
326
332
@@ -336,6 +342,14 @@ def log_sample_diagnostics(
336
342
- The version of the inference library
337
343
- The version of ArviZ
338
344
345
+ Parameters
346
+ ----------
347
+ idata : az.InferenceData
348
+ The InferenceData object returned by the sampling method.
349
+ tune : int, optional
350
+ The number of tuning steps used in sampling. Derived from the
351
+ inference data if not provided.
352
+
339
353
"""
340
354
if "posterior" not in idata :
341
355
raise KeyError ("InferenceData object does not contain the group posterior." )
@@ -348,19 +362,28 @@ def log_sample_diagnostics(
348
362
349
363
diverging = sample_stats ["diverging" ]
350
364
365
+ chains = posterior .sizes ["chain" ]
366
+ draws = posterior .sizes ["draw" ]
367
+ posterior_samples = chains * draws
368
+
369
+ tuning_step = sample_stats .attrs .get ("tuning_steps" , tune )
370
+ if tuning_step is not None :
371
+ tuning_samples = tuning_step * chains
372
+ mlflow .log_param ("tuning_steps" , tuning_step )
373
+ mlflow .log_param ("tuning_samples" , tuning_samples )
374
+
351
375
total_divergences = diverging .sum ().item ()
352
376
mlflow .log_metric ("total_divergences" , total_divergences )
353
377
if sampling_time := sample_stats .attrs .get ("sampling_time" ):
354
378
mlflow .log_metric ("sampling_time" , sampling_time )
355
379
mlflow .log_metric (
356
380
"time_per_draw" ,
357
- sampling_time / ( posterior . sizes [ "draw" ] * posterior . sizes [ "chain" ]) ,
381
+ sampling_time / posterior_samples ,
358
382
)
359
383
360
- if tuning_step := sample_stats .attrs .get ("tuning_steps" ):
361
- mlflow .log_param ("tuning_steps" , tuning_step )
362
- mlflow .log_param ("draws" , posterior .sizes ["draw" ])
363
- mlflow .log_param ("chains" , posterior .sizes ["chain" ])
384
+ mlflow .log_param ("draws" , draws )
385
+ mlflow .log_param ("chains" , chains )
386
+ mlflow .log_param ("posterior_samples" , posterior_samples )
364
387
365
388
if inference_library := posterior .attrs .get ("inference_library" ):
366
389
mlflow .log_param ("inference_library" , inference_library )
@@ -382,8 +405,7 @@ def log_inference_data(
382
405
idata : az.InferenceData
383
406
The InferenceData object returned by the sampling method.
384
407
save_file : str | Path
385
- The path to save the InferenceData object as a net
386
- CDF file.
408
+ The path to save the InferenceData object as a netCDF file.
387
409
388
410
"""
389
411
idata .to_netcdf (str (save_file ))
@@ -516,8 +538,11 @@ def new_sample(*args, **kwargs):
516
538
mlflow .log_param ("pymc_version" , pm .__version__ )
517
539
mlflow .log_param ("nuts_sampler" , kwargs .get ("nuts_sampler" , "pymc" ))
518
540
541
+ # Align with the default values in pymc.sample
542
+ tune = kwargs .get ("tune" , 1000 )
543
+
519
544
if log_sampler_info :
520
- log_sample_diagnostics (idata )
545
+ log_sample_diagnostics (idata , tune = tune )
521
546
log_arviz_summary (
522
547
idata ,
523
548
"summary.html" ,
0 commit comments