forked from metatensor/metatensor
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
…etatensor#753) --------- Co-authored-by: Joseph W. Abbott <[email protected]> Co-authored-by: Guillaume Fraux <[email protected]>
- Loading branch information
1 parent
8dc9f9f
commit 401c590
Showing
6 changed files
with
356 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
211 changes: 211 additions & 0 deletions
211
python/metatensor-learn/metatensor/learn/nn/equivariant_transformation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
|
||
from .._backend import Labels, TensorMap | ||
from .._dispatch import int_array_like | ||
from ._utils import _check_module_map_parameter | ||
from .module_map import ModuleMap | ||
|
||
|
||
class EquivariantTransformation(torch.nn.Module): | ||
""" | ||
A custom :py:class:`torch.nn.Module` that applies an arbitrary shape- and | ||
equivariance-preserving transformation to an input :py:class:`TensorMap`. | ||
For invariant blocks (specified with ``invariant_keys``), the respective | ||
transformation contained in :param modules: is applied as is. For covariant blocks, | ||
an invariant multiplier is created, applying the transformation to the norm of the | ||
block over the component dimensions. | ||
:param modules: a :py:class:`list` of :py:class:`torch.nn.Module` containing the | ||
transformations to be applied to each block indexed by | ||
:param in_keys:. Transformations for invariant and covariant blocks differ. See | ||
above. | ||
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input | ||
:py:class:`TensorMap` in the :py:meth:`forward` method. | ||
:param in_features: :py:class:`list` of :py:class:`int`, the number of features in | ||
the input tensor for each block indexed by the keys in :param in_keys:. If | ||
passed as a single value, the same number of features is assumed for all blocks. | ||
:param out_properties: :py:class:`list` of :py:class`Labels` (optional), the | ||
properties labels | ||
of the output. By default the output properties are relabeled using | ||
Labels.range. | ||
:param invariant_keys: a :py:class:`Labels` object that is used to select the | ||
invariant keys from ``in_keys``. If not provided, the invariant keys are assumed | ||
to be those where key dimensions ``["o3_lambda", "o3_sigma"]`` are equal to | ||
``[0, 1]``. | ||
>>> import torch | ||
>>> import numpy as np | ||
>>> from metatensor import Labels, TensorBlock, TensorMap | ||
>>> from metatensor.learn.nn import EquivariantTransformation | ||
Define a dummy invariant TensorBlock | ||
>>> block_1 = TensorBlock( | ||
... values=torch.randn(2, 1, 3), | ||
... samples=Labels( | ||
... ["system", "atom"], | ||
... np.array( | ||
... [ | ||
... [0, 0], | ||
... [0, 1], | ||
... ] | ||
... ), | ||
... ), | ||
... components=[Labels(["o3_mu"], np.array([[0]]))], | ||
... properties=Labels(["properties"], np.array([[0], [1], [2]])), | ||
... ) | ||
Define a dummy covariant TensorBlock | ||
>>> block_2 = TensorBlock( | ||
... values=torch.randn(2, 3, 3), | ||
... samples=Labels( | ||
... ["system", "atom"], | ||
... np.array( | ||
... [ | ||
... [0, 0], | ||
... [0, 1], | ||
... ] | ||
... ), | ||
... ), | ||
... components=[Labels(["o3_mu"], np.array([[-1], [0], [1]]))], | ||
... properties=Labels(["properties"], np.array([[3], [4], [5]])), | ||
... ) | ||
Create a TensorMap containing the dummy TensorBlocks | ||
>>> keys = Labels(names=["o3_lambda"], values=np.array([[0], [1]])) | ||
>>> tensor = TensorMap(keys, [block_1, block_2]) | ||
Define the transformation to apply to the TensorMap | ||
>>> modules = [torch.nn.Tanh(), torch.nn.Tanh()] | ||
>>> in_features = [len(tensor.block(key).properties) for key in tensor.keys] | ||
Define the EquivariantTransformation module | ||
>>> transformation = EquivariantTransformation( | ||
... modules, | ||
... tensor.keys, | ||
... in_features, | ||
... out_properties=[tensor.block(key).properties for key in tensor.keys], | ||
... invariant_keys=Labels( | ||
... ["o3_lambda"], np.array([0], dtype=np.int64).reshape(-1, 1) | ||
... ), | ||
... ) | ||
The output metadata are the same as the input | ||
>>> transformation(tensor) | ||
TensorMap with 2 blocks | ||
keys: o3_lambda | ||
0 | ||
1 | ||
>>> transformation(tensor)[0] | ||
TensorBlock | ||
samples (2): ['system', 'atom'] | ||
components (1): ['o3_mu'] | ||
properties (3): ['properties'] | ||
gradients: None | ||
""" | ||
|
||
def __init__( | ||
self, | ||
modules: List[torch.nn.Module], | ||
in_keys: Labels, | ||
in_features: Union[int, List[int]], | ||
out_features: Optional[Union[int, List[int]]] = None, | ||
out_properties: Optional[List[Labels]] = None, | ||
invariant_keys: Optional[Labels] = None, | ||
) -> None: | ||
super().__init__() | ||
|
||
# Set a default for invariant keys | ||
if invariant_keys is None: | ||
invariant_keys = Labels( | ||
names=["o3_lambda", "o3_sigma"], | ||
values=int_array_like([0, 1], like=in_keys.values).reshape(-1, 2), | ||
) | ||
invariant_key_idxs = in_keys.select(invariant_keys) | ||
|
||
# Infer `out_features` if not provided | ||
if out_features is None: | ||
if out_properties is None: | ||
raise ValueError( | ||
"If `out_features` is not provided," | ||
" `out_properties` must be provided." | ||
) | ||
out_features = [len(p) for p in out_properties] | ||
|
||
# Check input parameters, convert to lists (for each key) if necessary | ||
in_features = _check_module_map_parameter( | ||
in_features, "in_features", int, len(in_keys), "in_keys" | ||
) | ||
|
||
modules_for_map: List[torch.nn.Module] = [] | ||
for i in range(len(in_keys)): | ||
if i in invariant_key_idxs: | ||
module_i = modules[i] | ||
else: | ||
module_i = _CovariantTransform( | ||
module=modules[i], | ||
) | ||
modules_for_map.append(module_i) | ||
|
||
self.module_map = ModuleMap(in_keys, modules_for_map, out_properties) | ||
|
||
def forward(self, tensor: TensorMap) -> TensorMap: | ||
""" | ||
Apply the transformation to the input tensor map `tensor`. | ||
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed. | ||
:return: :py:class:`TensorMap` corresponding to the transformed input | ||
``tensor``. | ||
""" | ||
return self.module_map(tensor) | ||
|
||
|
||
class _CovariantTransform(torch.nn.Module): | ||
""" | ||
Applies an arbitrary shape-preserving transformation defined in ``module`` to a | ||
3-dimensional tensor in a way that preserves equivariance. The transformation is | ||
applied to the norm of the :py:class:`torch.Tensor` over the component dimension. | ||
The resulting :py:class:`torch.Tensor` is elementwise multiplied back to the | ||
original one, thus preserving covariance. | ||
:param in_features: a :py:class:`int`, the input feature dimension. This also | ||
corresponds to the output feature size as the shape of the tensor passed to | ||
:py:meth:`forward` is preserved. | ||
:param module: :py:class:`torch.nn.Module` containing the transformation to be | ||
applied to the invariants constructed from the norms over the component | ||
dimension of the input :py:class:`torch.Tensor` passed to the :py:meth:`forward` | ||
method. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
module: torch.nn.Module, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.module = module | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Creates an invariant block from the ``input`` covariant, and transforms it by | ||
applying the torch ``module`` passed to the class constructor. Then uses the | ||
transformed invariant as an elementwise multiplier for the ``input`` block. | ||
Transformations are applied consistently to components (axis 1) to preserve | ||
equivariance. | ||
""" | ||
assert len(input.shape) == 3, "``input`` must be a three-dimensional tensor" | ||
invariant = input.norm(dim=1, keepdim=True) | ||
invariant_transformed = self.module(invariant) | ||
tensor_out = invariant_transformed * input | ||
|
||
return tensor_out |
88 changes: 88 additions & 0 deletions
88
python/metatensor-learn/tests/equivariant_transformation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import os | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import metatensor | ||
|
||
|
||
torch = pytest.importorskip("torch") | ||
|
||
from metatensor.learn.nn.equivariant_transformation import ( # noqa: E402 | ||
EquivariantTransformation, # noqa: E402 | ||
) # noqa:E402 | ||
|
||
from ._rotation_utils import WignerDReal # noqa: E402 | ||
|
||
|
||
DATA_ROOT = os.path.join( | ||
os.path.dirname(__file__), "..", "..", "metatensor-operations", "tests", "data" | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def tensor(): | ||
tensor = metatensor.load(os.path.join(DATA_ROOT, "qm7-spherical-expansion.npz")) | ||
tensor = tensor.to(arrays="torch") | ||
tensor = metatensor.remove_gradients(tensor) | ||
return tensor | ||
|
||
|
||
@pytest.fixture | ||
def wigner_d_real(): | ||
return WignerDReal(lmax=4, angles=(0.87641, 1.8729, 0.9187)) | ||
|
||
|
||
def module_wrapper(in_features, device, dtype): | ||
""" | ||
A sequential module with nonlinearities | ||
""" | ||
return torch.nn.Sequential( | ||
torch.nn.Tanh(), | ||
torch.nn.Linear( | ||
in_features=in_features, | ||
out_features=5, | ||
device=device, | ||
dtype=dtype, | ||
), | ||
torch.nn.ReLU(), | ||
torch.nn.Linear( | ||
in_features=5, | ||
out_features=in_features, | ||
device=device, | ||
dtype=dtype, | ||
), | ||
) | ||
|
||
|
||
def test_equivariance(tensor, wigner_d_real): | ||
""" | ||
Tests that application of the EquivariantTransformation layer is equivariant to O3 | ||
transformation of the input. | ||
""" | ||
# Define input and rotated input | ||
x = tensor | ||
Rx = wigner_d_real.transform_tensormap_o3(x) | ||
|
||
in_features = [len(x.block(key).properties) for key in x.keys] | ||
modules = [ | ||
module_wrapper(in_feat, device=x.device, dtype=x.block(0).values.dtype) | ||
for in_feat in in_features | ||
] | ||
|
||
# Define the EquiLayerNorm module | ||
f = EquivariantTransformation( | ||
modules, | ||
x.keys, | ||
in_features, | ||
out_properties=[x.block(key).properties for key in x.keys], | ||
invariant_keys=metatensor.Labels( | ||
["o3_lambda"], np.array([0], dtype=np.int64).reshape(-1, 1) | ||
), | ||
) | ||
|
||
# Pass both through the linear layer | ||
Rfx = wigner_d_real.transform_tensormap_o3(f(x)) # R . f(x) | ||
fRx = f(Rx) # f(R . x) | ||
|
||
assert metatensor.allclose(fRx, Rfx, atol=1e-10, rtol=1e-10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters