Skip to content

Commit

Permalink
[Divide by 0 Error] add pinv check (PaddlePaddle#49951)
Browse files Browse the repository at this point in the history
* add pinv check

* add unitest

* update unitest

* roll back

* fix not call stupid bug

* use context
  • Loading branch information
DrRyanHuang authored and pangengzheng committed Feb 2, 2023
1 parent eb40e7d commit 7ccc7dc
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
8 changes: 8 additions & 0 deletions paddle/phi/kernels/cpu/svd_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ void SvdKernel(const Context& dev_ctx,
// int k = std::min(rows, cols);
// int col_u = full ? rows : k;
// int col_v = full ? cols : k;
PADDLE_ENFORCE_LT(
0,
rows,
errors::InvalidArgument("The row of Input(X) should be greater than 0."));
PADDLE_ENFORCE_LT(
0,
cols,
errors::InvalidArgument("The col of Input(X) should be greater than 0."));
int batches = numel / (rows * cols);
auto* U_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(U);
auto* VH_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH);
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/gpu/svd_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,15 @@ void SvdKernel(const Context& dev_ctx,
int m = dims[rank - 2];
int n = dims[rank - 1];

PADDLE_ENFORCE_LT(
0,
m,
errors::InvalidArgument("The row of Input(X) should be greater than 0."));
PADDLE_ENFORCE_LT(
0,
n,
errors::InvalidArgument("The col of Input(X) should be greater than 0."));

auto* u_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(U);
auto* vh_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH);
auto* s_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(S);
Expand Down
22 changes: 22 additions & 0 deletions python/paddle/fluid/tests/unittests/test_linalg_pinv_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,5 +280,27 @@ def init_config(self):
self.hermitian = True


class TestDivByZero(unittest.TestCase):
def pinv_zero_input_static(self):

paddle.enable_static()
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32')
paddle.linalg.pinv(x)

def pinv_zero_input_dynamic(self):

paddle.disable_static()
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32')
paddle.linalg.pinv(x)

def test_div_by_zero(self):

with self.assertRaises(ValueError):
self.pinv_zero_input_dynamic()
self.pinv_zero_input_static()


if __name__ == '__main__':
unittest.main()

0 comments on commit 7ccc7dc

Please sign in to comment.