Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
Arthur Corstanje committed Nov 16, 2023
1 parent b2b134b commit e4fb8d0
Showing 1 changed file with 38 additions and 101 deletions.
139 changes: 38 additions & 101 deletions NuRadioReco/modules/LOFAR/stationRFIFilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def FindRFI_LOFAR(
target_trace_length=65536,
rfi_cleaning_trace_length=8192,
flagged_antenna_ids=[],
num_dbl_z=1000,
):
"""
A code that basically reads given LOFAR TBB H5 file and returns an array of dirty channels.
Expand Down Expand Up @@ -122,94 +123,20 @@ def FindRFI_LOFAR(
# FIXME: -- should be fixed, needs explicit testing -- what if one bad antenna in Station? Currently FindRFI crashes

logger.info(f"Running find RFI with {num_blocks} blocks")
output = FindRFI(tbb_file, rfi_cleaning_trace_length, 0, int(num_blocks),
lower_frequency=lower_frequency_bound, upper_frequency=upper_frequency_bound,
num_dbl_z=1000, flagged_antenna_ids=flagged_antenna_ids)

tbb_file.close_file()

# Calculate the required output variables
avg_power_spectrum = np.sum(output["ave_spectrum_magnitude"], axis=0)
avg_antenna_power = output["ave_spectrum_magnitude"]
cleaned_power = output["cleaned_power"]
antenna_names = output["antenna_names"]
dirty_channels = output["dirty_channels"][0]
multiplied_channels = []
multiplied_blocks = target_trace_length // rfi_cleaning_trace_length
for ch in dirty_channels: # TODO: could be done more efficiently...?
this_channel_list = np.arange(multiplied_blocks * ch, multiplied_blocks * ch + multiplied_blocks, 1)
this_channel_list = list(this_channel_list)
multiplied_channels.extend(this_channel_list)

dirty_channels = np.sort(np.array(multiplied_channels))
dirty_channels_block_size = target_trace_length

return avg_power_spectrum, dirty_channels, dirty_channels_block_size, antenna_names, avg_antenna_power, cleaned_power

############

# TODO: add description of algorithm to notes section
def FindRFI(
tbb_file,
block_size,
initial_block,
num_blocks,
max_blocks=None,
lower_frequency=10e6*units.Hz,
upper_frequency=90e6*units.Hz,
num_dbl_z=100,
flagged_antenna_ids = []
):
"""
A function to find RFI in data using phase-variance.
Parameters
----------
tbb_file : MultiFile_Dal1
File pointer to TBB encompassing the data for one station.
block_size : int
The size of a single block of which to calculate the spectrum. Should be around 65536 (2^16).
initial_block : int
Initial block should be such that there is no lightning in the max_blocks number of blocks.
num_blocks : int
Number of blocks to use for phase-variance detection. Should be at least 20.
max_blocks : int, default=None
Sometimes a block needs to be skipped, so max_blocks shows the maximum number of blocks used
(after initial block) used to find num_blocks number of good blocks. If max_blocks is None (default),
it is set to `num_blocks`.
lower_frequency : float, default=10e6
The lower end of the frequency band to consider.
upper_frequency : float, default=90e6
The higher end of the frequency band to consider.
num_dbl_z : int, default=100
The number of double zeros allowed in a block, if there are too many, then there could be data loss.
Returns
-------
output_dict : dict
A dictionary with the following key-value pairs:
* "ave_spectrum_magnitude": array that contains the average of the magnitude of the frequency spectrum
* "ave_spectrum_phase": array containing the average of the phase of the frequency spectrum
* "phase_variance": array containing the phase variance of each frequency channel
* "dirty_channels": array of indices indicating the channels that are contaminated with RFI
Notes
-----
The algorithm compares the phase stability of each frequency channel between a reference antenna and every
other antenna in the station. If the phase is stable, this indicates a constant source contaminating the data.
More information can be found in Section 3.2.2 of `this paper <https://arxiv.org/pdf/1311.1399.pdf>`_ .
"""
initial_block = 0
num_blocks = int(num_blocks) # make sure the number of blocks in an integer (as its used as a shape parameter)
max_blocks = num_blocks

if max_blocks is None:
max_blocks = num_blocks

window_function = half_hann_window(block_size, 0.1)
window_function = half_hann_window(rfi_cleaning_trace_length, 0.1)
antenna_ids = tbb_file.get_antenna_names()
antenna_ids = [id for id in antenna_ids if id not in flagged_antenna_ids]
num_antennas = len(antenna_ids)

# step one: find which blocks are good, and find average power
oneAnt_data = np.zeros(block_size, dtype=np.double) # initialize at zero
oneAnt_data = np.zeros(rfi_cleaning_trace_length, dtype=np.double) # initialize at zero

logger.info("finding good blocks")
blocks_good = np.zeros((num_antennas, max_blocks), dtype=bool)
Expand All @@ -221,13 +148,13 @@ def FindRFI(
for ant_i in range(num_antennas):
try:
oneAnt_data[:] = tbb_file.get_data(
block_size * block, block_size, antenna_ID=antenna_ids[ant_i]
rfi_cleaning_trace_length * block, rfi_cleaning_trace_length, antenna_ID=antenna_ids[ant_i]
)
except: # TODO: more specific exception
logger.warning('Could not read data for antenna %s block %d' % (antenna_ids[ant_i], block_i))
# proceed with zeros in the block
#oneAnt_data[:] = tbb_file.get_data(
# block_size * block, block_size, antenna_index=ant_i
# rfi_cleaning_trace_length * block, rfi_cleaning_trace_length, antenna_index=ant_i
#)
if (
num_double_zeros(oneAnt_data) < num_dbl_z
Expand Down Expand Up @@ -286,13 +213,13 @@ def FindRFI(

# Process data
num_processed_blocks = np.zeros(num_antennas, dtype=int)
frequencies = np.fft.fftfreq(block_size, 1.0 / tbb_file.get_sample_frequency())
frequencies = np.fft.fftfreq(rfi_cleaning_trace_length, 1.0 / tbb_file.get_sample_frequency())
frequencies *= units.Hz
lower_frequency_index = np.searchsorted(
frequencies[: int(len(frequencies) / 2)], lower_frequency
frequencies[: int(len(frequencies) / 2)], lower_frequency_bound
)
upper_frequency_index = np.searchsorted(
frequencies[: int(len(frequencies) / 2)], upper_frequency
frequencies[: int(len(frequencies) / 2)], upper_frequency_bound
)

phase_mean = np.zeros(
Expand All @@ -316,7 +243,7 @@ def FindRFI(
):
continue
oneAnt_data[:] = tbb_file.get_data(
block_size * block, block_size, antenna_index=ant_i
rfi_cleaning_trace_length * block, rfi_cleaning_trace_length, antenna_index=ant_i
)

# Window the data
Expand Down Expand Up @@ -389,7 +316,7 @@ def FindRFI(

# Extend dirty channels by some size, in order to account for shoulders
extend_dirty_channels = np.zeros(N, dtype=bool)
half_flagwidth = int(block_size / 8192)
half_flagwidth = int(rfi_cleaning_trace_length / 8192)
for i in dirty_channels:
flag_min = i - half_flagwidth
flag_max = i + half_flagwidth
Expand All @@ -405,25 +332,35 @@ def FindRFI(
# plot and return data
frequencies = frequencies[lower_frequency_index:upper_frequency_index]

output_dict = {
"ave_spectrum_magnitude": spectrum_mean,
"ave_spectrum_phase": np.angle(phase_mean, deg=False),
"phase_variance": phase_stability,
"dirty_channels": dirty_channels + lower_frequency_index,
"blocksize": block_size,
}
ave_spectrum_magnitude = spectrum_mean

cleaned_spectrum = np.array(spectrum_mean)
cleaned_spectrum[:, dirty_channels] = 0.0
output_dict["cleaned_spectrum_magnitude"] = cleaned_spectrum
output_dict["cleaned_power"] = 2 * np.sum(cleaned_spectrum, axis=1)

output_dict["antenna_names"] = antenna_ids # where flagged antennas have been removed from (above)
output_dict["timestamp"] = tbb_file.get_timestamp()
output_dict["antennas_good"] = antenna_is_good
output_dict["frequency"] = frequencies
tbb_file.close_file()

# Calculate the required output variables
avg_power_spectrum = np.sum(ave_spectrum_magnitude, axis=0)
avg_antenna_power = ave_spectrum_magnitude

cleaned_power = 2 * np.sum(cleaned_spectrum, axis=1)
antenna_names = antenna_ids
dirty_channels += lower_frequency_index # = output["dirty_channels"][0]
dirty_channels = dirty_channels[0]
multiplied_channels = []
multiplied_blocks = target_trace_length // rfi_cleaning_trace_length
for ch in dirty_channels: # TODO: could be done more efficiently...?
this_channel_list = np.arange(multiplied_blocks * ch, multiplied_blocks * ch + multiplied_blocks, 1)
this_channel_list = list(this_channel_list)
multiplied_channels.extend(this_channel_list)

dirty_channels = np.sort(np.array(multiplied_channels))
dirty_channels_block_size = target_trace_length

return avg_power_spectrum, dirty_channels, dirty_channels_block_size, antenna_names, avg_antenna_power, cleaned_power


return output_dict
# TODO: add description of algorithm to notes section


# TODO: -- stationRFI filter works on station selection made earlier at event read-in -- make stationRFIFilter take keyword to only process certain stations?
Expand All @@ -436,7 +373,7 @@ class stationRFIFilter:
**Note**: currently the class uses hardcoded values for LOFAR, this needs to be improved later.
"""
def __init__(self):
self.logger = logging.getLogger('NuRadioReco.channelRFIFilter')
self.logger = logger # logging.getLogger('NuRadioReco.channelRFIFilter')

self.__rfi_trace_length = None

Expand Down

0 comments on commit e4fb8d0

Please sign in to comment.