diff --git a/python/paddle/fft.py b/python/paddle/fft.py index 74b3bb23fc683..b8939d08b588b 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -1275,6 +1275,8 @@ def fftfreq(n, d=1.0, dtype=None, name=None): # Tensor(shape=[5], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # [ 0. , 0.40000001, 0.80000001, -0.80000001, -0.40000001]) """ + if d * n == 0: + raise ValueError("d or n should not be 0.") dtype = paddle.framework.get_default_dtype() val = 1.0 / (n * d) diff --git a/python/paddle/fluid/tests/unittests/fft/test_fft.py b/python/paddle/fluid/tests/unittests/fft/test_fft.py index 1b42badd1481a..8a57fa81b5729 100644 --- a/python/paddle/fluid/tests/unittests/fft/test_fft.py +++ b/python/paddle/fluid/tests/unittests/fft/test_fft.py @@ -1823,6 +1823,23 @@ def test_fftfreq(self): ) +@place(DEVICES) +@parameterize( + (TEST_CASE_NAME, 'n', 'd', 'dtype', 'expect_exception'), + [ + ('test_with_0_0', 0, 0, 'float32', ValueError), + ('test_with_n_0', 20, 0, 'float32', ValueError), + ('test_with_0_d', 0, 20, 'float32', ValueError), + ], +) +class TestFftFreqException(unittest.TestCase): + def test_fftfreq2(self): + """Test fftfreq with d = 0""" + with paddle.fluid.dygraph.guard(self.place): + with self.assertRaises(self.expect_exception): + paddle.fft.fftfreq(self.n, self.d, self.dtype) + + @place(DEVICES) @parameterize( (TEST_CASE_NAME, 'n', 'd', 'dtype'),