Skip to content

Commit

Permalink
fix stackoverflow case13 gather (#50243)
Browse files Browse the repository at this point in the history
  • Loading branch information
enkilee authored Feb 10, 2023
1 parent fb228c4 commit bf80664
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
21 changes: 21 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,27 @@ void GatherInferMeta(const MetaTensor& x,

auto input_dim = x.dims();
auto axis_v = axis.to<int>();
if (axis_v < 0) axis_v += input_dim.size();

PADDLE_ENFORCE_GE(
axis_v,
(0 - input_dim.size()),
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [%d, %d]. But received Attr(axis) = %d.",
-input_dim.size(),
input_dim.size() - 1,
axis_v));
PADDLE_ENFORCE_LT(
axis_v,
input_dim.size(),
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [%d, %d]. But received Attr(axis) = %d.",
-input_dim.size(),
input_dim.size() - 1,
axis_v));

if (index_dims.size() == 0) {
// 0D index will decrease the dimension
if (input_dim.size() == 1) {
Expand Down
23 changes: 23 additions & 0 deletions python/paddle/fluid/tests/unittests/test_gather_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,29 @@ def test_index_type():

self.assertRaises(TypeError, test_index_type)

def test_error3(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):

shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='int32', name='x')
axis = paddle.fluid.data(shape=[1], dtype='int32', name='axis')
index = paddle.fluid.data(shape=shape, dtype='int32', name='index')
index_float = paddle.fluid.data(
shape=shape, dtype='float32', name='index_float'
)

def test_axis_minsize():
paddle.gather(x, index, axis=-1)

self.assertRaises(ValueError, test_axis_minsize)

def test_axis_maxsize():
paddle.gather(x, index, axis=512)

self.assertRaises(ValueError, test_axis_maxsize)


class TestCheckOutType(unittest.TestCase):
def test_out_type(self):
Expand Down

0 comments on commit bf80664

Please sign in to comment.