Skip to content

Commit

Permalink
Merge pull request #339 from april-tools/tucker-tensor-train
Browse files Browse the repository at this point in the history
Add Tucker and Tensor-Train / MPS factorization templates
  • Loading branch information
loreloc authored Feb 5, 2025
2 parents c446634 + c733a96 commit 503dc96
Show file tree
Hide file tree
Showing 9 changed files with 502 additions and 97 deletions.
29 changes: 16 additions & 13 deletions cirkit/backend/torch/layers/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,8 @@ def __init__(
num_folds: The number of channels.
Raises:
NotImplementedError: If the arity is not 2.
ValueError: If the number of input units is not the same as the number of output units.
"""
# TODO: generalize kronecker layer as to support a greater arity
if arity != 2:
raise NotImplementedError("Kronecker only implemented for binary product units.")
super().__init__(
num_input_units,
num_input_units**arity,
Expand All @@ -177,18 +173,25 @@ def config(self) -> Mapping[str, Any]:
}

def forward(self, x: Tensor) -> Tensor:
x0 = x[:, 0].unsqueeze(dim=-1) # shape (F, B, Ki, 1).
x1 = x[:, 1].unsqueeze(dim=-2) # shape (F, B, 1, Ki).
# shape (F, B, Ki, Ki) -> (F, B, Ko=Ki**2).
return self.semiring.mul(x0, x1).flatten(start_dim=-2)
# x: (F, H, B, Ki)
y0 = x[:, 0]
for i in range(1, x.shape[1]):
y0 = y0.unsqueeze(dim=-1) # (F, B, K, 1).
y1 = x[:, i].unsqueeze(dim=-2) # (F, B, 1, Ki).
# y0: (F, B, K=K * Ki).
y0 = torch.flatten(self.semiring.mul(y0, y1), start_dim=-2)
# y0: (F, B, Ko=Ki ** arity)
return y0

def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
# x: (F, H, C, K, num_samples, D)
x0 = x[:, 0].unsqueeze(dim=3) # (F, C, Ki, 1, num_samples, D)
x1 = x[:, 1].unsqueeze(dim=2) # (F, C, 1, Ki, num_samples, D)
# shape (F, C, Ki, Ki, num_samples, D) -> (F, C, Ko=Ki**2, num_samples, D)
x = x0 + x1
return torch.flatten(x, start_dim=2, end_dim=3), None
y0 = x[:, 0]
for i in range(1, x.shape[1]):
y0 = y0.unsqueeze(dim=3) # (F, C, K, 1, num_samples, D)
y1 = x[:, i].unsqueeze(dim=2) # (F, C, 1, Ki, num_samples, D)
y0 = torch.flatten(y0 + y1, start_dim=2, end_dim=3)
# y0: (F, C, Ko=Ki ** arity, num_samples, D)
return y0, None


class TorchSumLayer(TorchInnerLayer):
Expand Down
31 changes: 20 additions & 11 deletions cirkit/backend/torch/layers/optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,18 @@ def __init__(
Args:
num_input_units: The number of input units.
num_output_units: The number of output units.
arity: The arity of the layer, must be 2. Defaults to 2.
arity: The arity of the layer. Defaults to 2.
weight: The weight parameter, which must have shape $(F, K_o, K_i^2)$,
where $F$ is the number of folds, $K_o$ is the number output units,
and $K_i$ is the number of input units.
Raises:
NotImplementedError: If the arity is not equal to 2. Future versions of cirkit
will support Tucker layers having arity greter than 2.
ValueError: If the arity is less than two.
ValueError: If the number of input and output units are incompatible with the
shape of the weight parameter.
"""
# TODO: Generalize Tucker layer to have any arity greater or equal 2
if arity != 2:
raise NotImplementedError("The Tucker layer is only implemented with arity=2")
if arity < 2:
raise ValueError("The arity should be at least 2")
super().__init__(
num_input_units, num_output_units, arity=arity, semiring=semiring, num_folds=num_folds
)
Expand All @@ -52,6 +50,16 @@ def __init__(
f"{weight.num_folds} and {weight.shape}, respectively"
)
self.weight = weight
# Construct the einsum expression that the Tucker layer computes
# For instance, if arity == 2 then we have that
# self._einsum = ((0, 1, 2), (0, 1, 3), (0, 1, 4, 2, 3), (0, 1, 4))
# Also, if arity == 3 then we have that
# self._einsum = ((0, 1, 2), (0, 1, 3), (0, 1, 4), (0, 5, 2, 3, 4), (0, 1, 5))
self._einsum = (
tuple((0, 1, i + 2) for i in range(arity))
+ ((0, arity + 2, *tuple(i + 2 for i in range(arity))),)
+ ((0, 1, arity + 2),)
)

def _valid_weight_shape(self, w: TorchParameter) -> bool:
if w.num_folds != self.num_folds:
Expand All @@ -60,7 +68,7 @@ def _valid_weight_shape(self, w: TorchParameter) -> bool:

@property
def _weight_shape(self) -> tuple[int, ...]:
return self.num_output_units, self.num_input_units * self.num_input_units
return self.num_output_units, self.num_input_units**self.arity

@property
def config(self) -> Mapping[str, Any]:
Expand All @@ -75,14 +83,15 @@ def params(self) -> Mapping[str, TorchParameter]:
return {"weight": self.weight}

def forward(self, x: Tensor) -> Tensor:
# weight: (F, Ko, Ki * Ki) -> (F, Ko, Ki, Ki)
# x: (F, H, B, Ki)
# weight: (F, Ko, Ki ** arity) -> (F, Ko, Ki, ..., Ki)
weight = self.weight().view(
-1, self.num_output_units, self.num_input_units, self.num_input_units
-1, self.num_output_units, *(self.num_input_units for _ in range(self.arity))
)
return self.semiring.einsum(
"fbi,fbj,foij->fbo",
self._einsum,
inputs=x.unbind(dim=1),
operands=(weight,),
inputs=(x[:, 0], x[:, 1]),
dim=-1,
keepdim=True,
)
Expand Down
40 changes: 31 additions & 9 deletions cirkit/backend/torch/semiring.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import itertools
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from typing import ClassVar, TypeVar, cast
Expand Down Expand Up @@ -153,19 +154,40 @@ def __new__(cls) -> "SemiringImpl":
@classmethod
def einsum(
cls,
equation: str,
equation: str | Sequence[Sequence[int, ...], ...],
*,
inputs: tuple[Tensor, ...],
operands: tuple[Tensor, ...],
inputs: tuple[Tensor, ...] | None = None,
operands: tuple[Tensor, ...] | None = None,
dim: int,
keepdim: bool,
) -> Tensor:
operands = tuple(cls.cast(opd) for opd in operands)

def _einsum_func(*xs: Tensor) -> Tensor:
return torch.einsum(equation, *xs, *operands)

return cls.apply_reduce(_einsum_func, *inputs, dim=dim, keepdim=keepdim)
# TODO (LL): We need to remove this super general yet extremely complicated and hard
# to maintain einsum definition, which depends on the semiring. A future version of the
# compiler in cirkit will be able to emit pytorch code for every layer at compile time
match equation:
case str():

def _einsum_str_func(*xs: Tensor) -> Tensor:
opds = tuple(cls.cast(opd) for opd in operands)
return torch.einsum(equation, *xs, *opds)

einsum_func = _einsum_str_func
case Sequence():

def _einsum_seq_func(*xs: Tensor) -> Tensor:
opds = tuple(cls.cast(opd) for opd in operands)
einsum_args = tuple(
itertools.chain.from_iterable(zip(xs + opds, equation[:-1]))
)
return torch.einsum(*einsum_args, equation[-1])

einsum_func = _einsum_seq_func
case _:
raise ValueError(
"The einsum expression must be either a string or a sequence of int sequences"
)

return cls.apply_reduce(einsum_func, *inputs, dim=dim, keepdim=keepdim)

# NOTE: Subclasses should not touch any of the above final static methods but should implement
# all the following abstract class methods, and subclasses should be @final.
Expand Down
Loading

0 comments on commit 503dc96

Please sign in to comment.