diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 43741f8e876..db261646eb9 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -317,7 +317,6 @@ def sample_blackjax_nuts( postprocessing_backend: Optional[str] = None, postprocessing_chunks: Optional[int] = None, idata_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, ) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. @@ -530,7 +529,6 @@ def sample_numpyro_nuts( postprocessing_chunks: Optional[int] = None, idata_kwargs: Optional[Dict] = None, nuts_kwargs: Optional[Dict] = None, - **kwargs, ) -> az.InferenceData: """ Draw samples from the posterior using the NUTS method from the ``numpyro`` library.