Skip to content

Commit

Permalink
Fix Python IndexError of case13: paddle.static.nn.batch_norm (#50011)
Browse files Browse the repository at this point in the history
* add channel_num check for paddle.static.nn.batch_norm

* fix bugs

* fix bugs
  • Loading branch information
ccsuzzh authored Jan 31, 2023
1 parent 0d32f55 commit da11aa4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
4 changes: 4 additions & 0 deletions python/paddle/fluid/tests/unittests/test_batch_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,10 @@ def test_errors(self):
)
self.assertRaises(TypeError, paddle.static.nn.batch_norm, x2)

# the first dimension of input for batch_norm must between [2d, 5d].
x3 = paddle.static.data("", shape=[0], dtype="float32")
self.assertRaises(ValueError, paddle.static.nn.batch_norm, x3)


class TestDygraphBatchNormAPIError(unittest.TestCase):
def test_errors(self):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_fold_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_errors(self):
with program_guard(Program(), Program()):

def test_input_shape():
# input_shpae must be 3-D
# input_shape must be 3-D
x = paddle.randn(shape=[2, 3, 6, 7], dtype="float32")
out = fold(x, output_sizes=[2, 3], kernel_sizes=[2, 2])

Expand Down
6 changes: 6 additions & 0 deletions python/paddle/static/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,6 +2731,12 @@ def batch_norm(
dtype = core.VarDesc.VarType.FP32

input_shape = input.shape
if len(input.shape) < 2 or len(input.shape) > 5:
raise ValueError(
'expected 2D or 3D or 4D or 5D input (got {}D input, input shape is: {})'.format(
len(input.shape), input_shape
)
)
if data_layout == 'NCHW':
channel_num = input_shape[1]
else:
Expand Down

0 comments on commit da11aa4

Please sign in to comment.