From 04d16ba1a92f6611eb0645e5633e1827b1a09506 Mon Sep 17 00:00:00 2001 From: E-Rum Date: Wed, 29 Jan 2025 14:36:59 +0000 Subject: [PATCH] Fix tests for the case when CUDA is available on the system --- tests/metatensor/test_workflow_metatensor.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 6beed927..bb76f0ad 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -94,7 +94,9 @@ def system(self, device=None, dtype=None): properties=mts_torch.Labels.range("distance", 1), ) - return system.to(device=device), neighbors.to(device=device) + return system.to(device=device, dtype=dtype), neighbors.to( + device=device, dtype=dtype + ) def check_operation(self, calculator, device, dtype): """Make sure computation runs and returns a metatensor.TensorMap.""" @@ -107,12 +109,16 @@ def check_operation(self, calculator, device, dtype): def test_operation_as_python(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a normal python script""" - calculator = CalculatorClass(**params) + params["potential"].device = device + params["potential"].dtype = dtype + calculator = CalculatorClass(**params, device=device, dtype=dtype) self.check_operation(calculator=calculator, device=device, dtype=dtype) def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype): """Run `check_operation` as a compiled torch script module.""" - calculator = CalculatorClass(**params) + params["potential"].device = device + params["potential"].dtype = dtype + calculator = CalculatorClass(**params, device=device, dtype=dtype) scripted = torch.jit.script(calculator) self.check_operation(calculator=scripted, device=device, dtype=dtype)