From 2bbc800f90232762420d1eca61ad678de5a5f25b Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 29 Jul 2023 07:00:49 +0000 Subject: [PATCH 1/2] add BF16 datatype for argsort --- paddle/phi/kernels/gpu/argsort_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/argsort_kernel.cu | 8 +++- python/paddle/tensor/search.py | 1 + test/legacy_test/test_argsort_op.py | 47 +++++++++++++++++++ 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu index b8d9df64c23efb..695044c095735e 100644 --- a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu @@ -222,4 +222,5 @@ PD_REGISTER_KERNEL(argsort_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/argsort_kernel.cu b/paddle/phi/kernels/gpu/argsort_kernel.cu index 5102594f98d1e0..3b502a567a499f 100644 --- a/paddle/phi/kernels/gpu/argsort_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_kernel.cu @@ -61,6 +61,11 @@ namespace cub { template <> struct NumericTraits : BaseTraits {}; + +template <> +struct NumericTraits + : BaseTraits { +}; } // namespace cub #endif @@ -328,6 +333,7 @@ PD_REGISTER_KERNEL(argsort, double, int, int64_t, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->OutputAt(1).SetDataType(phi::DataType::INT64); } diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 6172c0247554d6..5f62f2cc539d36 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -109,6 +109,7 @@ def argsort(x, axis=-1, descending=False, name=None): 'int32', 'int64', 'uint8', + 'uint16', ], 'argsort', ) diff --git a/test/legacy_test/test_argsort_op.py b/test/legacy_test/test_argsort_op.py index ec6db2f6651e99..8dd751a9a926f2 100644 --- a/test/legacy_test/test_argsort_op.py +++ b/test/legacy_test/test_argsort_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid @@ -513,5 +514,51 @@ def test_fp16(self): out = exe.run(feed={'x': x_np}, fetch_list=[out]) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestArgsortBF16OP(OpTest): + def setUp(self): + self.init() + self.op_type = "argsort" + self.python_api = paddle.argsort + self.public_python_api = paddle.argsort + self.dtype = np.uint16 + self.descending = False + self.attrs = {"axis": self.axis, "descending": self.descending} + self.x = np.random.rand(*self.input_shape).astype(np.float32) + self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis).astype( + np.float32 + ) + self.indices = np.argsort( + self.x, kind='heapsort', axis=self.axis + ).astype(np.float32) + self.inputs = {'X': convert_float_to_uint16(self.x)} + self.outputs = { + 'Out': convert_float_to_uint16(self.sorted_x), + "Indices": convert_float_to_uint16(self.indices), + } + + def init(self): + self.input_shape = [ + 1000, + ] + self.axis = 0 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X'], + 'Out', + ) + + if __name__ == "__main__": unittest.main() From c2b316501bc5ade49425290e26b829163e994048 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sun, 30 Jul 2023 06:21:22 +0000 Subject: [PATCH 2/2] fix bugs --- test/legacy_test/test_argsort_op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/legacy_test/test_argsort_op.py b/test/legacy_test/test_argsort_op.py index 8dd751a9a926f2..3a5dff216af4ab 100644 --- a/test/legacy_test/test_argsort_op.py +++ b/test/legacy_test/test_argsort_op.py @@ -525,6 +525,9 @@ def setUp(self): self.op_type = "argsort" self.python_api = paddle.argsort self.public_python_api = paddle.argsort + self.python_out_sig = [ + "Out" + ] # python out sig is customized output signature. self.dtype = np.uint16 self.descending = False self.attrs = {"axis": self.axis, "descending": self.descending}