From c8270b101cc9ca2fc68a613879bb6148f93a0748 Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Sun, 22 Jan 2023 16:36:09 +0000 Subject: [PATCH 1/3] add channel_num check for paddle.static.nn.batch_norm --- .../paddle/fluid/tests/unittests/test_batch_norm_op.py | 7 +++++++ python/paddle/static/nn/common.py | 10 ++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index c2a6c468e5c8f1..5b5caf91de7bac 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -768,6 +768,13 @@ def test_errors(self): ) self.assertRaises(TypeError, paddle.static.nn.batch_norm, x2) + def test_channel_num(): + paddle.enable_static() + input = paddle.static.data("", shape=[0], dtype="float32") + paddle.static.nn.batch_norm(input) + + # the channel_num should not be null in NCHW. + self.assertRaises(ValueError, test_channel_num) class TestDygraphBatchNormAPIError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 53954f49f343a5..1b1d657dfbe078 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -515,7 +515,10 @@ def data_norm( input_shape = input.shape if data_layout == 'NCHW': - channel_num = input_shape[1] + if len(input_shape) > 1: + channel_num = input_shape[1] + else: + raise ValueError("The channel_num shoule not be null in this data layout:" + data_layout) else: if data_layout == 'NHWC': channel_num = input_shape[-1] @@ -2727,7 +2730,10 @@ def batch_norm( input_shape = input.shape if data_layout == 'NCHW': - channel_num = input_shape[1] + if len(input_shape) > 1: + channel_num = input_shape[1] + else: + raise ValueError("The channel_num shoule not be null in this data layout:" + data_layout) else: if data_layout == 'NHWC': channel_num = input_shape[-1] From 9bc8b2fa301818cbf0aa4bfadb41c187023d338c Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Mon, 23 Jan 2023 14:31:26 +0000 Subject: [PATCH 2/3] fix bugs --- .../paddle/fluid/tests/unittests/test_batch_norm_op.py | 5 +++-- python/paddle/static/nn/common.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 5b5caf91de7bac..b5e8fcf3b1da35 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -768,13 +768,14 @@ def test_errors(self): ) self.assertRaises(TypeError, paddle.static.nn.batch_norm, x2) - def test_channel_num(): + def test_channel_num_is_null(): paddle.enable_static() input = paddle.static.data("", shape=[0], dtype="float32") paddle.static.nn.batch_norm(input) # the channel_num should not be null in NCHW. - self.assertRaises(ValueError, test_channel_num) + self.assertRaises(ValueError, test_channel_num_is_null) + class TestDygraphBatchNormAPIError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 1b1d657dfbe078..5f714da39996e0 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -515,10 +515,7 @@ def data_norm( input_shape = input.shape if data_layout == 'NCHW': - if len(input_shape) > 1: - channel_num = input_shape[1] - else: - raise ValueError("The channel_num shoule not be null in this data layout:" + data_layout) + channel_num = input_shape[1] else: if data_layout == 'NHWC': channel_num = input_shape[-1] @@ -2733,7 +2730,10 @@ def batch_norm( if len(input_shape) > 1: channel_num = input_shape[1] else: - raise ValueError("The channel_num shoule not be null in this data layout:" + data_layout) + raise ValueError( + "The channel_num shoule not be null in this data layout:" + + data_layout + ) else: if data_layout == 'NHWC': channel_num = input_shape[-1] From 95fc76957d88164b0ad0a5e5e74ca2bba43d636c Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Wed, 25 Jan 2023 16:05:04 +0000 Subject: [PATCH 3/3] fix bugs --- .../fluid/tests/unittests/test_batch_norm_op.py | 10 +++------- .../paddle/fluid/tests/unittests/test_fold_op.py | 2 +- python/paddle/static/nn/common.py | 14 +++++++------- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index b5e8fcf3b1da35..02171db3fca75a 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -768,13 +768,9 @@ def test_errors(self): ) self.assertRaises(TypeError, paddle.static.nn.batch_norm, x2) - def test_channel_num_is_null(): - paddle.enable_static() - input = paddle.static.data("", shape=[0], dtype="float32") - paddle.static.nn.batch_norm(input) - - # the channel_num should not be null in NCHW. - self.assertRaises(ValueError, test_channel_num_is_null) + # the first dimension of input for batch_norm must between [2d, 5d]. + x3 = paddle.static.data("", shape=[0], dtype="float32") + self.assertRaises(ValueError, paddle.static.nn.batch_norm, x3) class TestDygraphBatchNormAPIError(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fold_op.py b/python/paddle/fluid/tests/unittests/test_fold_op.py index 1f3193fa1fd494..a86161cc450230 100644 --- a/python/paddle/fluid/tests/unittests/test_fold_op.py +++ b/python/paddle/fluid/tests/unittests/test_fold_op.py @@ -179,7 +179,7 @@ def test_errors(self): with program_guard(Program(), Program()): def test_input_shape(): - # input_shpae must be 3-D + # input_shape must be 3-D x = paddle.randn(shape=[2, 3, 6, 7], dtype="float32") out = fold(x, output_sizes=[2, 3], kernel_sizes=[2, 2]) diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 5f714da39996e0..777d40aae82075 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -2726,14 +2726,14 @@ def batch_norm( dtype = core.VarDesc.VarType.FP32 input_shape = input.shape - if data_layout == 'NCHW': - if len(input_shape) > 1: - channel_num = input_shape[1] - else: - raise ValueError( - "The channel_num shoule not be null in this data layout:" - + data_layout + if len(input.shape) < 2 or len(input.shape) > 5: + raise ValueError( + 'expected 2D or 3D or 4D or 5D input (got {}D input, input shape is: {})'.format( + len(input.shape), input_shape ) + ) + if data_layout == 'NCHW': + channel_num = input_shape[1] else: if data_layout == 'NHWC': channel_num = input_shape[-1]