Skip to content

Commit

Permalink
Implement metatensor.learn.nn module EquivariantTransformation (m…
Browse files Browse the repository at this point in the history
…etatensor#753)

---------

Co-authored-by: Joseph W. Abbott <[email protected]>
Co-authored-by: Guillaume Fraux <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2024
1 parent 8dc9f9f commit 401c590
Show file tree
Hide file tree
Showing 6 changed files with 356 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/src/learn/reference/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ Modules

.. autoclass:: metatensor.learn.nn.InvariantLayerNorm
:members:

.. autoclass:: metatensor.learn.nn.EquivariantTransformation
:members:
7 changes: 7 additions & 0 deletions python/metatensor-learn/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/metatensor/metatensor/)

### Added

- Added `metatensor.learn.nn.EquivariantTransformation` to apply any
`torch.nn.Module` to invariants computed from the norm over components of covariant
blocks. The transformed invariants are then elementwise multiplied back to the
covariant blocks. For invariant blocks, the `torch.nn.Module` is applied as is (#744)

<!-- Possible sections
### Added
Expand Down
1 change: 1 addition & 0 deletions python/metatensor-learn/metatensor/learn/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .equivariant_transformation import EquivariantTransformation # noqa: F401
from .layer_norm import InvariantLayerNorm, LayerNorm # noqa: F401
from .linear import EquivariantLinear, Linear # noqa: F401
from .module_map import ModuleMap # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from typing import List, Optional, Union

import torch

from .._backend import Labels, TensorMap
from .._dispatch import int_array_like
from ._utils import _check_module_map_parameter
from .module_map import ModuleMap


class EquivariantTransformation(torch.nn.Module):
"""
A custom :py:class:`torch.nn.Module` that applies an arbitrary shape- and
equivariance-preserving transformation to an input :py:class:`TensorMap`.
For invariant blocks (specified with ``invariant_keys``), the respective
transformation contained in :param modules: is applied as is. For covariant blocks,
an invariant multiplier is created, applying the transformation to the norm of the
block over the component dimensions.
:param modules: a :py:class:`list` of :py:class:`torch.nn.Module` containing the
transformations to be applied to each block indexed by
:param in_keys:. Transformations for invariant and covariant blocks differ. See
above.
:param in_keys: :py:class:`Labels`, the keys that are assumed to be in the input
:py:class:`TensorMap` in the :py:meth:`forward` method.
:param in_features: :py:class:`list` of :py:class:`int`, the number of features in
the input tensor for each block indexed by the keys in :param in_keys:. If
passed as a single value, the same number of features is assumed for all blocks.
:param out_properties: :py:class:`list` of :py:class`Labels` (optional), the
properties labels
of the output. By default the output properties are relabeled using
Labels.range.
:param invariant_keys: a :py:class:`Labels` object that is used to select the
invariant keys from ``in_keys``. If not provided, the invariant keys are assumed
to be those where key dimensions ``["o3_lambda", "o3_sigma"]`` are equal to
``[0, 1]``.
>>> import torch
>>> import numpy as np
>>> from metatensor import Labels, TensorBlock, TensorMap
>>> from metatensor.learn.nn import EquivariantTransformation
Define a dummy invariant TensorBlock
>>> block_1 = TensorBlock(
... values=torch.randn(2, 1, 3),
... samples=Labels(
... ["system", "atom"],
... np.array(
... [
... [0, 0],
... [0, 1],
... ]
... ),
... ),
... components=[Labels(["o3_mu"], np.array([[0]]))],
... properties=Labels(["properties"], np.array([[0], [1], [2]])),
... )
Define a dummy covariant TensorBlock
>>> block_2 = TensorBlock(
... values=torch.randn(2, 3, 3),
... samples=Labels(
... ["system", "atom"],
... np.array(
... [
... [0, 0],
... [0, 1],
... ]
... ),
... ),
... components=[Labels(["o3_mu"], np.array([[-1], [0], [1]]))],
... properties=Labels(["properties"], np.array([[3], [4], [5]])),
... )
Create a TensorMap containing the dummy TensorBlocks
>>> keys = Labels(names=["o3_lambda"], values=np.array([[0], [1]]))
>>> tensor = TensorMap(keys, [block_1, block_2])
Define the transformation to apply to the TensorMap
>>> modules = [torch.nn.Tanh(), torch.nn.Tanh()]
>>> in_features = [len(tensor.block(key).properties) for key in tensor.keys]
Define the EquivariantTransformation module
>>> transformation = EquivariantTransformation(
... modules,
... tensor.keys,
... in_features,
... out_properties=[tensor.block(key).properties for key in tensor.keys],
... invariant_keys=Labels(
... ["o3_lambda"], np.array([0], dtype=np.int64).reshape(-1, 1)
... ),
... )
The output metadata are the same as the input
>>> transformation(tensor)
TensorMap with 2 blocks
keys: o3_lambda
0
1
>>> transformation(tensor)[0]
TensorBlock
samples (2): ['system', 'atom']
components (1): ['o3_mu']
properties (3): ['properties']
gradients: None
"""

def __init__(
self,
modules: List[torch.nn.Module],
in_keys: Labels,
in_features: Union[int, List[int]],
out_features: Optional[Union[int, List[int]]] = None,
out_properties: Optional[List[Labels]] = None,
invariant_keys: Optional[Labels] = None,
) -> None:
super().__init__()

# Set a default for invariant keys
if invariant_keys is None:
invariant_keys = Labels(
names=["o3_lambda", "o3_sigma"],
values=int_array_like([0, 1], like=in_keys.values).reshape(-1, 2),
)
invariant_key_idxs = in_keys.select(invariant_keys)

# Infer `out_features` if not provided
if out_features is None:
if out_properties is None:
raise ValueError(
"If `out_features` is not provided,"
" `out_properties` must be provided."
)
out_features = [len(p) for p in out_properties]

# Check input parameters, convert to lists (for each key) if necessary
in_features = _check_module_map_parameter(
in_features, "in_features", int, len(in_keys), "in_keys"
)

modules_for_map: List[torch.nn.Module] = []
for i in range(len(in_keys)):
if i in invariant_key_idxs:
module_i = modules[i]
else:
module_i = _CovariantTransform(
module=modules[i],
)
modules_for_map.append(module_i)

self.module_map = ModuleMap(in_keys, modules_for_map, out_properties)

def forward(self, tensor: TensorMap) -> TensorMap:
"""
Apply the transformation to the input tensor map `tensor`.
:param tensor: :py:class:`TensorMap` with the input tensor to be transformed.
:return: :py:class:`TensorMap` corresponding to the transformed input
``tensor``.
"""
return self.module_map(tensor)


class _CovariantTransform(torch.nn.Module):
"""
Applies an arbitrary shape-preserving transformation defined in ``module`` to a
3-dimensional tensor in a way that preserves equivariance. The transformation is
applied to the norm of the :py:class:`torch.Tensor` over the component dimension.
The resulting :py:class:`torch.Tensor` is elementwise multiplied back to the
original one, thus preserving covariance.
:param in_features: a :py:class:`int`, the input feature dimension. This also
corresponds to the output feature size as the shape of the tensor passed to
:py:meth:`forward` is preserved.
:param module: :py:class:`torch.nn.Module` containing the transformation to be
applied to the invariants constructed from the norms over the component
dimension of the input :py:class:`torch.Tensor` passed to the :py:meth:`forward`
method.
"""

def __init__(
self,
module: torch.nn.Module,
) -> None:
super().__init__()

self.module = module

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Creates an invariant block from the ``input`` covariant, and transforms it by
applying the torch ``module`` passed to the class constructor. Then uses the
transformed invariant as an elementwise multiplier for the ``input`` block.
Transformations are applied consistently to components (axis 1) to preserve
equivariance.
"""
assert len(input.shape) == 3, "``input`` must be a three-dimensional tensor"
invariant = input.norm(dim=1, keepdim=True)
invariant_transformed = self.module(invariant)
tensor_out = invariant_transformed * input

return tensor_out
88 changes: 88 additions & 0 deletions python/metatensor-learn/tests/equivariant_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os

import numpy as np
import pytest

import metatensor


torch = pytest.importorskip("torch")

from metatensor.learn.nn.equivariant_transformation import ( # noqa: E402
EquivariantTransformation, # noqa: E402
) # noqa:E402

from ._rotation_utils import WignerDReal # noqa: E402


DATA_ROOT = os.path.join(
os.path.dirname(__file__), "..", "..", "metatensor-operations", "tests", "data"
)


@pytest.fixture
def tensor():
tensor = metatensor.load(os.path.join(DATA_ROOT, "qm7-spherical-expansion.npz"))
tensor = tensor.to(arrays="torch")
tensor = metatensor.remove_gradients(tensor)
return tensor


@pytest.fixture
def wigner_d_real():
return WignerDReal(lmax=4, angles=(0.87641, 1.8729, 0.9187))


def module_wrapper(in_features, device, dtype):
"""
A sequential module with nonlinearities
"""
return torch.nn.Sequential(
torch.nn.Tanh(),
torch.nn.Linear(
in_features=in_features,
out_features=5,
device=device,
dtype=dtype,
),
torch.nn.ReLU(),
torch.nn.Linear(
in_features=5,
out_features=in_features,
device=device,
dtype=dtype,
),
)


def test_equivariance(tensor, wigner_d_real):
"""
Tests that application of the EquivariantTransformation layer is equivariant to O3
transformation of the input.
"""
# Define input and rotated input
x = tensor
Rx = wigner_d_real.transform_tensormap_o3(x)

in_features = [len(x.block(key).properties) for key in x.keys]
modules = [
module_wrapper(in_feat, device=x.device, dtype=x.block(0).values.dtype)
for in_feat in in_features
]

# Define the EquiLayerNorm module
f = EquivariantTransformation(
modules,
x.keys,
in_features,
out_properties=[x.block(key).properties for key in x.keys],
invariant_keys=metatensor.Labels(
["o3_lambda"], np.array([0], dtype=np.int64).reshape(-1, 1)
),
)

# Pass both through the linear layer
Rfx = wigner_d_real.transform_tensormap_o3(f(x)) # R . f(x)
fRx = f(Rx) # f(R . x)

assert metatensor.allclose(fRx, Rfx, atol=1e-10, rtol=1e-10)
46 changes: 46 additions & 0 deletions python/metatensor-torch/tests/learn/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from metatensor.torch import allclose_raise
from metatensor.torch.learn.nn import (
EquivariantLinear,
EquivariantTransformation,
InvariantLayerNorm,
InvariantReLU,
InvariantSiLU,
Expand Down Expand Up @@ -190,3 +191,48 @@ def test_sequential(tensor):
)
check_module_torch_script(module, tensor)
check_module_save_load(module)


def test_equivariant_transform(tensor):
"""Tests module EquivariantTransformation"""

def module_wrapper(in_features, device, dtype):
"""
A sequential module with nonlinearities
"""
return torch.nn.Sequential(
torch.nn.Tanh(),
torch.nn.Linear(
in_features=in_features,
out_features=5,
device=device,
dtype=dtype,
),
torch.nn.ReLU(),
torch.nn.Linear(
in_features=5,
out_features=in_features,
device=device,
dtype=dtype,
),
)

in_keys = tensor.keys
in_features = [len(tensor.block(key).properties) for key in in_keys]

modules = [
module_wrapper(
in_feat, device=tensor.device, dtype=tensor.block(0).values.dtype
)
for in_feat in in_features
]

module = EquivariantTransformation(
modules,
in_keys,
in_features,
out_properties=[tensor.block(key).properties for key in tensor.keys],
invariant_keys=in_keys,
)
check_module_torch_script(module, tensor)
check_module_save_load(module)

0 comments on commit 401c590

Please sign in to comment.