diff --git a/falkon/tests/test_conjgrad.py b/falkon/tests/test_conjgrad.py index b4be47f..fc129cd 100644 --- a/falkon/tests/test_conjgrad.py +++ b/falkon/tests/test_conjgrad.py @@ -154,8 +154,8 @@ def test_restarts(self, data, centers, kernel, preconditioner, knm, kmm, vec_rhs def test_precomputed_kernel(self, data, centers, kernel, preconditioner, knm, kmm, vec_rhs, device): preconditioner = preconditioner.to(device) options = dataclasses.replace(self.basic_opt, use_cpu=device == "cpu") - knm = move_tensor(knm, device) - calc_kernel = PrecomputedKernel(knm, options) + knm_dev = move_tensor(knm, device) + calc_kernel = PrecomputedKernel(knm_dev, options) opt = FalkonConjugateGradient(calc_kernel, preconditioner, opt=options) # Solve (knm.T @ knm + lambda*n*kmm) x = knm.T @ b