diff --git a/falkon/optim/conjgrad.py b/falkon/optim/conjgrad.py index a6b8157..2ea425f 100644 --- a/falkon/optim/conjgrad.py +++ b/falkon/optim/conjgrad.py @@ -282,7 +282,7 @@ def solve(self, X, M, Y, _lambda, initial_solution, max_iter, callback=None): stream = torch.cuda.current_stream(device) # Note that if we don't have CUDA this still works with stream=None. - with ExitStack() as stack, TicToc("ConjGrad preparation", False), torch.inference_mode(): + with ExitStack() as stack, TicToc("ConjGrad preparation", False): if cuda_inputs: stack.enter_context(torch.cuda.device(device)) stack.enter_context(torch.cuda.stream(stream))