Skip to content

Commit

Permalink
【complex op】No.25 add complex support for fold (#56914)
Browse files Browse the repository at this point in the history
* [complex] add complex support for fold

* fix ut
  • Loading branch information
BeingGod authored Sep 11, 2023
1 parent 7fe2442 commit 5919c7a
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 10 deletions.
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/fold_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fold_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
fold_grad, CPU, ALL_LAYOUT, phi::FoldGradKernel, float, double) {}
PD_REGISTER_KERNEL(fold_grad,
CPU,
ALL_LAYOUT,
phi::FoldGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
9 changes: 8 additions & 1 deletion paddle/phi/kernels/cpu/fold_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fold_kernel_impl.h"

PD_REGISTER_KERNEL(fold, CPU, ALL_LAYOUT, phi::FoldKernel, float, double) {}
PD_REGISTER_KERNEL(fold,
CPU,
ALL_LAYOUT,
phi::FoldKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
24 changes: 24 additions & 0 deletions paddle/phi/kernels/funcs/im2col.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,24 @@ template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
double>;
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
phi::dtype::complex<float>>;
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
phi::dtype::complex<double>>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
float>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
double>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
phi::dtype::complex<float>>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::CPUContext,
phi::dtype::complex<double>>;

/*
* im = [input_channels, input_height, input_width]
Expand Down Expand Up @@ -331,11 +343,23 @@ template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
double>;
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
phi::dtype::complex<float>>;
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
phi::dtype::complex<double>>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
float>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
double>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
phi::dtype::complex<float>>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::CPUContext,
phi::dtype::complex<double>>;
} // namespace funcs
} // namespace phi
24 changes: 24 additions & 0 deletions paddle/phi/kernels/funcs/im2col.cu
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@ template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
double>;
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::complex<float>>;
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::complex<double>>;
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::float16>;
Expand All @@ -322,6 +328,12 @@ template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
double>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::complex<float>>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::complex<double>>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::float16>;
Expand Down Expand Up @@ -573,6 +585,12 @@ template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
double>;
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::complex<float>>;
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::complex<double>>;
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::float16>;
Expand All @@ -585,6 +603,12 @@ template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
double>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::complex<float>>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::complex<double>>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::float16>;
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/fold_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fold_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
fold_grad, GPU, ALL_LAYOUT, phi::FoldGradKernel, float, double) {}
PD_REGISTER_KERNEL(fold_grad,
GPU,
ALL_LAYOUT,
phi::FoldGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
9 changes: 8 additions & 1 deletion paddle/phi/kernels/gpu/fold_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fold_kernel_impl.h"

PD_REGISTER_KERNEL(fold, GPU, ALL_LAYOUT, phi::FoldKernel, float, double) {}
PD_REGISTER_KERNEL(fold,
GPU,
ALL_LAYOUT,
phi::FoldKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
6 changes: 4 additions & 2 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2280,7 +2280,7 @@ def fold(
Parameters:
x(Tensor): 3-D Tensor, input tensor of format [N, C, L],
data type can be float32 or float64
data type can be float32, float64, complex64 or complex128
output_sizes(int|list|tuple): The size of output size, should be [output_size_h, output_size_w]
or an interger o treated as [o, o].
kernel_sizes(int|list|tuple): The size of convolution kernel, should be [k_h, k_w]
Expand Down Expand Up @@ -2325,7 +2325,9 @@ def fold(

helper = LayerHelper("fold", **locals())

check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'fold')
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'complex64', 'complex128'], 'fold'
)

assert len(x.shape) == 3, "input should be the format of [N, C, L]"

Expand Down
23 changes: 21 additions & 2 deletions test/legacy_test/test_fold_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@ def init_data(self):
self.dilations = [1, 1]
self.output_sizes = [4, 5]
input_shape = [self.batch_size, self.input_channels, self.length]
self.x = np.random.rand(*input_shape).astype(np.float64)
self.x = np.random.rand(*input_shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.x = (
np.random.uniform(-1, 1, input_shape)
+ 1j * np.random.uniform(-1, 1, input_shape)
).astype(self.dtype)

def init_dtype(self):
self.dtype = np.float64

def calc_fold(self):
output_shape = [0] * 4
Expand Down Expand Up @@ -75,7 +83,7 @@ def calc_fold(self):
)
+ 1
)
output = np.zeros(output_shape).astype(np.float64)
output = np.zeros(output_shape).astype(self.dtype)
# ------------- calculate output ------------- #
for b in range(output_shape[0]):
for c in range(self.input_channels):
Expand Down Expand Up @@ -106,6 +114,7 @@ def calc_fold(self):
self.outputs = output

def set_data(self):
self.init_dtype()
self.init_data()
self.calc_fold()
self.inputs = {'X': OpTest.np_dtype_to_base_dtype(self.x)}
Expand All @@ -130,6 +139,16 @@ def test_check_grad(self):
self.check_grad(['X'], 'Y')


class TestFold_Complex64(TestFoldOp):
def init_dtype(self):
self.dtype = np.complex64


class TestFold_Complex128(TestFoldOp):
def init_dtype(self):
self.dtype = np.complex128


class TestFoldshape(TestFoldOp):
def init_data(self):
self.batch_size = 8
Expand Down

0 comments on commit 5919c7a

Please sign in to comment.