Skip to content

Commit

Permalink
Add ComplexAvgPool2d
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Feb 15, 2021
1 parent 22137ab commit 8ccdd2d
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion complexLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.nn import Module, Parameter, init
from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d
from torch.nn import ConvTranspose2d
from complexFunctions import complex_relu, complex_max_pool2d
from complexFunctions import complex_relu, complex_max_pool2d, complex_avg_pool2d
from complexFunctions import complex_dropout, complex_dropout2d

def apply_complex(fr, fi, input):
Expand Down Expand Up @@ -59,6 +59,25 @@ def forward(self,input):
stride = self.stride, padding = self.padding,
dilation = self.dilation, ceil_mode = self.ceil_mode,
return_indices = self.return_indices)


class ComplexAvgPool2d(Module):

def __init__(self,kernel_size, stride= None, padding = 0,
dilation = 1, return_indices = False, ceil_mode = False):
super(ComplexAvgPool2d,self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.ceil_mode = ceil_mode
self.return_indices = return_indices

def forward(self,input):
return complex_avg_pool2d(input,kernel_size = self.kernel_size,
stride = self.stride, padding = self.padding,
dilation = self.dilation, ceil_mode = self.ceil_mode,
return_indices = self.return_indices)

class ComplexReLU(Module):

Expand Down

0 comments on commit 8ccdd2d

Please sign in to comment.