Skip to content

Commit

Permalink
Make _check_xyz_tensor_map ~3x faster by accessing data once (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf authored Jan 22, 2025
1 parent bd72af6 commit f7b5b71
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
17 changes: 11 additions & 6 deletions python/src/sphericart/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,21 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap:


def _check_xyz_tensor_map(xyz: TensorMap):
if len(xyz.blocks()) != 1:
blocks = xyz.blocks()
if len(blocks) != 1:
raise ValueError("`xyz` should have only one block")
if len(xyz.block().components) != 1:

block = blocks[0]
components = block.components
if len(components) != 1:
raise ValueError("`xyz` should have only one component")
if xyz.block().components[0].names != ["xyz"]:
if components[0].names != ["xyz"]:
raise ValueError("`xyz` should have only one component named 'xyz'")
if xyz.block().components[0].values.shape[0] != 3:

values_shape = block.values.shape
if values_shape[1] != 3:
raise ValueError("`xyz` should have 3 Cartesian coordinates")
if xyz.block().properties.values.shape[0] != 1:
if values_shape[2] != 1:
raise ValueError("`xyz` should have only one property")


Expand All @@ -251,7 +257,6 @@ def _wrap_into_tensor_map(
sh_gradients: Optional[np.ndarray] = None,
sh_hessians: Optional[np.ndarray] = None,
) -> TensorMap:

# infer l_max
l_max = len(components) - 1

Expand Down
17 changes: 11 additions & 6 deletions sphericart-torch/python/sphericart/torch/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,21 @@ def _send_precomputed_labels_to_device(self, device):


def _check_xyz_tensor_map(xyz: TensorMap):
if len(xyz.blocks()) != 1:
blocks = xyz.blocks()
if len(blocks) != 1:
raise ValueError("`xyz` should have only one block")
if len(xyz.block().components) != 1:

block = blocks[0]
components = block.components
if len(components) != 1:
raise ValueError("`xyz` should have only one component")
if xyz.block().components[0].names != ["xyz"]:
if components[0].names != ["xyz"]:
raise ValueError("`xyz` should have only one component named 'xyz'")
if xyz.block().components[0].values.shape[0] != 3:

values_shape = block.values.shape
if values_shape[1] != 3:
raise ValueError("`xyz` should have 3 Cartesian coordinates")
if xyz.block().properties.values.shape[0] != 1:
if values_shape[2] != 1:
raise ValueError("`xyz` should have only one property")


Expand All @@ -273,7 +279,6 @@ def _wrap_into_tensor_map(
sh_gradients: Optional[torch.Tensor] = None,
sh_hessians: Optional[torch.Tensor] = None,
) -> TensorMap:

# infer l_max
l_max = len(components) - 1

Expand Down

0 comments on commit f7b5b71

Please sign in to comment.