@@ -86,7 +86,7 @@ def __init__(
86
86
self .model_config = (
87
87
self .default_model_config | model_config
88
88
) # parameters for priors etc.
89
- self .model : pm .Model | None = None # Set by build_model
89
+ self .model : pm .Model
90
90
self .idata : az .InferenceData | None = None # idata is generated during fitting
91
91
self .is_fitted_ = False
92
92
@@ -458,19 +458,22 @@ def fit(
458
458
if self .X is None or self .y is None :
459
459
raise ValueError ("X and y must be set before calling build_model!" )
460
460
461
- if self . model is None :
461
+ if not hasattr ( self , "model" ) :
462
462
self .build_model (self .X , self .y )
463
463
464
464
sampler_config = self .sampler_config .copy ()
465
465
sampler_config ["progressbar" ] = progressbar
466
466
sampler_config ["random_seed" ] = random_seed
467
467
sampler_config .update (** kwargs )
468
468
469
- sampler_config .update (** kwargs )
470
- if self .model is not None :
471
- with self .model :
472
- sampler_args = {** self .sampler_config , ** kwargs }
473
- self .idata = pm .sample (** sampler_args )
469
+ sampler_args = {** self .sampler_config , ** kwargs }
470
+ with self .model :
471
+ idata = pm .sample (** sampler_args )
472
+
473
+ if self .idata :
474
+ self .idata .extend (idata , join = "right" )
475
+ else :
476
+ self .idata = idata
474
477
475
478
X_df = pd .DataFrame (X , columns = X .columns )
476
479
combined_data = pd .concat ([X_df , y_df ], axis = 1 )
@@ -537,7 +540,7 @@ def sample_prior_predictive(
537
540
X_pred ,
538
541
y_pred = None ,
539
542
samples : int | None = None ,
540
- extend_idata : bool = False ,
543
+ extend_idata : bool = True ,
541
544
combined : bool = True ,
542
545
** kwargs ,
543
546
):
@@ -552,7 +555,7 @@ def sample_prior_predictive(
552
555
Number of samples from the prior parameter distributions to generate.
553
556
If not set, uses sampler_config['draws'] if that is available, otherwise defaults to 500.
554
557
extend_idata : Boolean determining whether the predictions should be added to inference data object.
555
- Defaults to False .
558
+ Defaults to True .
556
559
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
557
560
Defaults to True.
558
561
**kwargs: Additional arguments to pass to pymc.sample_prior_predictive
@@ -567,21 +570,19 @@ def sample_prior_predictive(
567
570
if samples is None :
568
571
samples = self .sampler_config .get ("draws" , 500 )
569
572
570
- if self . model is None :
573
+ if not hasattr ( self , "model" ) :
571
574
self .build_model (X_pred , y_pred )
572
575
573
576
self ._data_setter (X_pred , y_pred )
574
- if self .model is not None :
575
- with self .model : # sample with new input data
576
- prior_pred : az .InferenceData = pm .sample_prior_predictive (
577
- samples , ** kwargs
578
- )
579
- self .set_idata_attrs (prior_pred )
580
- if extend_idata :
581
- if self .idata is not None :
582
- self .idata .extend (prior_pred , join = "right" )
583
- else :
584
- self .idata = prior_pred
577
+ with self .model : # sample with new input data
578
+ prior_pred : az .InferenceData = pm .sample_prior_predictive (samples , ** kwargs )
579
+ self .set_idata_attrs (prior_pred )
580
+
581
+ if extend_idata :
582
+ if self .idata is not None :
583
+ self .idata .extend (prior_pred , join = "right" )
584
+ else :
585
+ self .idata = prior_pred
585
586
586
587
prior_predictive_samples = az .extract (
587
588
prior_pred , "prior_predictive" , combined = combined
@@ -590,7 +591,11 @@ def sample_prior_predictive(
590
591
return prior_predictive_samples
591
592
592
593
def sample_posterior_predictive (
593
- self , X_pred , extend_idata : bool = True , combined : bool = True , ** kwargs
594
+ self ,
595
+ X_pred ,
596
+ extend_idata : bool = True ,
597
+ combined : bool = True ,
598
+ ** sample_posterior_predictive_kwargs ,
594
599
):
595
600
"""
596
601
Sample from the model's posterior predictive distribution.
@@ -603,7 +608,7 @@ def sample_posterior_predictive(
603
608
Defaults to True.
604
609
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
605
610
Defaults to True.
606
- **kwargs : Additional arguments to pass to pymc.sample_posterior_predictive
611
+ **sample_posterior_predictive_kwargs : Additional arguments to pass to pymc.sample_posterior_predictive
607
612
608
613
Returns
609
614
-------
@@ -612,16 +617,21 @@ def sample_posterior_predictive(
612
617
"""
613
618
self ._data_setter (X_pred )
614
619
615
- with self .model : # type: ignore
616
- post_pred = pm .sample_posterior_predictive (self .idata , ** kwargs )
617
- if extend_idata :
618
- self .idata .extend (post_pred , join = "right" ) # type: ignore
620
+ with self .model :
621
+ post_pred = pm .sample_posterior_predictive (
622
+ self .idata , ** sample_posterior_predictive_kwargs
623
+ )
624
+
625
+ if extend_idata :
626
+ self .idata .extend (post_pred , join = "right" ) # type: ignore
619
627
620
- posterior_predictive_samples = az .extract (
621
- post_pred , "posterior_predictive" , combined = combined
628
+ variable_name = (
629
+ "predictions"
630
+ if sample_posterior_predictive_kwargs .get ("predictions" )
631
+ else "posterior_predictive"
622
632
)
623
633
624
- return posterior_predictive_samples
634
+ return az . extract ( post_pred , variable_name , combined = combined )
625
635
626
636
def get_params (self , deep = True ):
627
637
"""
0 commit comments