Skip to content

Commit

Permalink
Towards issue #119
Browse files Browse the repository at this point in the history
Add some placeholders for coherence sorting

- frequency_band_helpers.py
  - add some placeholder methods
  - especially Spectrogram class -- a candidate to replace stft_obj in main flow

- kernel_dataset.py:
  - Factor update_survey_metadata method into its own method for better readability

- coherence_weights.py
  - experimental code for Jackkknife coherence weights

- xarray_helpers.py
  - Deprecate unusued cast_3d_stft_to_2d_observations method from xarray_helpers
  - move notes from this method into stack_fcs()
  • Loading branch information
kkappler committed Jan 23, 2024
1 parent c5a5c7d commit 7d7fc8a
Show file tree
Hide file tree
Showing 5 changed files with 506 additions and 138 deletions.
21 changes: 20 additions & 1 deletion aurora/pipelines/transfer_function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,26 @@ def drop_nans(X, Y, RR):


def stack_fcs(X, Y, RR):
"""Reshape 2D arrays of frequency and time to 1D"""
"""
Reshape 2D arrays of frequency and time to 1D
Context:
When the data for a frequency band are extracted from the Spectrogram, each channel
is a 2D array, one axis is time (the time of the window that was FFT-ed) and the
other axis is frequency. However if we make no distinction between the harmonics
(bins) within a band in regression, then all the FCs for each channel can be
put into a 1D array. This method performs that reshaping (ravelling) operation.
**It is not important how we unravel the FCs but it is important that
we use the same scheme for X and Y.
TODO: Make this take a list and return a list rather than X,Y,RR
TODO: Decorate this with @dataset_or_dataarray
if isinstance(X, xr.Dataset):
tmp = X.to_array("channel")
tmp = tmp.stack()
or similar
"""
X = X.stack(observation=("frequency", "time"))
Y = Y.stack(observation=("frequency", "time"))
if RR is not None:
Expand Down
109 changes: 108 additions & 1 deletion aurora/time_series/frequency_band_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def get_band_for_tf_estimate(band, dec_level_config, local_stft_obj, remote_stft_obj):
"""
Get data for TF estimation for a particular band.
Returns spectrograms X, Y, RR for harmonics within the given band
Parameters
----------
Expand Down Expand Up @@ -163,3 +163,110 @@ def check_time_axes_synched(X, Y):
logger.warning("WARNING - NAN Handling could fail if X,Y dont share time axes")
raise Exception
return


def get_band_for_coherence_sorting(
frequency_band,
dec_level_config,
local_stft_obj,
remote_stft_obj,
widening_rule="min3",
):
"""
Just like get_band_for_tf_estimate, but here we enforce some rules so that the band is not one FC wide
- it is possible that this method will get merged with get_band_for_tf_estimate
- this is a placeholder until the appropriate rules are sorted out.
Parameters
----------
band : mt_metadata.transfer_functions.processing.aurora.FrequencyBands
object with lower_bound and upper_bound to tell stft object which
subarray to return
config : mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel
information about the input and output channels needed for TF
estimation problem setup
local_stft_obj : xarray.core.dataset.Dataset or None
Time series of Fourier coefficients for the station whose TF is to be
estimated
remote_stft_obj : xarray.core.dataset.Dataset or None
Time series of Fourier coefficients for the remote reference station
Returns
-------
X, Y, RR : xarray.core.dataset.Dataset or None
data structures as local_stft_object and remote_stft_object, but
restricted only to input_channels, output_channels,
reference_channels and also the frequency axes are restricted to
being within the frequency band given as an input argument.
"""
band = frequency_band.copy()
logger.info(
f"Processing band {band.center_period:.6f}s ({1./band.center_period:.6f}Hz)"
)
stft = Spectrogram(local_stft_obj)
if stft.num_harmonics_in_band(band) == 1:
logger.warning("Cant evaluate coherence with only 1 harmonic")
logger.info(f"Widening band according to {widening_rule} rule")
if widening_rule == "min3":
band.frequency_min -= stft.df
band.frequency_max += stft.df
else:
msg = f"Widening rule {widening_rule} not recognized"
logger.error(msg)
raise NotImplementedError(msg)
# proceed as in
return get_band_for_tf_estimate(
band, dec_level_config, local_stft_obj, remote_stft_obj
)


def cross_spectra(X, Y):
return X.conj() * Y


class Spectrogram(object):
"""
Class to contain methods for STFT objects.
TODO: Add support for cross powers
TODO: Add OLS Z-estimates
TODO: Add Sims/Vozoff Z-estimates
"""

def __init__(self, dataset):
self._dataset = dataset
self._df = None

@property
def dataset(self):
return self._dataset

@property
def df(self):
if self._df is None:
frequency_axis = self.dataset.frequency
self._df = frequency_axis.data[1] - frequency_axis.data[0]
return self._df

def num_harmonics_in_band(self, frequency_band, epsilon=1e-7):
"""
make this a method of STFT() when you make the class
Parameters
----------
band
stft_obj
Returns
-------
"""
cond1 = self._dataset.frequency >= frequency_band.lower_bound - epsilon
cond2 = self._dataset.frequency <= frequency_band.upper_bound + epsilon
num_harmonics = (cond1 & cond2).data.sum()
return num_harmonics

def extract_band(self, frequency_band):
return extract_band(frequency_band, self.dataset, epsilon=1e-7)

def cross_powers(self, ch1, ch2, band=None):
pass
35 changes: 1 addition & 34 deletions aurora/time_series/xarray_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def handle_nan(X, Y, RR, drop_dim=""):
Initial use case is for Fourier coefficients, but could be more general.
Idea is to merge X,Y,RR together, and then call dropna. We have to be careful
with merging becuase there can be namespace clashes in the channel labels.
with merging because there can be namespace clashes in the channel labels.
Currently handling this by relabelling the remote reference channels from for
example "hx"--> "remote_hx", "hy"-->"remote_hy". If needed we could add "local" to
local the other channels in X, Y.
Expand Down Expand Up @@ -84,36 +84,3 @@ def handle_nan(X, Y, RR, drop_dim=""):
RR = RR.rename(data_var_rm_label_mapper)

return X, Y, RR


def cast_3d_stft_to_2d_observations(XY):
"""
When the data for a frequency band are extracted from the STFT and
passed to RegressionEstimator they have a typical STFT structure:
One axis is time (the time of the window that was FFT-ed) and the
other axis is frequency. However we make no distinction between the
harmonics (or bins) within a band. We need to gather all the FCs for
each channel into a 1D array.
This method performs that reshaping (ravelling) operation.
*It is not important how we unravel the FCs but it is important that
we use the same scheme for X and Y.
2021-08-25: Modified this method to use xarray's stack() method.
Parameters
----------
XY: either X or Y of the regression nomenclature. Should be an
xarray.Dataset already splitted on channel
Returns
-------
output_array: numpy array of two dimensions (observations, channel)
"""
if isinstance(XY, xr.Dataset):
tmp = XY.to_array("channel")

tmp = tmp.stack(observation=("frequency", "time"))
return tmp
# output_array = tmp.data.T
# return output_array
58 changes: 40 additions & 18 deletions aurora/transfer_function/kernel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,35 @@ def sample_rate(self):
sample_rate = self.df.sample_rate.unique()[0]
return sample_rate

def update_survey_metadata(self, i, row, run_ts):
"""
Wrangle survey_metadata into kernel_dataset. This needs to be passed to TF before exporting data
This was factored out of initialize_dataframe_for_processing
Parameters
----------
i: integer.
This would be the index of row, if we were sure that the dataframe was cleanly indexed
row: row of kernel_dataset dataframe corresponding to a survey-station-run.
run_ts
Returns
-------
"""
survey_id = run_ts.survey_metadata.id
if i == 0:
self.survey_metadata[survey_id] = run_ts.survey_metadata
elif i > 0:
if row.station_id in self.survey_metadata[survey_id].stations.keys():
self.survey_metadata[survey_id].stations[row.station_id].add_run(
run_ts.run_metadata
)
else:
self.survey_metadata[survey_id].add_station(run_ts.station_metadata)
if len(self.survey_metadata.keys()) > 1:
raise NotImplementedError

def initialize_dataframe_for_processing(self, mth5_objs):
"""
Adds extra columns needed for processing, populates them with mth5 objects,
Expand All @@ -363,12 +392,16 @@ def initialize_dataframe_for_processing(self, mth5_objs):
Note #1: When assigning xarrays to dataframe cells, df dislikes xr.Dataset,
so we convert to xr.DataArray before packing df
Note #2: [OPTIMIZATION] By accesssing the run_ts and packing the "run_dataarray" column of the df with it, we
Note #2: [OPTIMIZATION] By accessing the run_ts and packing the "run_dataarray" column of the df, we
perform a non-lazy operation, and essentially forcing the entire decimation_level=0 dataset to be
loaded into memory. Seeking a lazy method to handle this maybe worthwhile. For example, using
a df.apply() approach to initialize only ione row at a time would allow us to gernerate the FCs one
a df.apply() approach to initialize only one row at a time would allow us to generate the FCs one
row at a time and never ingest more than one run of data at a time ...
Note #3: Uncommenting the continue statement here is desireable, will speed things up, but
is not yet tested. A nice test would be to have two stations, some runs having FCs built
and others not having FCs built. What goes wrong is in update_survey_metadata.
Need a way to get the survey metadata from a run, not a run_ts if possible
Parameters
----------
Expand All @@ -384,27 +417,16 @@ def initialize_dataframe_for_processing(self, mth5_objs):
self.df["run_reference"].at[i] = run_obj.hdf5_group.ref

if row.fc:
msg = f"row {row} already has fcs prescribed by processing confg "
msg += "-- skipping time series initialzation"
msg = f"row {row} already has fcs prescribed by processing config"
msg += "-- skipping time series initialisation"
logger.info(msg)
# continue
# see Note #3
# continue
# the line below is not lazy, See Note #2
run_ts = run_obj.to_runts(start=row.start, end=row.end)
self.df["run_dataarray"].at[i] = run_ts.dataset.to_array("channel")

# wrangle survey_metadata into kernel_dataset
survey_id = run_ts.survey_metadata.id
if i == 0:
self.survey_metadata[survey_id] = run_ts.survey_metadata
elif i > 0:
if row.station_id in self.survey_metadata[survey_id].stations.keys():
self.survey_metadata[survey_id].stations[row.station_id].add_run(
run_ts.run_metadata
)
else:
self.survey_metadata[survey_id].add_station(run_ts.station_metadata)
if len(self.survey_metadata.keys()) > 1:
raise NotImplementedError
self.update_survey_metadata(i, row, run_ts)

logger.info("Dataset dataframe initialized successfully")

Expand Down
Loading

0 comments on commit 7d7fc8a

Please sign in to comment.