diff --git a/gpax/models/dkl.py b/gpax/models/dkl.py index 2905480..73efeb0 100644 --- a/gpax/models/dkl.py +++ b/gpax/models/dkl.py @@ -159,7 +159,7 @@ def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarra def sample_biases(name: str, channels: int) -> jnp.ndarray: """Sampling bias vector""" - b = numpyro.sample(name=name, fn=dist.Normal( + b = numpyro.sample(name=name, fn=dist.Cauchy( loc=jnp.zeros((channels)), scale=jnp.ones((channels)))) return b