Skip to content

Commit

Permalink
FullyBayesianSingleTaskGP.train should not return None
Browse files Browse the repository at this point in the history
Summary: This is for consistency with the signature of `Module.train`.

Differential Revision: D68710923
  • Loading branch information
esantorella authored and facebook-github-bot committed Jan 27, 2025
1 parent 851df1f commit 59961a2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
12 changes: 10 additions & 2 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import math
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any
from typing import Any, TypeVar

import pyro
import torch
Expand Down Expand Up @@ -67,6 +67,11 @@
from pyro.ops.integrator import register_exception_handler
from torch import Tensor

# Can replace with Self type once 3.11 is the minimum version
TFullyBayesianSingleTaskGP = TypeVar(
"TFullyBayesianSingleTaskGP", bound="FullyBayesianSingleTaskGP"
)

_sqrt5 = math.sqrt(5)


Expand Down Expand Up @@ -623,13 +628,16 @@ def _aug_batch_shape(self) -> torch.Size:
aug_batch_shape += torch.Size([self.num_outputs])
return aug_batch_shape

def train(self, mode: bool = True) -> None:
def train(
self: TFullyBayesianSingleTaskGP, mode: bool = True
) -> TFullyBayesianSingleTaskGP:
r"""Puts the model in `train` mode."""
super().train(mode=mode)
if mode:
self.mean_module = None
self.covar_module = None
self.likelihood = None
return self

def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
r"""Load the MCMC hyperparameter samples into the model.
Expand Down
3 changes: 2 additions & 1 deletion test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ def test_fit_model(self):
# Make sure the model shapes are set correctly
self.assertEqual(model.pyro_model.train_X.shape, torch.Size([n, d]))
self.assertAllClose(model.pyro_model.train_X, train_X)
model.train() # Put the model in train mode
trained_model = model.train() # Put the model in train mode
self.assertIs(trained_model, model)
self.assertAllClose(train_X, model.pyro_model.train_X)
self.assertIsNone(model.mean_module)
self.assertIsNone(model.covar_module)
Expand Down

0 comments on commit 59961a2

Please sign in to comment.