Skip to content

Commit

Permalink
Allow for observation noise without mask in ModelListGP (#2735)
Browse files Browse the repository at this point in the history
Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

Fixes #2734

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?
Yes

Pull Request resolved: #2735

Test Plan:
I've added a test to the unit test, in addition, a separate example:

```python
import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood

model = ModelListGP(
    *[SingleTaskGP(train_X=torch.tensor([[0.0], [1.0]]), train_Y=torch.tensor([[0.0], [1.0]])) for _ in range(2)]
)
mll = SumMarginalLogLikelihood(model.likelihood, model)
mll = fit_gpytorch_mll(mll)
sampler = ListSampler(*[SobolQMCNormalSampler(sample_shape=torch.Size([1])) for _ in range(2)])

# Works:
model.fantasize(
    torch.tensor([[0.5]]),
    sampler=sampler,
    observation_noise=torch.tensor([[0.2, 0.2]]),
    evaluation_mask=torch.tensor([[1.0, 1.0]], dtype=torch.bool),
)

# Now also works:
model.fantasize(
    torch.tensor([[0.5]]),
    sampler=sampler,
    observation_noise=torch.tensor([[0.2, 0.2]]),
)
```

Reviewed By: esantorella

Differential Revision: D69261331

Pulled By: saitcakmak

fbshipit-source-id: f51ab84144ecfcf1b142ad7b447fc6c56d518202
  • Loading branch information
swierh authored and facebook-github-bot committed Feb 7, 2025
1 parent f413275 commit 8770fa4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
2 changes: 2 additions & 0 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ def fantasize(
sampler_i = (
sampler.samplers[i] if isinstance(sampler, ListSampler) else sampler
)
if observation_noise is not None:
observation_noise_i = observation_noise[..., i : i + 1]

fant_model = self.models[i].fantasize(
X=X_i,
Expand Down
21 changes: 21 additions & 0 deletions test/models/test_model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,27 @@ def test_fantasize(self):
(3, 2), 0.3, dtype=x1.dtype, device=x1.device
)
observation_noise[:, 1] = 0.4

# check observation noise without mask
fm = modellist.fantasize(
torch.rand(3, 2),
sampler=ListSampler(sampler1, sampler2),
observation_noise=observation_noise,
)
for i in range(2):
fm_i = fm.models[i]
self.assertIsInstance(fm_i, SingleTaskGP)
self.assertIsInstance(fm_i.likelihood, FixedNoiseGaussianLikelihood)
self.assertEqual(fm_i.train_inputs[0].shape, torch.Size([2, 8, 2]))
self.assertEqual(fm_i.train_targets.shape, torch.Size([2, 8]))
# check observation_noise
self.assertTrue(
torch.equal(
fm_i.likelihood.noise[..., -3:], observation_noise[:, i]
)
)

# check masked noise
for obs_noise in (None, observation_noise):
fm = modellist.fantasize(
torch.rand(3, 2),
Expand Down

0 comments on commit 8770fa4

Please sign in to comment.