diff --git a/test/models/test_model_list_gp_regression.py b/test/models/test_model_list_gp_regression.py index 0f13e4e259..9b0c7f003c 100644 --- a/test/models/test_model_list_gp_regression.py +++ b/test/models/test_model_list_gp_regression.py @@ -479,6 +479,27 @@ def test_fantasize(self): (3, 2), 0.3, dtype=x1.dtype, device=x1.device ) observation_noise[:, 1] = 0.4 + + # check observation noise without mask + fm = modellist.fantasize( + torch.rand(3, 2), + sampler=ListSampler(sampler1, sampler2), + observation_noise=observation_noise, + ) + for i in range(2): + fm_i = fm.models[i] + self.assertIsInstance(fm_i, SingleTaskGP) + self.assertIsInstance(fm_i.likelihood, FixedNoiseGaussianLikelihood) + self.assertEqual(fm_i.train_inputs[0].shape, torch.Size([2, 8, 2])) + self.assertEqual(fm_i.train_targets.shape, torch.Size([2, 8])) + # check observation_noise + self.assertTrue( + torch.equal( + fm_i.likelihood.noise[..., -3:], observation_noise[:, i] + ) + ) + + # check masked noise for obs_noise in (None, observation_noise): fm = modellist.fantasize( torch.rand(3, 2),