Skip to content

Commit

Permalink
Update sPM to allow for noise_prior_dist
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Feb 25, 2024
1 parent 8ba52ae commit b1d516a
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions gpax/models/spm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Created by Maxim Ziatdinov (email: [email protected])
"""

import warnings
from typing import Callable, Optional, Tuple, Type, Dict

import jax
Expand Down Expand Up @@ -41,13 +42,20 @@ class sPM:
def __init__(self,
model: model_type,
model_prior: prior_type,
noise_prior: Optional[prior_type] = None) -> None:
noise_prior: Optional[prior_type] = None,
noise_prior_dist: Optional[dist.Distribution] = None,) -> None:
self._model = model
self.model_prior = model_prior
if noise_prior is None:
self.noise_prior = lambda: numpyro.sample("sig", dist.LogNormal(0, 1))
else:
self.noise_prior = noise_prior
if noise_prior is not None:
warnings.warn(
"`noise_prior` is deprecated and will be removed in a future version. "
"Please use `noise_prior_dist` instead, which accepts an instance of a "
"numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, "
"rather than a function that calls `numpyro.sample`.",
FutureWarning,
)
self.noise_prior = noise_prior
self.noise_prior_dist = noise_prior_dist
self.mcmc = None

def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None:
Expand All @@ -59,10 +67,20 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None:
# Compute the function's value
mu = numpyro.deterministic("mu", self._model(X, params))
# Sample observational noise
sig = self.noise_prior()
if self.noise_prior: # this will be removed in the future releases
sig = self.noise_prior()
else:
sig = self._sample_noise()
# Score against the observed data points
numpyro.sample("y", dist.Normal(mu, sig), obs=y)

def _sample_noise(self) -> jnp.ndarray:
if self.noise_prior_dist is not None:
noise_dist = self.noise_prior_dist
else:
noise_dist = dist.LogNormal(0, 1)
return numpyro.sample("noise", noise_dist)

def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
num_warmup: int = 2000, num_samples: int = 2000,
num_chains: int = 1, chain_method: str = 'sequential',
Expand Down

0 comments on commit b1d516a

Please sign in to comment.