-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathFFTW.py
52 lines (38 loc) · 1.38 KB
/
FFTW.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#
import numpy as np
import sys, os
try:
import pyfftw
hasfftw = True
except ImportError:
print('PYFFTW not found, please run "pip install pyfftw" for up to 20x speedup')
hasfftw = False
class WrapFFTW(object):
def __init__(self, shape, **kwargs):
self.shape = shape
self._flags = kwargs.get('flags', ['FFTW_MEASURE'])
self._threads = kwargs.get('threads', 8)
self.data = pyfftw.empty_aligned(self.shape, n=16, dtype='complex64')
self.data_k = pyfftw.empty_aligned(self.shape, n=16, dtype='complex64')
self.fft_object = pyfftw.FFTW(self.data, self.data_k,
axes=(0,1), flags = self._flags,
threads = self._threads)
self.ifft_object = pyfftw.FFTW(self.data_k, self.data,
direction = 'FFTW_BACKWARD',
axes=(0,1), flags = self._flags,
threads = self._threads)
def fft(self, inp):
self.data[:,:] = inp
return self.fft_object().copy()
def ifft(self, inp):
self.data_k[:,:] = inp
return self.ifft_object().copy()
class WrapFFTW_NUMPY(object):
def __init__(self, shape, **kwargs):
self.shape = shape
def fft(self, inp):
return np.fft.fftn(inp)
def ifft(self, inp):
return np.fft.ifftn(inp)
if not hasfftw:
WrapFFTW = WrapFFTW_NUMPY