Skip to content

Commit

Permalink
Get pycbc_inspiral and chisq.py to vary number of chisq bins
Browse files Browse the repository at this point in the history
Allow input of a general formula in terms of template parameters
which will be rounded down to an integer
Also rework the interface from pycbc_inspiral to event manager to
make it easier for field names and types to line up correctly

Conflicts:
	pycbc/vetoes/chisq.py
  • Loading branch information
tdent committed Feb 10, 2015
1 parent 75af3c4 commit c71e3ff
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 51 deletions.
60 changes: 42 additions & 18 deletions bin/pycbc_inspiral
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,18 @@ parser.add_argument("--cluster-method", choices=["template", "window"],
help="FIXME: ADD")
parser.add_argument("--cluster-window", type=float, default = -1,
help="Length of clustering window in seconds")

parser.add_argument("--maximization-interval", type=float, default=0,
help="Maximize triggers over the template bank (ms)")

parser.add_argument("--bank-veto-bank-file", type=str, help="FIXME: ADD")

parser.add_argument("--chisq-bins", type=int, default=0, help="FIXME: ADD")
parser.add_argument("--chisq-bins", default=0, help=
"Number of frequency bins to use for power chisq. Specify"
" an integer for a constant number of bins, or a function "
"of template params. Math functions are allowed, ex. "
"'10./math.sqrt((params.mass1+params.mass2)/100.)'. "
"Non-integer values will be rounded down.")
parser.add_argument("--chisq-threshold", type=float, default=0,
help="FIXME: ADD")
parser.add_argument("--chisq-delta", type=float, default=0, help="FIXME: ADD")
Expand Down Expand Up @@ -107,7 +113,7 @@ strain_segments = strain.StrainSegments.from_cli(opt, gwstrain)
with ctx:
fft.from_cli(opt)

flow = opt.low_frequency_cutoff
flow = opt.low_frequency_cutoff
flen = strain_segments.freq_len
tlen = strain_segments.time_len
delta_f = strain_segments.delta_f
Expand Down Expand Up @@ -137,9 +143,27 @@ with ctx:
for seg in segments:
seg /= psd

names = ['time_index', 'snr', 'chisq', 'bank_chisq', 'cont_chisq']
# storage for values and types to be passed to event manager
out_types = {
'time_index' : int,
'snr' : complex64,
'chisq' : float32,
'chisq_dof' : int,
'bank_chisq' : float32,
'cont_chisq' : float32
}
out_vals = {
'time_index' : None,
'snr' : None,
'chisq' : None,
'chisq_dof' : None,
'bank_chisq' : None,
'cont_chisq' : None
}
names = sorted(out_vals.keys())

event_mgr = events.EventManager(opt, names,
[int, complex64, float32, float32, float32, float32], psd=psd)
[out_types[n] for n in names], psd=psd)

logging.info("Read in template bank")
bank = waveform.FilterBank(opt.bank_file, opt.approximant, flen, delta_f,
Expand All @@ -158,32 +182,32 @@ with ctx:
(t_num + 1, len(bank), s_num + 1, len(segments)))

snr, norm, corr, idx, snrv = \
matched_filter.matched_filter_and_cluster(template, stilde, cluster_window)
matched_filter.matched_filter_and_cluster(template, stilde,
cluster_window)

if not len(idx):
continue

bank_chisqv = bank_chisq.values(template, s_num, snr, norm,
idx+stilde.analyze.start)
power_chisqv = power_chisq.values(corr, snr, norm, psd, snrv,
idx+stilde.analyze.start,
template, bank, flow)
out_vals['bank_chisq'] = bank_chisq.values(template, s_num, snr,
norm, idx+stilde.analyze.start)
out_vals['chisq'], out_vals['chisq_dof'] = \
power_chisq.values(corr, snr, norm, psd, snrv,
idx+stilde.analyze.start, template, bank, flow)

snrv *= norm

if opt.autochi_number_points:
auto_chisqv = autochisq.values(snr*norm, corr, hautocorr=None,
indices=idx+stilde.analyze.start, template=template, psd=psd,
low_frequency_cutoff=flow)
auto_chisqv = autochisq.values(snr*norm, corr, hautocorr=None,
indices=idx+stilde.analyze.start, template=template,
psd=psd, low_frequency_cutoff=flow)
out_vals['cont_chisq'] = auto_chisqv[:,2]

idx += stilde.cumulative_index

if (opt.autochi_number_points):
vals = [idx, snrv, power_chisqv, bank_chisqv, auto_chisqv[:, 2]]
else:
vals = [idx, snrv, power_chisqv, bank_chisqv, None]
out_vals['time_index'] = idx
out_vals['snr'] = snrv

event_mgr.add_template_events(names, vals)
event_mgr.add_template_events(names, [out_vals[n] for n in names])

event_mgr.cluster_template_events("time_index", "snr", cluster_window)
event_mgr.finalize_template_events()
Expand Down
39 changes: 29 additions & 10 deletions pycbc/events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def findchirp_cluster_over_window(times, values, window_length):
return indices[0:j+1]

def newsnr(snr, reduced_x2):
"""Calculate the re-weighted SNR statistic known as NewSNR from given
SNR and reduced chi-squared values. See http://arxiv.org/abs/1208.3491
for a definition of NewSNR.
"""Calculate the re-weighted SNR statistic ('newSNR') from given SNR and
reduced chi-squared values. See http://arxiv.org/abs/1208.3491 for
definition.
"""
newsnr = numpy.array(snr, ndmin=1)
reduced_x2 = numpy.array(reduced_x2, ndmin=1)
Expand Down Expand Up @@ -159,17 +159,18 @@ def __init__(self, opt, column, column_types, **kwds):
def chisq_threshold(self, value, num_bins, delta=0):
remove = []
for i, event in enumerate(self.events):
xi = event['chisq'] / (num_bins + delta * event['snr'].conj() * event['snr'])
xi = event['chisq'] / (event['chisq_dof'] + delta * event['snr'].conj() * event['snr'])
if xi > value:
remove.append(i)
self.events = numpy.delete(self.events, remove)

def newsnr_threshold(self, threshold):
"Remove events with newsnr smaller than given threshold"
""" Remove events with newsnr smaller than given threshold
"""
if not self.opt.chisq_bins:
raise RuntimeError('Chi-square test must be enabled in order to use newsnr threshold')
x2_dof = 2 * self.opt.chisq_bins - 2
remove = [i for i, e in enumerate(self.events) if newsnr(abs(e['snr']), e['chisq'] / x2_dof) < threshold]
remove = [i for i, e in enumerate(self.events) if \
newsnr(abs(e['snr']), e['chisq'] / (2 * e['chisq_dof'] - 2)) < threshold]
self.events = numpy.delete(self.events, remove)

def maximize_over_bank(self, tcolumn, column, window):
Expand Down Expand Up @@ -217,7 +218,15 @@ def maximize_over_bank(self, tcolumn, column, window):

def add_template_events(self, columns, vectors):
""" Add a vector indexed """
new_events = numpy.zeros(len(vectors[0]), dtype=self.event_dtype)
# initialize with zeros - since vectors can be None, look for the
# first one that isn't
new_events = None
for v in vectors:
if v is not None:
new_events = numpy.zeros(len(v), dtype=self.event_dtype)
break
# they shouldn't all be None
assert new_events is not None
new_events['template_id'] = self.template_index
for c, v in zip(columns, vectors):
if v is not None:
Expand Down Expand Up @@ -294,10 +303,20 @@ def write_events(self, outname):
if self.opt.chisq_bins != 0:
# FIXME: This is *not* the dof!!!
# but is needed for later programs not to fail
row.chisq_dof = self.opt.chisq_bins
row.chisq = event['chisq']
try:
# if the options specify an integer, use it and check
# that the value can be cast as an int without changing
# numerically
row.chisq_dof = int(self.opt.chisq_bins)
assert row.chisq_dof == self.opt.chisq_bins
except:
# fail through: copy the value from the trigger
row.chisq_dof = event['chisq_dof']
#row.chisq_dof = self.opt.chisq_bins
row.chisq = event['chisq']

if hasattr(self.opt, 'bank_veto_bank_file') and self.opt.bank_veto_bank_file:
# EXPLAINME - is this a hard-coding? Certainly looks like one
row.bank_chisq_dof = 10
row.bank_chisq = event['bank_chisq']
else:
Expand Down
63 changes: 42 additions & 21 deletions pycbc/vetoes/chisq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
#
# =============================================================================
#
import numpy
import logging
import numpy, logging, math, pycbc.fft

import pycbc.fft

from pycbc.types import Array, zeros, real_same_precision_as, TimeSeries, complex_same_precision_as, FrequencySeries
Expand Down Expand Up @@ -85,11 +85,11 @@ def power_chisq_bins(htilde, num_bins, psd, low_frequency_cutoff=None,
bins: List of ints
A list of the edges of the chisq bins is returned.
"""
sigma_vec = sigmasq_series(htilde, psd, low_frequency_cutoff,
high_frequency_cutoff).numpy()
kmin, kmax = get_cutoff_indices(low_frequency_cutoff,
sigma_vec = sigmasq_series(htilde, psd, low_frequency_cutoff,
high_frequency_cutoff).numpy()
kmin, kmax = get_cutoff_indices(low_frequency_cutoff,
high_frequency_cutoff,
htilde.delta_f,
htilde.delta_f,
(len(htilde)-1)*2)
return power_chisq_bins_from_sigmasq_series(sigma_vec, num_bins, kmin, kmax)

Expand Down Expand Up @@ -130,7 +130,7 @@ def power_chisq_at_points_from_precomputed(corr, snr, snr_norm, bins, indices):
An array containing only the chisq at the selected points.
"""
logging.info('doing fast point chisq')
snr = Array(snr, copy=False)
snr = Array(snr, copy=False)
num_bins = len(bins) - 1

chisq = shift_sum(corr, indices, bins)
Expand All @@ -154,7 +154,7 @@ def power_chisq_from_precomputed(corr, snr, snr_norm, bins, indices=None):
snr: TimeSeries
The unnormalized snr time series.
snr_norm:
The snr normalization factor (true snr = snr * snr_norm) EXPLAINME - define 'true snr'? refer to FindChirp ?
The snr normalization factor (true snr = snr * snr_norm) EXPLAINME - define 'true snr'?
bins: List of integers
The edges of the chisq bins.
indices: {Array, None}, optional
Expand Down Expand Up @@ -208,12 +208,12 @@ def power_chisq_from_precomputed(corr, snr, snr_norm, bins, indices=None):
return chisq
else:
return TimeSeries(chisq, delta_t=snr.delta_t, epoch=snr.start_time, copy=False)

def fastest_power_chisq_at_points(corr, snr, snrv, snr_norm, bins, indices):
"""Calculate the chisq values for only selected points.
This function looks at the number of points to be evaluated and selects
the fastest method (FFT, or direct time shift and sum). In either case,
This function looks at the number of points to be evaluated and selects
the fastest method (FFT, or direct time shift and sum). In either case,
only the selected points are returned.
Parameters
Expand All @@ -238,7 +238,7 @@ def fastest_power_chisq_at_points(corr, snr, snrv, snr_norm, bins, indices):
import pycbc.scheme
if isinstance(pycbc.scheme.mgr.state, pycbc.scheme.CPUScheme):
# We don't have that many points so do the direct time shift.
return power_chisq_at_points_from_precomputed(corr, snrv,
return power_chisq_at_points_from_precomputed(corr, snrv,
snr_norm, bins, indices)
else:
# We have a lot of points so it is faster to use the fourier transform
Expand All @@ -256,15 +256,15 @@ def power_chisq(template, data, num_bins, psd,
data: FrequencySeries or TimeSeries
A time or frequency series that contains the data to filter. The length
must be commensurate with the template.
EXPLAINME - does this mean 'the same as' or something else?
--- EXPLAINME - does this mean 'the same as' or something else?
num_bins: int
The number of bins in the chisq. Note that the dof goes as 2*num_bins-2.
psd: FrequencySeries
The psd of the data.
low_frequency_cutoff: {None, float}, optional
The low frequency cutoff to apply.
The low frequency cutoff for the filter
high_frequency_cutoff: {None, float}, optional
The high frequency cutoff to apply.
The high frequency cutoff for the filter
Returns
-------
Expand All @@ -278,7 +278,7 @@ def power_chisq(template, data, num_bins, psd,
high_frequency_cutoff)
corra = zeros((len(htilde)-1)*2, dtype=htilde.dtype)
total_snr, corr, tnorm = matched_filter_core(htilde, stilde, psd,
low_frequency_cutoff, high_frequency_cutoff, aa
low_frequency_cutoff, high_frequency_cutoff,
corr_out=corra)

return power_chisq_from_precomputed(corr, total_snr, tnorm, bins)
Expand All @@ -288,25 +288,45 @@ class SingleDetPowerChisq(object):
"""Class that handles precomputation and memory management for efficiently
running the power chisq in a single detector inspiral analysis.
"""
def __init__(self, num_bins):
if num_bins > 0:
def __init__(self, num_bins=0):
if not (num_bins == 0):
self.do = True
self.column_name = "chisq"
self.table_dof_name = "chisq_dof"
self.dof = num_bins * 2 - 2
self.num_bins = num_bins

self._num_bins = num_bins
# internal values to store parameters between computations
self._num_bins = None
self._bins = None
self._template = None
else:
self.do = False

@staticmethod
def parse_option(row, arg):
safe_dict = {}
safe_dict.update(row.__dict__)
safe_dict.update(math.__dict__)
return eval(arg, {"__builtins__":None}, safe_dict)

def values(self, corr, snr, snr_norm, psd, snrv, indices, template, bank,
low_frequency_cutoff):
"""FIXME: Document this function more fully
Returns
-------
chisq: Array
Chisq values, one for each sample index
chisq_dof: Array
Numbers of frequency bins corresponding to the chisq values
"""
if self.do:
# Compute the chisq bins if we haven't already
# Only recompute the bins if the template changes
if self._template is None or self._template != template:
# determine number of bins by parsing the option
self._num_bins = int(self.parse_option(template, self.num_bins))
if bank.sigmasq_vec is not None:
logging.info("...Calculating fast power chisq bins")
kmin = int(low_frequency_cutoff / corr.delta_f)
Expand All @@ -321,4 +341,5 @@ def values(self, corr, snr, snr_norm, psd, snrv, indices, template, bank,
self._bins = bins

logging.info("...Doing power chisq")
return fastest_power_chisq_at_points(corr, snr, snrv, snr_norm, self._bins, indices)
return (fastest_power_chisq_at_points(corr, snr, snrv, snr_norm,
self._bins, indices), self._num_bins * numpy.ones_like(indices))
2 changes: 1 addition & 1 deletion test/long/pycbc_inspiral/pycbc_inspiral
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ with ctx:

logging.info("%s points above threshold" % str(len(idx)))
bank_chisqv = bank_chisq.values(template, s_num, snr, norm, idx+snr_start)
power_chisqv = power_chisq.values(corr, snr, norm, psd, idx+snr_start, template, bank, opt.low_frequency_cutoff)
power_chisqv, chisqdof = power_chisq.values(corr, snr, norm, psd, idx+snr_start, template, bank, opt.low_frequency_cutoff)

snrv *= norm
idx += (cumulative_index - seg_width_idx + (snr_start - opt.segment_start_pad * opt.sample_rate))
Expand Down
2 changes: 1 addition & 1 deletion test/long/pycbc_inspiral_triggers/pycbc_inspiral
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ with ctx:

logging.info("%s points above threshold" % str(len(idx)))
bank_chisqv = bank_chisq.values(template, s_num, snr, norm, idx+snr_start)
power_chisqv = power_chisq.values(corr, snr, norm, psd, idx+snr_start, template, bank, opt.low_frequency_cutoff)
power_chisqv, chisqdof = power_chisq.values(corr, snr, norm, psd, idx+snr_start, template, bank, opt.low_frequency_cutoff)

snrv *= norm
idx += (cumulative_index - seg_width_idx + (snr_start - opt.segment_start_pad * opt.sample_rate))
Expand Down

0 comments on commit c71e3ff

Please sign in to comment.