diff --git a/falkon/kernels/distance_kernel.py b/falkon/kernels/distance_kernel.py index 6917041..147cbd2 100644 --- a/falkon/kernels/distance_kernel.py +++ b/falkon/kernels/distance_kernel.py @@ -552,9 +552,9 @@ def keops_mmv_impl(self, X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2): ) elif self.nu == 2.5: formula = ( - "(IntCst(1) + Sqrt(IntCst(5)) * Norm2(x1 / s - x2 / s) + " + "((IntCst(1) + Sqrt(IntCst(5)) * Norm2(x1 / s - x2 / s) + " "(IntInv(3) * IntCst(5)) * SqNorm2(x1 / s - x2 / s)) * " - "(Exp(-Sqrt(IntCst(5)) * Norm2(x1 / s - x2 / s)) * v)" + "Exp(-Sqrt(IntCst(5)) * Norm2(x1 / s - x2 / s))) * v" ) elif self.nu == float("inf"): formula = "Exp(IntInv(-2) * SqDist(x1 / s, x2 / s)) * v" diff --git a/falkon/tests/conftest.py b/falkon/tests/conftest.py index 81c1a54..f9d9cec 100644 --- a/falkon/tests/conftest.py +++ b/falkon/tests/conftest.py @@ -100,7 +100,7 @@ def fix_mats(*mats, order, device, dtype): dtype = make_tuple(dtype, len(mats)) for i, m in enumerate(mats): if isinstance(m, SparseTensor): - yield fix_sparse_mat(m, dtype=dtype[i], device=device[i]) + yield fix_sparse_mat(m, device=device[i], dtype=dtype[i]) else: yield fix_mat(m, order=order[i], device=device[i], dtype=dtype[i]) diff --git a/falkon/tests/test_kernels.py b/falkon/tests/test_kernels.py index a6fae31..8185505 100644 --- a/falkon/tests/test_kernels.py +++ b/falkon/tests/test_kernels.py @@ -158,6 +158,9 @@ def autogradcheck_mm(_m1, _m2, *_kernel_params): actual_noout, actual, rtol=rtol, atol=atol, msg="MMV with out and without return different stuff" ) expected_mmv = expected_mm @ v + print(f"{expected_mmv=}") + print(f"{actual=}") + print(f"Max Diff: {(expected_mmv - actual).square().max().sqrt()}") torch.testing.assert_close(expected_mmv, actual, rtol=rtol, atol=atol, msg="MMV result is incorrect") # 4. MMV gradients