Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Erase dtype and device #166

Merged
merged 10 commits into from
Feb 11, 2025
Merged

Erase dtype and device #166

merged 10 commits into from
Feb 11, 2025

Conversation

E-Rum
Copy link
Contributor

@E-Rum E-Rum commented Feb 3, 2025

A couple of PRs ago, we decided to include dtype and device as explicit and obligatory parameters for both calculators and potentials.

Unfortunately, after thorough consideration of how typical pipelines are built, I concluded that we should abandon this design choice.

The main reason is that, in most cases, when working with an NN model, the preferred strategy is to first initialize the model and then move it to the desired device using model.to(device).

Since torch-pme is designed to be an internal part of the model, this creates a conflict. We initialize dtype and device once, but when we later move the model to a different device, it undermines our prior device-checking logic.

Luckily, since our entire pipeline is either a torch.nn.Module or its subclass, we can integrate it smoothly with models that change their device and dtype. The key idea is to thoroughly rewrite the pipeline so that all newly created tensors during calculations are registered as buffers using self.register_buffer.

This PR aims to achieve exactly that.


📚 Documentation preview 📚: https://torch-pme--166.org.readthedocs.build/en/166/

@E-Rum E-Rum marked this pull request as ready for review February 10, 2025 10:51
@E-Rum
Copy link
Contributor Author

E-Rum commented Feb 10, 2025

Done! I dropped "dtype", "device" initialization from "Potential" and "Calculator" classes and rewrote all the tests and examples accordingly. By default, all initialized torch.Tensors are now registered as buffers with the "device" and "dtype" they were passed on. All floats that we register as buffers are explicitly registered with torch.float64, as Python floats are float64 by default.

Since I was not involved in the tuning code development, I would kindly ask you to pay special attention to the changes in that part to ensure I didn’t break anything.

@@ -32,6 +32,11 @@ Added
* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in
``Calculator`` classes and tuning functions.

Changed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably change it to

Suggested change
Changed
Removed

and also remove our statements in changed about using dtypes everywhere...

Comment on lines +72 to +73
pot_1 = pot_1.to(dtype=dtype)
pot_2 = pot_2.to(dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment why you need this here.

Comment on lines +72 to +78
cell = torch.eye(
3,
device=self.potential.smearing.device,
dtype=self.potential.smearing.dtype,
)
ns_mesh = torch.ones(3, dtype=int, device=cell.device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I apply to .to to the class, is this correctly passed to self.kspace_filter and self.mesh_interpolator?


# Calculate intervals
intervals = x[1:] - x[:-1]
dy = (y[1:] - y[:-1]) / intervals

# Create zero boundary conditions (natural spline)
d2y = torch.zeros_like(x, dtype=torch.float64)
torch.zeros_like(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove it?

Comment on lines +67 to +70
if potential.smearing is None:
raise ValueError(
"Must specify smearing to use a potential with P3MCalculator"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no test for it. Should be go into workflow tests.

Also I think more and more that we need an Ewald Base class and another "Base" class for the direct calculator.
I fact the EwaldCalculator can serve as a base class and PME and P3M only override the k-space method. But this is something for later...

@E-Rum E-Rum requested a review from PicoCentauri February 10, 2025 17:40
@PicoCentauri PicoCentauri merged commit fb760cd into main Feb 11, 2025
13 checks passed
@PicoCentauri PicoCentauri deleted the fix_device_dtype branch February 11, 2025 08:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants