diff --git a/complexLayers.py b/complexLayers.py index 5afd606..e223aab 100755 --- a/complexLayers.py +++ b/complexLayers.py @@ -13,7 +13,7 @@ from torch.nn import Module, Parameter, init from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d from torch.nn import ConvTranspose2d -from complexFunctions import complex_relu, complex_max_pool2d +from complexFunctions import complex_relu, complex_max_pool2d, complex_avg_pool2d from complexFunctions import complex_dropout, complex_dropout2d def apply_complex(fr, fi, input): @@ -59,6 +59,25 @@ def forward(self,input): stride = self.stride, padding = self.padding, dilation = self.dilation, ceil_mode = self.ceil_mode, return_indices = self.return_indices) + + +class ComplexAvgPool2d(Module): + + def __init__(self,kernel_size, stride= None, padding = 0, + dilation = 1, return_indices = False, ceil_mode = False): + super(ComplexAvgPool2d,self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.return_indices = return_indices + + def forward(self,input): + return complex_avg_pool2d(input,kernel_size = self.kernel_size, + stride = self.stride, padding = self.padding, + dilation = self.dilation, ceil_mode = self.ceil_mode, + return_indices = self.return_indices) class ComplexReLU(Module):