Skip to content

Commit

Permalink
Helper: to_torch
Browse files Browse the repository at this point in the history
Add helper methods to generate PyTorch tensors.
  • Loading branch information
ax3l committed Sep 27, 2023
1 parent 56d5c98 commit 4ee4476
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/amrex/Array4.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def array4_to_cupy(self, copy=False, order="F"):
raise ValueError("The order argument must be F or C.")


# torch


def register_Array4_extension(amr):
"""Array4 helper methods"""
import inspect
Expand Down
3 changes: 3 additions & 0 deletions src/amrex/ArrayOfStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def aos_to_cupy(self, copy=False):
return cp.array(self, copy=copy)


# torch


def register_AoS_extension(amr):
"""ArrayOfStructs helper methods"""
import inspect
Expand Down
5 changes: 5 additions & 0 deletions src/amrex/MultiFab.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def mf_to_cupy(self, copy=False, order="F"):
return views


# torch


def register_MultiFab_extension(amr):
"""MultiFab helper methods"""

Expand All @@ -99,3 +102,5 @@ def register_MultiFab_extension(amr):
amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__

amr.MultiFab.to_cupy = mf_to_cupy

# torch
14 changes: 14 additions & 0 deletions src/amrex/PODVector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,19 @@ def podvector_to_cupy(self, copy=False):
raise ValueError("Vector is empty.")


def podvector_to_torch(self, copy=False):
"""
Provide PyTorch tensor views into a PODVector (e.g., RealVector, IntVector).
...
"""
import torch

# if CUDA else ...
# pick right device (context? device number?)
return torch.as_tensor(self.to_cupy(copy), device="cuda")


def register_PODVector_extension(amr):
"""PODVector helper methods"""
import inspect
Expand All @@ -82,3 +95,4 @@ def register_PODVector_extension(amr):
):
POD_type.to_numpy = podvector_to_numpy
POD_type.to_cupy = podvector_to_cupy
POD_type.to_torch = podvector_to_torch
14 changes: 14 additions & 0 deletions src/amrex/StructOfArrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ def soa_to_cupy(self, copy=False):
return soa_view


def soa_to_torch(self, copy=False):
"""
Provide PyTorch tensor views into a StructOfArrays.
...
"""
import torch

# if CUDA else ...
# pick right device (context? device number?)
return torch.as_tensor(self.to_cupy(copy), device="cuda")


def register_SoA_extension(amr):
"""StructOfArrays helper methods"""
import inspect
Expand All @@ -97,3 +110,4 @@ def register_SoA_extension(amr):
):
SoA_type.to_numpy = soa_to_numpy
SoA_type.to_cupy = soa_to_cupy
SoA_type.to_torch = soa_to_torch

0 comments on commit 4ee4476

Please sign in to comment.