Skip to content

Commit

Permalink
Fix random seed for "vi"
Browse files Browse the repository at this point in the history
Random seed was not properly passed when `inference_method="vi"`. This is now resolved.
  • Loading branch information
B-Deforce authored Jan 13, 2025
1 parent 089d8e9 commit fde5632
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def run(
**kwargs,
)
elif inference_method in self.pymc_methods["vi"]:
result = self._run_vi(**kwargs)
result = self._run_vi(random_seed, **kwargs)
elif inference_method == "laplace":
result = self._run_laplace(draws, omit_offsets, include_response_params)
else:
Expand Down Expand Up @@ -382,9 +382,9 @@ def _clean_results(self, idata, omit_offsets, include_response_params, idata_fro

return idata

def _run_vi(self, **kwargs):
def _run_vi(self, random_seed, **kwargs):
with self.model:
self.vi_approx = pm.fit(**kwargs)
self.vi_approx = pm.fit(random_seed=random_seed, **kwargs)
return self.vi_approx

def _run_laplace(self, draws, omit_offsets, include_response_params):
Expand Down

0 comments on commit fde5632

Please sign in to comment.