You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Complex Convolution
conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
where W, x and b are all complex inputs.
With Gauss Trick:
a = conv(Wr, xr, br),
b = conv(Wi, xi, 0),
c = conv(Wr + Wi, xr + xi, bi + br)
conv(W, x, b) = a - b + i(c - a - b)
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
class ComplexConvTranspose1dn(nn.ConvTranspose1d):
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
i_r = input.real
i_i = input.imag
w_r = self.weight.real
w_i = self.weight.imag
b_r = self.bias.real
b_i = self.bias.imag
a = F.conv_transpose1d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
b = F.conv_transpose1d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
c = F.conv_transpose1d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)
return torch.complex(a - b, c - a - b)
class ComplexConvTranspose2dn(nn.ConvTranspose2d):
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 2
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
i_r = input.real
i_i = input.imag
w_r = self.weight.real
w_i = self.weight.imag
b_r = self.bias.real
b_i = self.bias.imag
a = F.conv_transpose2d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
b = F.conv_transpose2d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
c = F.conv_transpose2d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)
return torch.complex(a - b, c - a - b)
class ComplexConvTranspose3dn(nn.ConvTranspose3d):
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
num_spatial_dims = 3
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
num_spatial_dims, self.dilation) # type: ignore[arg-type]
i_r = input.real
i_i = input.imag
w_r = self.weight.real
w_i = self.weight.imag
b_r = self.bias.real
b_i = self.bias.imag
a = F.conv_transpose3d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
b = F.conv_transpose3d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
c = F.conv_transpose3d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)
return torch.complex(a - b, c - a - b)
The text was updated successfully, but these errors were encountered:
H320
changed the title
ComplexConvTranspose
ComplexConvTransposeNd
Dec 16, 2022
H320
changed the title
ComplexConvTransposeNd
ComplexConv[Transpose]Nd
Dec 16, 2022
H320
changed the title
ComplexConv[Transpose]Nd
ComplexConvTransposeNd
Dec 16, 2022
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Convolution.cpp#L812
Complex Convolution
conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
where W, x and b are all complex inputs.
With Gauss Trick:
a = conv(Wr, xr, br),
b = conv(Wi, xi, 0),
c = conv(Wr + Wi, xr + xi, bi + br)
conv(W, x, b) = a - b + i(c - a - b)
The text was updated successfully, but these errors were encountered: