Skip to content

Commit

Permalink
Changes to make test cases work (gwastro#1172)
Browse files Browse the repository at this point in the history
* Changes to make test cases work

* Remove debugging print statement
  • Loading branch information
spxiwh authored and ahnitz committed Dec 5, 2016
1 parent 545b80a commit 63159d7
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 79 deletions.
1 change: 1 addition & 0 deletions pycbc/fft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .parser_support import insert_fft_option_group, verify_fft_options, from_cli
from .func_api import fft, ifft
from .class_api import FFT, IFFT
from .backend_support import get_backend_names
45 changes: 24 additions & 21 deletions test/fft_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _test_fft(test_case,inarr,expec,tol):
if hasattr(outarr,'delta_f'):
outarr._delta_f *= 5*tol
with tc.context:
pycbc.fft.fft(inarr,outarr,tc.backends)
pycbc.fft.fft(inarr, outarr)
# First, verify that the input hasn't been overwritten
emsg = 'FFT overwrote input array'
tc.assertEqual(inarr,in_pristine,emsg)
Expand Down Expand Up @@ -168,7 +168,7 @@ def _test_ifft(test_case,inarr,expec,tol):
if hasattr(outarr,'delta_f'):
outarr._delta_f *= 5*tol
with tc.context:
pycbc.fft.ifft(inarr,outarr,tc.backends)
pycbc.fft.ifft(inarr, outarr)
# First, verify that the input hasn't been overwritten
emsg = 'Inverse FFT overwrote input array'
tc.assertEqual(inarr,in_pristine,emsg)
Expand Down Expand Up @@ -204,8 +204,8 @@ def _test_random(test_case,inarr,outarr,tol):
if type(inarr) == pycbc.types.Array:
incopy *= len(inarr)
with tc.context:
pycbc.fft.fft(inarr,outarr,tc.backends)
pycbc.fft.ifft(outarr,inarr,tc.backends)
pycbc.fft.fft(inarr, outarr)
pycbc.fft.ifft(outarr, inarr)
emsg="IFFT(FFT(random)) did not reproduce original array to within tolerance {0}".format(tol)
if isinstance(incopy,ts) or isinstance(incopy,fs):
tc.assertTrue(incopy.almost_equal_norm(inarr,tol=tol,dtol=tol),
Expand All @@ -231,8 +231,8 @@ def _test_random(test_case,inarr,outarr,tol):
if type(outarr) == pycbc.types.Array:
outcopy *= len(inarr)
with tc.context:
pycbc.fft.ifft(outarr,inarr,tc.backends)
pycbc.fft.fft(inarr,outarr,tc.backends)
pycbc.fft.ifft(outarr, inarr)
pycbc.fft.fft(inarr, outarr)
emsg="FFT(IFFT(random)) did not reproduce original array to within tolerance {0}".format(tol)
if isinstance(outcopy,ts) or isinstance(outcopy,fs):
tc.assertTrue(outcopy.almost_equal_norm(outarr,tol=tol,dtol=tol),
Expand All @@ -256,7 +256,7 @@ def _test_lal_tf_fft(test_case,inarr,outarr,tol):
outlal = outarr.lal()
# Calculate the pycbc fft:
with tc.context:
pycbc.fft.fft(inarr,outarr,tc.backends)
pycbc.fft.fft(inarr, outarr)
fwdplan = _fwd_plan_dict[dtype(inarr).type](len(inarr),0)
# Call the lal function directly (see above for dict). Note that
# lal functions want *output* given first.
Expand Down Expand Up @@ -287,7 +287,7 @@ def _test_lal_tf_ifft(test_case,inarr,outarr,tol):
outlal = outarr.lal()
# Calculate the pycbc fft:
with tc.context:
pycbc.fft.ifft(inarr,outarr,tc.backends)
pycbc.fft.ifft(inarr, outarr)
revplan = _rev_plan_dict[dtype(outarr).type](len(outarr),0)
# Call the lal function directly (see above for dict). Note that
# lal functions want *output* given first.
Expand Down Expand Up @@ -318,25 +318,25 @@ def _test_raise_excep_fft(test_case,inarr,outarr,other_args=None):
outzer = pycbc.types.zeros(len(outarr))
# If we give an output array that is wrong only in length, raise ValueError:
out_badlen = outty(pycbc.types.zeros(len(outarr)+1),dtype=outarr.dtype,**other_args)
args = [inarr,out_badlen,tc.backends]
tc.assertRaises(ValueError,pycbc.fft.fft,*args)
args = [inarr, out_badlen]
tc.assertRaises(ValueError, pycbc.fft.fft, *args)
# If we give an output array that has the wrong precision, raise ValueError:
out_badprec = outty(outzer,dtype=_other_prec[dtype(outarr).type],**other_args)
args = [inarr,out_badprec,tc.backends]
args = [inarr, out_badprec]
tc.assertRaises(ValueError,pycbc.fft.fft,*args)
# If we give an output array that has the wrong kind (real or complex) but
# correct precision, then raise a ValueError. This only makes sense if we try
# to do either C2R or R2R.
out_badkind = outty(outzer,dtype=_bad_dtype[dtype(inarr).type],**other_args)
args = [inarr,out_badkind,tc.backends]
args = [inarr, out_badkind]
tc.assertRaises(ValueError,pycbc.fft.fft,*args)
# If we give an output array that isn't a PyCBC type, raise TypeError:
out_badtype = numpy.zeros(len(outarr),dtype=outarr.dtype)
args = [inarr,out_badtype,tc.backends]
args = [inarr, out_badtype]
tc.assertRaises(TypeError,pycbc.fft.fft,*args)
# If we give an input array that isn't a PyCBC type, raise TypeError:
in_badtype = numpy.zeros(len(inarr),dtype=inarr.dtype)
args = [in_badtype,outarr,tc.backends]
args = [in_badtype, outarr]
tc.assertRaises(TypeError,pycbc.fft.fft,*args)

def _test_raise_excep_ifft(test_case,inarr,outarr,other_args=None):
Expand All @@ -356,11 +356,11 @@ def _test_raise_excep_ifft(test_case,inarr,outarr,other_args=None):
outzer = pycbc.types.zeros(len(outarr))
# If we give an output array that is wrong only in length, raise ValueError:
out_badlen = outty(pycbc.types.zeros(len(outarr)+1),dtype=outarr.dtype,**other_args)
args = [inarr,out_badlen,tc.backends]
args = [inarr, out_badlen]
tc.assertRaises(ValueError,pycbc.fft.ifft,*args)
# If we give an output array that has the wrong precision, raise ValueError:
out_badprec = outty(outzer,dtype=_other_prec[dtype(outarr).type],**other_args)
args = [inarr,out_badprec,tc.backends]
args = [inarr,out_badprec]
tc.assertRaises(ValueError,pycbc.fft.ifft,*args)
# If we give an output array that has the wrong kind (real or complex) but
# correct precision, then raise a ValueError. Here we must adjust the kind
Expand All @@ -374,17 +374,20 @@ def _test_raise_excep_ifft(test_case,inarr,outarr,other_args=None):
except KeyError:
delta = new_args.pop('delta_f')
new_args.update({'delta_t' : delta})
in_badkind = type(inarr)(pycbc.types.zeros(len(inarr)),dtype=_bad_dtype[dtype(outarr).type],
in_badkind = type(inarr)(pycbc.types.zeros(len(inarr)),
dtype=_bad_dtype[dtype(outarr).type],
**new_args)
args = [in_badkind,outarr,tc.backends]
tc.assertRaises(ValueError,pycbc.fft.ifft,*args)
args = [in_badkind, outarr]
#pycbc.fft.ifft(in_badkind, outarr)
if str(outarr.dtype) not in ['complex64', 'complex128']:
tc.assertRaises((ValueError, KeyError), pycbc.fft.ifft, *args)
# If we give an output array that isn't a PyCBC type, raise TypeError:
out_badtype = numpy.zeros(len(outarr),dtype=outarr.dtype)
args = [inarr,out_badtype,tc.backends]
args = [inarr,out_badtype]
tc.assertRaises(TypeError,pycbc.fft.ifft,*args)
# If we give an input array that isn't a PyCBC type, raise TypeError:
in_badtype = numpy.zeros(len(inarr),dtype=inarr.dtype)
args = [in_badtype,outarr,tc.backends]
args = [in_badtype,outarr]
tc.assertRaises(TypeError,pycbc.fft.ifft,*args)

# The following isn't a helper function, called by several test functions, but it only applies
Expand Down
52 changes: 30 additions & 22 deletions test/test_autochisq.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def setUp(self):
hpt = TimeSeries(h, self.del_t)
self.htilde = make_frequency_series(hpt)



# generate sin-gaussian signal
time = np.arange(0, len(hp))*self.del_t
Expand All @@ -103,28 +102,31 @@ def test_chirp(self):
flow = self.low_frequency_cutoff

with _context:
hautocor, hacorfr, hnrm = matched_filter_core(self.htilde, self.htilde, psd=psd, \
hautocor, hacorfr, hnrm = matched_filter_core(self.htilde, self.htilde, psd=psd, \
low_frequency_cutoff=flow, high_frequency_cutoff=self.fmax)

snr, cor, nrm = matched_filter_core(self.htilde, sig_tilde, psd=psd, \
hautocor = hautocor * float(np.real(1./hautocor[0]))

snr, cor, nrm = matched_filter_core(self.htilde, sig_tilde, psd=psd, \
low_frequency_cutoff=flow, high_frequency_cutoff=self.fmax)

hacor = Array(hautocor.real(), copy=True)
hacor = Array(hautocor, copy=True)

indx = Array(np.array([352250, 352256, 352260]))
indx = np.array([352250, 352256, 352260])

snr = snr*nrm


with _context:
dof, achi_list = autochisq_from_precomputed(snr, cor, hacor, stride=3, num_points=20, \
indices=indx)
dof, achisq, indices= \
autochisq_from_precomputed(snr, snr, hacor, indx, stride=3,
num_points=20)

obt_snr = achi_list[1,1]
obt_ach = achi_list[1,2]
obt_snr = abs(snr[indices[1]])
obt_ach = achisq[1]
self.assertTrue(obt_snr > 10.0 and obt_snr < 12.0)
self.assertTrue(obt_ach < 1.e-3)
self.assertTrue(achi_list[0,2] > 20.0)
self.assertTrue(achi_list[2,2] > 20.0)
self.assertTrue(obt_ach < 2.e-3)
self.assertTrue(achisq[0] > 20.0)
self.assertTrue(achisq[2] > 20.0)


#with _context:
Expand All @@ -150,6 +152,8 @@ def test_sg(self):
with _context:
hautocor, hacorfr, hnrm = matched_filter_core(self.htilde, self.htilde, psd=psd, \
low_frequency_cutoff=flow, high_frequency_cutoff=self.fmax)
hautocor = hautocor * float(np.real(1./hautocor[0]))


snr, cor, nrm = matched_filter_core(self.htilde, sig_tilde, psd=psd, \
low_frequency_cutoff=flow, high_frequency_cutoff=self.fmax)
Expand All @@ -158,18 +162,22 @@ def test_sg(self):

hacor = Array(hautocor.real(), copy=True)

indx = Array(np.array([301440, 301450, 301460]))
indx = np.array([301440, 301450, 301460])

snr = snr*nrm

with _context:
dof, achi_list = autochisq_from_precomputed(snr, cor, hacor, stride=3, num_points=20, \
indices=indx)
obt_snr = achi_list[1,1]
obt_ach = achi_list[1,2]
self.assertTrue(obt_snr > 12.0 and obt_snr < 15.0)
self.assertTrue(obt_ach > 6.8e3)
self.assertTrue(achi_list[0,2] > 6.8e3)
self.assertTrue(achi_list[2,2] > 6.8e3)
dof, achisq, indices= \
autochisq_from_precomputed(snr, snr, hacor, indx, stride=3,
num_points=20)

obt_snr = abs(snr[indices[1]])
obt_ach = achisq[1]
self.assertTrue(obt_snr > 12.0 and obt_snr < 15.0)
self.assertTrue(obt_ach > 6.8e3)
self.assertTrue(achisq[0] > 6.8e3)
self.assertTrue(achisq[2] > 6.8e3)


# with _context:
# dof, achi_list = autochisq(self.htilde, sig_tilde, psd, stride=3, num_points=20, \
Expand Down
2 changes: 1 addition & 1 deletion test/test_fft_unthreaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

# Get our list of backends:

backends = pycbc.fft._all_backends_list
backends = pycbc.fft.get_backend_names()

FFTTestClasses = []
for backend in backends:
Expand Down
2 changes: 1 addition & 1 deletion test/test_fftw_openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

# See if we can get set the FFTW backend to 'openmp'; if not, say so and exit.

if 'fftw' in pycbc.fft._all_backends_list:
if 'fftw' in pycbc.fft.get_backend_names():
import pycbc.fft.fftw
try:
pycbc.fft.fftw.set_threads_backend('openmp')
Expand Down
2 changes: 1 addition & 1 deletion test/test_fftw_pthreads.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

# See if we can get set the FFTW backend to 'pthreads'; if not, say so and exit.

if 'fftw' in pycbc.fft._all_backends_list:
if 'fftw' in pycbc.fft.get_backend_names():
import pycbc.fft.fftw
try:
pycbc.fft.fftw.set_threads_backend('pthreads')
Expand Down
Loading

0 comments on commit 63159d7

Please sign in to comment.