Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ComplexConvTransposeNd #26

Open
H320 opened this issue Dec 16, 2022 · 0 comments
Open

ComplexConvTransposeNd #26

H320 opened this issue Dec 16, 2022 · 0 comments

Comments

@H320
Copy link

H320 commented Dec 16, 2022

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Convolution.cpp#L812

Complex Convolution
conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
where W, x and b are all complex inputs.
With Gauss Trick:
a = conv(Wr, xr, br),
b = conv(Wi, xi, 0),
c = conv(Wr + Wi, xr + xi, bi + br)
conv(W, x, b) = a - b + i(c - a - b)

from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import Tensor, nn


class ComplexConvTranspose1dn(nn.ConvTranspose1d):

    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')

        assert isinstance(self.padding, tuple)
        # One cannot replace List by Tuple or Sequence in "_output_padding" because
        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
        num_spatial_dims = 1
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
            num_spatial_dims, self.dilation)  # type: ignore[arg-type]

        i_r = input.real
        i_i = input.imag
        w_r = self.weight.real
        w_i = self.weight.imag
        b_r = self.bias.real
        b_i = self.bias.imag

        a = F.conv_transpose1d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
        b = F.conv_transpose1d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
        c = F.conv_transpose1d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)

        return torch.complex(a - b, c - a - b)


class ComplexConvTranspose2dn(nn.ConvTranspose2d):

    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')

        assert isinstance(self.padding, tuple)
        # One cannot replace List by Tuple or Sequence in "_output_padding" because
        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
        num_spatial_dims = 2
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
            num_spatial_dims, self.dilation)  # type: ignore[arg-type]

        i_r = input.real
        i_i = input.imag
        w_r = self.weight.real
        w_i = self.weight.imag
        b_r = self.bias.real
        b_i = self.bias.imag

        a = F.conv_transpose2d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
        b = F.conv_transpose2d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
        c = F.conv_transpose2d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)

        return torch.complex(a - b, c - a - b)


class ComplexConvTranspose3dn(nn.ConvTranspose3d):

    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')

        assert isinstance(self.padding, tuple)
        # One cannot replace List by Tuple or Sequence in "_output_padding" because
        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
        num_spatial_dims = 3
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
            num_spatial_dims, self.dilation)  # type: ignore[arg-type]

        i_r = input.real
        i_i = input.imag
        w_r = self.weight.real
        w_i = self.weight.imag
        b_r = self.bias.real
        b_i = self.bias.imag

        a = F.conv_transpose3d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
        b = F.conv_transpose3d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
        c = F.conv_transpose3d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)

        return torch.complex(a - b, c - a - b)

@H320 H320 changed the title ComplexConvTranspose ComplexConvTransposeNd Dec 16, 2022
@H320 H320 changed the title ComplexConvTransposeNd ComplexConv[Transpose]Nd Dec 16, 2022
@H320 H320 changed the title ComplexConv[Transpose]Nd ComplexConvTransposeNd Dec 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant