Skip to content

Commit

Permalink
Fix device initialization for CUDA in Calculator and Potential classes
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Jan 29, 2025
1 parent a65511a commit 1513c2d
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 35 deletions.
6 changes: 5 additions & 1 deletion src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def __init__(
f"Potential must be an instance of Potential, got {type(potential)}"
)

self.device = torch.get_default_device() if device is None else torch.device(device)
self.device = (
torch.get_default_device() if device is None else torch.device(device)
)
if self.device.type == "cuda" and self.device.index is None:
self.device = torch.device("cuda:0")
self.dtype = torch.get_default_dtype() if dtype is None else dtype

if self.dtype != potential.dtype:
Expand Down
6 changes: 5 additions & 1 deletion src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def __init__(
):
super().__init__()
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = torch.get_default_device() if device is None else torch.device(device)
self.device = (
torch.get_default_device() if device is None else torch.device(device)
)
if self.device.type == "cuda" and self.device.index is None:
self.device = torch.device("cuda:0")
if smearing is not None:
self.register_buffer(
"smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
Expand Down
12 changes: 10 additions & 2 deletions src/torchpme/tuning/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ def __init__(
f"Only exponent = 1 is supported but got {exponent}."
)

self.device = torch.get_default_device() if device is None else torch.device(device)
self.device = (
torch.get_default_device() if device is None else torch.device(device)
)
if self.device.type == "cuda" and self.device.index is None:
self.device = torch.device("cuda:0")
self.dtype = torch.get_default_dtype() if dtype is None else dtype

_validate_parameters(
Expand Down Expand Up @@ -295,7 +299,11 @@ def __init__(
super().__init__()

self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = torch.get_default_device() if device is None else torch.device(device)
self.device = (
torch.get_default_device() if device is None else torch.device(device)
)
if self.device.type == "cuda" and self.device.index is None:
self.device = torch.device("cuda:0")

_validate_parameters(
charges=charges,
Expand Down
67 changes: 46 additions & 21 deletions tests/calculators/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,35 @@
(
Calculator,
{
"potential": CoulombPotential(smearing=None),
"potential": lambda dtype, device: CoulombPotential(
smearing=None, dtype=dtype, device=device
),
},
),
(
EwaldCalculator,
{
"potential": CoulombPotential(smearing=SMEARING),
"potential": lambda dtype, device: CoulombPotential(
smearing=SMEARING, dtype=dtype, device=device
),
"lr_wavelength": LR_WAVELENGTH,
},
),
(
PMECalculator,
{
"potential": CoulombPotential(smearing=SMEARING),
"potential": lambda dtype, device: CoulombPotential(
smearing=SMEARING, dtype=dtype, device=device
),
"mesh_spacing": MESH_SPACING,
},
),
(
P3MCalculator,
{
"potential": CoulombPotential(smearing=SMEARING),
"potential": lambda dtype, device: CoulombPotential(
smearing=SMEARING, dtype=dtype, device=device
),
"mesh_spacing": MESH_SPACING,
},
),
Expand All @@ -75,6 +83,7 @@ def cscl_system(self, device=None, dtype=None):

def test_smearing_non_positive(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"] = params["potential"](dtype, device)
match = r"`smearing` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator, PMECalculator]:
params["smearing"] = 0
Expand All @@ -86,6 +95,7 @@ def test_smearing_non_positive(self, CalculatorClass, params, device, dtype):

def test_interpolation_order_error(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"] = params["potential"](dtype, device)
if type(CalculatorClass) in [PMECalculator]:
match = "Only `interpolation_nodes` from 1 to 5"
params["interpolation_nodes"] = 10
Expand All @@ -94,6 +104,7 @@ def test_interpolation_order_error(self, CalculatorClass, params, device, dtype)

def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"] = params["potential"](dtype, device)
match = r"`lr_wavelength` .* has to be positive"
if type(CalculatorClass) in [EwaldCalculator]:
params["lr_wavelength"] = 0
Expand All @@ -106,8 +117,7 @@ def test_lr_wavelength_non_positive(self, CalculatorClass, params, device, dtype
def test_dtype_device(self, CalculatorClass, params, device, dtype):
"""Test that the output dtype and device are the same as the input."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
potential = calculator.forward(*self.cscl_system(device=device, dtype=dtype))
Expand All @@ -127,26 +137,23 @@ def check_operation(self, calculator, device, dtype):
def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
self.check_operation(calculator=calculator, device=device, dtype=dtype)

def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
self.check_operation(calculator=scripted, device=device, dtype=dtype)

def test_save_load(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
Expand All @@ -158,8 +165,7 @@ def test_save_load(self, CalculatorClass, params, device, dtype):
def test_prefactor(self, CalculatorClass, params, device, dtype):
"""Test if the prefactor is applied correctly."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

prefactor = 2.0
calculator1 = CalculatorClass(**params, device=device, dtype=dtype)
Expand All @@ -175,8 +181,7 @@ def test_prefactor(self, CalculatorClass, params, device, dtype):
def test_not_nan(self, CalculatorClass, params, device, dtype):
"""Make sure derivatives are not NaN."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
system = self.cscl_system(device=device, dtype=dtype)
Expand Down Expand Up @@ -212,9 +217,7 @@ def test_dtype_and_device_incompatability(
params = params.copy()

other_dtype = torch.float32 if dtype == torch.float64 else torch.float64

params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

match = (
rf"dtype of `potential` \({params['potential'].dtype}\) must be same as "
Expand All @@ -239,11 +242,33 @@ def test_potential_and_calculator_incompatability(
):
"""Test that the calculator raises an error if the potential and calculator are incompatible."""
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

params["potential"] = torch.jit.script(params["potential"])
with pytest.raises(
TypeError, match="Potential must be an instance of Potential, got.*"
):
CalculatorClass(**params, device=device, dtype=dtype)

def test_device_string_compatability(self, CalculatorClass, params, dtype, device):
"""Test that the calculator works with device strings."""
params = params.copy()
params["potential"] = params["potential"](dtype, "cpu")
calculator = CalculatorClass(
**params,
device=torch.device("cpu"),
dtype=dtype,
)

assert calculator.device == params["potential"].device

def test_device_index_compatability(self, CalculatorClass, params, dtype, device):
"""Test that the calculator works with no index on the device."""
if torch.cuda.is_available():
params = params.copy()
params["potential"] = params["potential"](dtype, "cuda")
calculator = CalculatorClass(
**params, device=torch.device("cuda:0"), dtype=dtype
)

assert calculator.device == params["potential"].device
27 changes: 17 additions & 10 deletions tests/metatensor/test_workflow_metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,35 @@
(
torchpme.metatensor.Calculator,
{
"potential": torchpme.CoulombPotential(smearing=None),
"potential": lambda dtype, device: torchpme.CoulombPotential(
smearing=None, dtype=dtype, device=device
),
},
),
(
torchpme.metatensor.EwaldCalculator,
{
"potential": torchpme.CoulombPotential(smearing=SMEARING),
"potential": lambda dtype, device: torchpme.CoulombPotential(
smearing=SMEARING, dtype=dtype, device=device
),
"lr_wavelength": LR_WAVELENGTH,
},
),
(
torchpme.metatensor.PMECalculator,
{
"potential": torchpme.CoulombPotential(smearing=SMEARING),
"potential": lambda dtype, device: torchpme.CoulombPotential(
smearing=SMEARING, dtype=dtype, device=device
),
"mesh_spacing": MESH_SPACING,
},
),
(
torchpme.metatensor.P3MCalculator,
{
"potential": torchpme.CoulombPotential(smearing=SMEARING),
"potential": lambda dtype, device: torchpme.CoulombPotential(
smearing=SMEARING, dtype=dtype, device=device
),
"mesh_spacing": MESH_SPACING,
},
),
Expand Down Expand Up @@ -109,23 +117,22 @@ def check_operation(self, calculator, device, dtype):

def test_operation_as_python(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a normal python script"""
params["potential"].device = device
params["potential"].dtype = dtype
params = params.copy()
params["potential"] = params["potential"](dtype, device)
calculator = CalculatorClass(**params, device=device, dtype=dtype)
self.check_operation(calculator=calculator, device=device, dtype=dtype)

def test_operation_as_torch_script(self, CalculatorClass, params, device, dtype):
"""Run `check_operation` as a compiled torch script module."""
params["potential"].device = device
params["potential"].dtype = dtype
params = params.copy()
params["potential"] = params["potential"](dtype, device)
calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
self.check_operation(calculator=scripted, device=device, dtype=dtype)

def test_save_load(self, CalculatorClass, params, device, dtype):
params = params.copy()
params["potential"].device = device
params["potential"].dtype = dtype
params["potential"] = params["potential"](dtype, device)

calculator = CalculatorClass(**params, device=device, dtype=dtype)
scripted = torch.jit.script(calculator)
Expand Down

0 comments on commit 1513c2d

Please sign in to comment.