diff --git a/aurora/pipelines/fourier_coefficients.py b/aurora/pipelines/fourier_coefficients.py deleted file mode 100644 index 8a4ec355..00000000 --- a/aurora/pipelines/fourier_coefficients.py +++ /dev/null @@ -1,377 +0,0 @@ -""" -Supporting codes for building the FC level of the mth5 - -Here are the parameters that are defined via the mt_metadata fourier coefficients structures: - -"bands", -"decimation.anti_alias_filter": "default", -"decimation.factor": 4.0, -"decimation.level": 2, -"decimation.method": "default", -"decimation.sample_rate": 0.0625, -"stft.per_window_detrend_type": "linear", -"stft.prewhitening_type": "first difference", -"stft.window.clock_zero_type": "ignore", -"stft.window.num_samples": 128, -"stft.window.overlap": 32, -"stft.window.type": "boxcar" - -Creating the decimations config requires a decision about decimation factors and the number of levels. -We have been getting this from the EMTF band setup file by default. It is desirable to continue supporting this, -however, note that the EMTF band setup is really about a time series operation, and not about making STFTs. - -For the record, here is the legacy decimation config from EMTF, a.k.a. decset.cfg: -``` -4 0 # of decimation level, & decimation offset -128 32. 1 0 0 7 4 32 1 -1.0 -128 32. 4 0 0 7 4 32 4 -.2154 .1911 .1307 .0705 -128 32. 4 0 0 7 4 32 4 -.2154 .1911 .1307 .0705 -128 32. 4 0 0 7 4 32 4 -.2154 .1911 .1307 .0705 -``` - -This essentially corresponds to a "Decimations Group" which is a list of decimations. -Related to the generation of FCs is the ARMA prewhitening (Issue #60) which was controlled in -EMTF with pwset.cfg -4 5 # of decimation level, # of channels -3 3 3 3 3 -3 3 3 3 3 -3 3 3 3 3 -3 3 3 3 3 - -Note 1: Assumes application of cascading decimation, and that the -decimated data will be accessed from the previous decimation level. - -Note 2: We can encounter cases where some runs can be decimated and others can not. -We need a way to handle this. For example, a short run may not yield any data from a -later decimation level. An attempt to handle this has been made in TF Kernel by -adding a is_valid_dataset column, associated with each run-decimation level pair. - -Note 3: This point in the loop marks the interface between _generation_ of the FCs and - their _usage_. In future the code above this comment would be pushed into - create_fourier_coefficients() and the code below this would access those FCs and - execute compute_transfer_function() - - -""" - -# ============================================================================= -# Imports -# ============================================================================= - -import mth5.mth5 -import pathlib -import xarray as xr - -from aurora.pipelines.time_series_helpers import calibrate_stft_obj -from aurora.pipelines.time_series_helpers import prototype_decimate -from aurora.pipelines.time_series_helpers import run_ts_to_stft_scipy -from loguru import logger -from mth5.mth5 import MTH5 -from mth5.utils.helpers import path_or_mth5_object -from mth5.groups.fourier_coefficients import FCDecimationGroup -from mt_metadata.timeseries.time_period import TimePeriod -from mt_metadata.transfer_functions.processing.fourier_coefficients import ( - Decimation as FCDecimation, -) -from typing import List, Optional, Union - -# ============================================================================= -GROUPBY_COLUMNS = ["survey", "station", "sample_rate"] - - -def fc_decimations_creator( - initial_sample_rate: float, - decimation_factors: Optional[list] = None, - max_levels: Optional[int] = 6, - time_period: Optional[TimePeriod] = None, -) -> List[FCDecimation]: - """ - TODO: move this to mt_metadata / replace with mt_metadata method once moved. - - Creates mt_metadata FCDecimation objects that parameterize Fourier coefficient decimation levels. - - Note 1: This does not yet work through the assignment of which bands to keep. Refer to - mt_metadata.transfer_functions.processing.Processing.assign_bands() to see how this was done in the past - - Parameters - ---------- - initial_sample_rate: float - Sample rate of the "level0" data -- usually the sample rate during field acquisition. - decimation_factors: Optional[list] - The decimation factors that will be applied at each FC decimation level - max_levels: Optional[int] - The maximum number of decimation levels to allow - time_period: Optional[TimePeriod] - Provides the start and end times - - Returns - ------- - fc_decimations: list - Each element of the list is an object of type - mt_metadata.transfer_functions.processing.fourier_coefficients.Decimation, - (a.k.a. FCDecimation). - - The order of the list corresponds the order of the cascading decimation - - No decimation levels are omitted. - - This could be changed in future by using a dict instead of a list, - - e.g. decimation_factors = dict(zip(np.arange(max_levels), decimation_factors)) - - """ - if not decimation_factors: - # msg = "No decimation factors given, set default values to EMTF default values [1, 4, 4, 4, ..., 4]") - # logger.info(msg) - default_decimation_factor = 4 - decimation_factors = max_levels * [default_decimation_factor] - decimation_factors[0] = 1 - - # See Note 1 - fc_decimations = [] - for i_dec_level, decimation_factor in enumerate(decimation_factors): - fc_dec = FCDecimation() - fc_dec.time_series_decimation.level = i_dec_level - fc_dec.id = f"{i_dec_level}" - fc_dec.decimation.factor = decimation_factor - if i_dec_level == 0: - current_sample_rate = 1.0 * initial_sample_rate - else: - current_sample_rate /= decimation_factor - fc_dec.decimation.sample_rate = current_sample_rate - - if time_period: - if isinstance(time_period, TimePeriod): - fc_dec.time_period = time_period - else: - msg = ( - f"Not sure how to assign time_period with type {type(time_period)}" - ) - logger.info(msg) - raise NotImplementedError(msg) - - fc_decimations.append(fc_dec) - - return fc_decimations - - -@path_or_mth5_object -def add_fcs_to_mth5(m: MTH5, fc_decimations: Optional[Union[str, list]] = None) -> None: - """ - Add Fourier Coefficient Levels ot an existing MTH5. - - **Notes:** - - - This module computes the FCs differently than the legacy aurora pipeline. It uses scipy.signal.spectrogram. - There is a test in Aurora to confirm that there are equivalent if we are not using fancy pre-whitening. - - - Nomenclature: "usssr_grouper" is the output of a group-by on unique {survey, station, sample_rate} tuples. - - Parameters - ---------- - m: MTH5 object - The mth5 file, open in append mode. - fc_decimations: Optional[Union[str, list]] - This specifies the scheme to use for decimating the time series when building the FC layer. - None: Just use default (something like four decimation levels, decimated by 4 each time say.) - String: Controlled Vocabulary, values are a work in progress, that will allow custom definition of - the fc_decimations for some common cases. For example, say you have stored already decimated time - series, then you want simply the zeroth decimation for each run, because the decimated time series live - under another run container, and that will get its own FCs. This is experimental. - List: (**UNTESTED**) -- This means that the user thought about the decimations that they want to create and is - passing them explicitly. -- probably will need to be a dictionary actually, since this - would get redefined at each sample rate. - - """ - # Group the channel summary by survey, station, sample_rate - channel_summary_df = m.channel_summary.to_dataframe() - usssr_grouper = channel_summary_df.groupby(GROUPBY_COLUMNS) - logger.debug(f"Detected {len(usssr_grouper)} unique station-sample_rate instances") - - # loop over groups - for (survey, station, sample_rate), usssr_group in usssr_grouper: - msg = f"\n\n\nsurvey: {survey}, station: {station}, sample_rate {sample_rate}" - logger.info(msg) - station_obj = m.get_station(station, survey) - run_summary = station_obj.run_summary - - # Get the FC decimation schemes if not provided - if not fc_decimations: - msg = "FC Decimations not supplied, creating defaults on the fly" - logger.info(f"{msg}") - fc_decimations = fc_decimations_creator( - initial_sample_rate=sample_rate, time_period=None - ) - elif isinstance(fc_decimations, str): - if fc_decimations == "degenerate": - fc_decimations = get_degenerate_fc_decimation(sample_rate) - - # TODO: Make this a function that can be done using df.apply() - for i_run_row, run_row in run_summary.iterrows(): - logger.info( - f"survey: {survey}, station: {station}, sample_rate {sample_rate}, i_run_row {i_run_row}" - ) - # Access Run - run_obj = m.from_reference(run_row.hdf5_reference) - - # Set the time period: - # TODO: Should this be over-writing time period if it is already there? - for fc_decimation in fc_decimations: - fc_decimation.time_period = run_obj.metadata.time_period - - # Access the data to Fourier transform - runts = run_obj.to_runts( - start=fc_decimation.time_period.start, - end=fc_decimation.time_period.end, - ) - run_xrds = runts.dataset - - # access container for FCs - fc_group = station_obj.fourier_coefficients_group.add_fc_group( - run_obj.metadata.id - ) - - # If timing corrections were needed they could go here, right before STFT - - for i_dec_level, fc_decimation in enumerate(fc_decimations): - try: - assert i_dec_level == fc_decimation.time_series_decimation.level - except: - msg = "decimation level has unexpected value" - logger.warning(msg) - - if ( - i_dec_level != 0 - ): # TODO: take this number from fc_decimation.time_series_decimation.level - # Apply decimation - ts_decimation = fc_decimation.time_series_decimation - run_xrds = prototype_decimate( - ts_decimation, run_xrds - ) # TODO: replace this with mth5 decimation - - _add_spectrogram_to_mth5( - fc_decimation=fc_decimation, - run_obj=run_obj, - run_xrds=run_xrds, - fc_group=fc_group, - ) - - return - - -def _add_spectrogram_to_mth5( - fc_decimation: FCDecimation, - run_obj: mth5.groups.RunGroup, - run_xrds: xr.Dataset, - fc_group: mth5.groups.FCGroup, -) -> None: - """ - - This function has been factored out of add_fcs_to_mth5. - This is the most atomic level of adding FCs and will be useful as standalone method. - - Parameters - ---------- - fc_decimation : FCDecimation - Metadata about how the decimation level is to be processed - - run_xrds : xarray.core.dataset.Dataset - Time series to be converted to a spectrogram and stored in MTH5. - - Returns - ------- - run_xrds : xarray.core.dataset.Dataset - pre-whitened time series - - """ - - # check if this decimation level yields a valid spectrogram - if not fc_decimation.is_valid_for_time_series_length(run_xrds.time.shape[0]): - logger.info( - f"Decimation Level {fc_decimation.time_series_decimation.level} invalid, TS of {run_xrds.time.shape[0]} samples too short" - ) - return - - stft_obj = run_ts_to_stft_scipy(fc_decimation, run_xrds) - stft_obj = calibrate_stft_obj(stft_obj, run_obj) - - # Pack FCs into h5 and update metadata - fc_decimation_group: FCDecimationGroup = fc_group.add_decimation_level( - f"{fc_decimation.time_series_decimation.level}", - decimation_level_metadata=fc_decimation, - ) - fc_decimation_group.from_xarray( - stft_obj, fc_decimation_group.metadata.decimation.sample_rate - ) - fc_decimation_group.update_metadata() - fc_group.update_metadata() - - -def get_degenerate_fc_decimation(sample_rate: float) -> list: - """ - - Makes a default fc_decimation list. WIP - This "degnerate" config will only operate on the first decimation level. - This is useful for testing. It could also be used in future on an MTH5 stored time series in decimation - levels already as separate runs. - - Parameters - ---------- - sample_rate: float - The sample rate assocaiated with the time-series to convert to Spectrogram - - Returns - ------- - output: list - List has only one element which is of type FCDecimation, aka. - mt_metadata.transfer_functions.processing.fourier_coefficients.Decimation. - """ - output = fc_decimations_creator( - sample_rate, - decimation_factors=[ - 1, - ], - max_levels=1, - ) - return output - - -# TODO: Delete after mth5 issue #271 is closed and merged. -@path_or_mth5_object -def read_back_fcs(m: Union[MTH5, pathlib.Path, str], mode: str = "r") -> None: - """ - - This is a helper function for tests. It was used as a sanity check while debugging the FC files, and - also is a good example for how to access the data at each level for each channel. - - The Time axis of the FC array will change from level to level, but the frequency axis will stay the same shape - (for now -- storing all fcs by default) - - Parameters - ---------- - m: Union[MTH5, pathlib.Path, str] - Either a path to an mth5, or an MTH5 object that the FCs will be read back from. - - - """ - channel_summary_df = m.channel_summary.to_dataframe() - logger.debug(channel_summary_df) - usssr_grouper = channel_summary_df.groupby(GROUPBY_COLUMNS) - for (survey, station, sample_rate), usssr_group in usssr_grouper: - logger.info(f"survey: {survey}, station: {station}, sample_rate {sample_rate}") - station_obj = m.get_station(station, survey) - fc_groups = station_obj.fourier_coefficients_group.groups_list - logger.info(f"FC Groups: {fc_groups}") - for run_id in fc_groups: - fc_group = station_obj.fourier_coefficients_group.get_fc_group(run_id) - dec_level_ids = fc_group.groups_list - for dec_level_id in dec_level_ids: - dec_level = fc_group.get_decimation_level(dec_level_id) - xrds = dec_level.to_xarray(["hx", "hy"]) - msg = f"dec_level {dec_level_id}" - msg = f"{msg} \n Time axis shape {xrds.time.data.shape}" - msg = f"{msg} \n Freq axis shape {xrds.frequency.data.shape}" - logger.debug(msg) - - return diff --git a/aurora/pipelines/process_mth5.py b/aurora/pipelines/process_mth5.py index 49f41244..5c7339dc 100644 --- a/aurora/pipelines/process_mth5.py +++ b/aurora/pipelines/process_mth5.py @@ -6,7 +6,6 @@ can be repurposed for other TF estimation schemes. The "legacy" version corresponds to aurora default processing. - Notes on process_mth5_legacy: Note 1: process_mth5 assumes application of cascading decimation, and that the decimated data will be accessed from the previous decimation level. This should be @@ -22,7 +21,7 @@ Note 3: This point in the loop marks the interface between _generation_ of the FCs and their _usage_. In future the code above this comment would be pushed into - create_fourier_coefficients() and the code below this would access those FCs and + the creation of the spectrograms and the code below this would access those FCs and execute compute_transfer_function(). This would also be an appropriate place to place a feature extraction layer, and compute weights for the FCs. @@ -93,6 +92,49 @@ def make_stft_objects(processing_config, i_dec_level, run_obj, run_xrds, units=" ------- stft_obj: xarray.core.dataset.Dataset Time series of calibrated Fourier coefficients per each channel in the run + + Development Notes: + Here are the parameters that are defined via the mt_metadata fourier coefficients structures: + + "bands", + "decimation.anti_alias_filter": "default", + "decimation.factor": 4.0, + "decimation.level": 2, + "decimation.method": "default", + "decimation.sample_rate": 0.0625, + "stft.per_window_detrend_type": "linear", + "stft.prewhitening_type": "first difference", + "stft.window.clock_zero_type": "ignore", + "stft.window.num_samples": 128, + "stft.window.overlap": 32, + "stft.window.type": "boxcar" + + Creating the decimations config requires a decision about decimation factors and the number of levels. + We have been getting this from the EMTF band setup file by default. It is desirable to continue supporting this, + however, note that the EMTF band setup is really about a time series operation, and not about making STFTs. + + For the record, here is the legacy decimation config from EMTF, a.k.a. decset.cfg: + ``` + 4 0 # of decimation level, & decimation offset + 128 32. 1 0 0 7 4 32 1 + 1.0 + 128 32. 4 0 0 7 4 32 4 + .2154 .1911 .1307 .0705 + 128 32. 4 0 0 7 4 32 4 + .2154 .1911 .1307 .0705 + 128 32. 4 0 0 7 4 32 4 + .2154 .1911 .1307 .0705 + ``` + + This essentially corresponds to a "Decimations Group" which is a list of decimations. + Related to the generation of FCs is the ARMA prewhitening (Issue #60) which was controlled in + EMTF with pwset.cfg + 4 5 # of decimation level, # of channels + 3 3 3 3 3 + 3 3 3 3 3 + 3 3 3 3 3 + 3 3 3 3 3 + """ stft_config = processing_config.get_decimation_level(i_dec_level) stft_obj = run_ts_to_stft(stft_config, run_xrds) @@ -278,7 +320,7 @@ def load_stft_obj_from_mth5( """ Load stft_obj from mth5 (instead of compute) - Note #1: See note #1 in time_series.frequency_band_helpers.extract_band + Note #1: See note #1 in mth5.timeseries.spectre.spectrogram.py in extract_band function. Parameters ---------- diff --git a/aurora/pipelines/time_series_helpers.py b/aurora/pipelines/time_series_helpers.py index fd466b5b..03716ca3 100644 --- a/aurora/pipelines/time_series_helpers.py +++ b/aurora/pipelines/time_series_helpers.py @@ -19,180 +19,11 @@ Decimation as FCDecimation, ) from mth5.groups import RunGroup +from mth5.timeseries.spectre.prewhitening import apply_prewhitening +from mth5.timeseries.spectre.prewhitening import apply_recoloring from typing import Literal, Optional, Union -def apply_prewhitening( - decimation_obj: Union[AuroraDecimationLevel, FCDecimation], - run_xrds_input: xr.Dataset, -) -> xr.Dataset: - """ - Applies pre-whitening to time series to avoid spectral leakage when FFT is applied. - - TODO: If "first difference", consider clipping first and last sample from the - differentiated time series. - - Parameters - ---------- - decimation_obj : Union[AuroraDecimationLevel, FCDecimation] - Information about how the decimation level is to be processed - - run_xrds_input : xarray.core.dataset.Dataset - Time series to be pre-whitened - - Returns - ------- - run_xrds : xarray.core.dataset.Dataset - pre-whitened time series - - """ - # TODO: remove this try/except once mt_metadata issue 238 PR is merged - try: - pw_type = decimation_obj.prewhitening_type - except AttributeError: - pw_type = decimation_obj.stft.prewhitening_type - - if not pw_type: - msg = "No prewhitening specified - skipping this step" - logger.info(msg) - return run_xrds_input - - if pw_type == "first difference": - run_xrds = run_xrds_input.differentiate("time") - else: - msg = f"{pw_type} pre-whitening not implemented" - logger.exception(msg) - raise NotImplementedError(msg) - return run_xrds - - -def apply_recoloring( - decimation_obj: Union[AuroraDecimationLevel, FCDecimation], - stft_obj: xr.Dataset, -) -> xr.Dataset: - """ - Inverts the pre-whitening operation in frequency domain. - - Parameters - ---------- - decimation_obj : mt_metadata.transfer_functions.processing.fourier_coefficients.decimation.Decimation - Information about how the decimation level is to be processed - stft_obj : xarray.core.dataset.Dataset - Time series of Fourier coefficients to be recoloured - - - Returns - ------- - stft_obj : xarray.core.dataset.Dataset - Recolored time series of Fourier coefficients - """ - # TODO: remove this try/except once mt_metadata issue 238 PR is merged - try: - pw_type = decimation_obj.prewhitening_type - except AttributeError: - pw_type = decimation_obj.stft.prewhitening_type - - # No recoloring needed if prewhitening not appiled, or recoloring set to False - if not pw_type: - return stft_obj - # TODO Delete after tests (20241220) -- this check has been moved above the call to this function - # if not decimation_obj.recoloring: - # return stft_obj - - if pw_type == "first difference": - # first difference prewhitening correction is to divide by jw - freqs = stft_obj.frequency.data # was freqs = decimation_obj.fft_frequencies - jw = 1.0j * 2 * np.pi * freqs - stft_obj /= jw - - # suppress nan and inf to mute later warnings - if jw[0] == 0.0: - cond = stft_obj.frequency != 0.0 - stft_obj = stft_obj.where(cond, complex(0.0)) - # elif decimation_obj.prewhitening_type == "ARMA": - # from statsmodels.tsa.arima.model import ARIMA - # AR = 3 # add this to processing config - # MA = 4 # add this to processing config - - else: - msg = f"{pw_type} recoloring not yet implemented" - logger.error(msg) - raise NotImplementedError(msg) - - return stft_obj - - -def run_ts_to_stft_scipy( - decimation_obj: Union[AuroraDecimationLevel, FCDecimation], - run_xrds_orig: xr.Dataset, -) -> xr.Dataset: - """ - TODO: Replace with mth5 run_ts_to_stft_scipy method - Converts a runts object into a time series of Fourier coefficients. - This method uses scipy.signal.spectrogram. - - - Parameters - ---------- - decimation_obj : mt_metadata.transfer_functions.processing.aurora.DecimationLevel - Information about how the decimation level is to be processed - Note: This works with FCdecimation and AuroraDecimationLevel becuase test_fourier_coefficients - and test_stft_methods_agree both use them) - Note: Both of these objects are basically spectrogram metadata with provenance for decimation levels. - run_xrds_orig : : xarray.core.dataset.Dataset - Time series to be processed - - Returns - ------- - stft_obj : xarray.core.dataset.Dataset - Time series of Fourier coefficients - """ - run_xrds = apply_prewhitening(decimation_obj, run_xrds_orig) - windowing_scheme = window_scheme_from_decimation( - decimation_obj - ) # TODO: deprecate in favor of stft.window.taper - - stft_obj = xr.Dataset() - for channel_id in run_xrds.data_vars: - ff, tt, specgm = ssig.spectrogram( - run_xrds[channel_id].data, - fs=decimation_obj.decimation.sample_rate, - window=windowing_scheme.taper, - nperseg=decimation_obj.stft.window.num_samples, - noverlap=decimation_obj.stft.window.overlap, - detrend="linear", - scaling="density", - mode="complex", - ) - - # drop Nyquist> - ff = ff[:-1] - specgm = specgm[:-1, :] - specgm *= np.sqrt(2) # compensate energy for keeping only half the spectrum - - # make time_axis - tt = tt - tt[0] - tt *= decimation_obj.decimation.sample_rate - time_axis = run_xrds.time.data[tt.astype(int)] - - xrd = xr.DataArray( - specgm.T, - dims=["time", "frequency"], - coords={"frequency": ff, "time": time_axis}, - ) - stft_obj.update({channel_id: xrd}) - - # TODO : remove try/except after mt_metadata issue 238 addressed - try: - to_recolor_or_not_to_recolor = decimation_obj.recoloring - except AttributeError: - to_recolor_or_not_to_recolor = decimation_obj.stft.recoloring - if to_recolor_or_not_to_recolor: - stft_obj = apply_recoloring(decimation_obj, stft_obj) - - return stft_obj - - def truncate_to_clock_zero( decimation_obj: Union[AuroraDecimationLevel, FCDecimation], run_xrds: RunGroup, @@ -204,7 +35,7 @@ def truncate_to_clock_zero( Parameters ---------- - decimation_obj: mt_metadata.transfer_functions.processing.aurora.DecimationLevel + decimation_obj: Union[AuroraDecimationLevel, FCDecimation] Information about how the decimation level is to be processed run_xrds : xarray.core.dataset.Dataset normally extracted from mth5.RunTS @@ -277,7 +108,7 @@ def run_ts_to_stft( Parameters ---------- - decimation_obj : mt_metadata.transfer_functions.processing.aurora.DecimationLevel + decimation_obj : AuroraDecimationLevel Information about how the decimation level is to be processed run_ts : xarray.core.dataset.Dataset normally extracted from mth5.RunTS @@ -292,7 +123,7 @@ def run_ts_to_stft( # need to remove any nans before windowing, or else if there is a single # nan then the whole channel becomes nan. run_xrds = nan_to_mean(run_xrds_orig) - run_xrds = apply_prewhitening(decimation_obj, run_xrds) + run_xrds = apply_prewhitening(decimation_obj.stft.prewhitening_type, run_xrds) run_xrds = truncate_to_clock_zero(decimation_obj, run_xrds) windowing_scheme = window_scheme_from_decimation(decimation_obj) windowed_obj = windowing_scheme.apply_sliding_window( @@ -313,7 +144,7 @@ def run_ts_to_stft( ) if decimation_obj.stft.recoloring: - stft_obj = apply_recoloring(decimation_obj, stft_obj) + stft_obj = apply_recoloring(decimation_obj.stft.prewhitening_type, stft_obj) return stft_obj diff --git a/aurora/time_series/frequency_band_helpers.py b/aurora/time_series/frequency_band_helpers.py index a39d4e23..6c0f085b 100644 --- a/aurora/time_series/frequency_band_helpers.py +++ b/aurora/time_series/frequency_band_helpers.py @@ -6,17 +6,24 @@ from mt_metadata.transfer_functions.processing.aurora import ( DecimationLevel as AuroraDecimationLevel, ) +from mt_metadata.transfer_functions.processing.aurora import Band +from mth5.timeseries.spectre.spectrogram import extract_band +from typing import Optional, Tuple +import xarray as xr def get_band_for_tf_estimate( - band, dec_level_config: AuroraDecimationLevel, local_stft_obj, remote_stft_obj -): + band: Band, + dec_level_config: AuroraDecimationLevel, + local_stft_obj: xr.Dataset, + remote_stft_obj: Optional[xr.Dataset], +) -> Tuple[xr.Dataset, xr.Dataset, Optional[xr.Dataset]]: """ Returns spectrograms X, Y, RR for harmonics within the given band Parameters ---------- - band : mt_metadata.transfer_functions.processing.aurora.FrequencyBands + band : mt_metadata.transfer_functions.processing.aurora.Band object with lower_bound and upper_bound to tell stft object which subarray to return config : AuroraDecimationLevel @@ -53,49 +60,6 @@ def get_band_for_tf_estimate( return X, Y, RR -def extract_band(frequency_band, fft_obj, channels=[], epsilon=1e-7): - """ - Extracts a frequency band from xr.DataArray representing a spectrogram. - - Stand alone version of the method that is used by WIP Spectrogram class. - - Development Notes: - #1: 20230902 - TODO: Decide if base dataset object should be a xr.DataArray (not xr.Dataset) - - drop=True does not play nice with h5py and Dataset, results in a type error. - File "stringsource", line 2, in h5py.h5r.Reference.__reduce_cython__ - TypeError: no default __reduce__ due to non-trivial __cinit__ - However, it works OK with DataArray, so maybe use data array in general - - Parameters - ---------- - frequency_band: mt_metadata.transfer_functions.processing.aurora.band.Band - Specifies interval corresponding to a frequency band - fft_obj: xarray.core.dataset.Dataset - To be replaced with an fft_obj() class in future - epsilon: float - Use this when you are worried about missing a frequency due to - round off error. This is in general not needed if we use a df/2 pad - around true harmonics. - - Returns - ------- - band: xr.DataArray - The frequencies within the band passed into this function - """ - cond1 = fft_obj.frequency >= frequency_band.lower_bound - epsilon - cond2 = fft_obj.frequency <= frequency_band.upper_bound + epsilon - try: - band = fft_obj.where(cond1 & cond2, drop=True) - except TypeError: # see Note #1 - tmp = fft_obj.to_array() - band = tmp.where(cond1 & cond2, drop=True) - band = band.to_dataset("variable") - if channels: - band = band[channels] - return band - - def check_time_axes_synched(X, Y): """ Utility function for checking that time axes agree. diff --git a/aurora/time_series/spectrogram.py b/aurora/time_series/spectrogram.py deleted file mode 100644 index ad3684d3..00000000 --- a/aurora/time_series/spectrogram.py +++ /dev/null @@ -1,160 +0,0 @@ -""" - WORK IN PROGRESS (WIP): This module contains a class that represents a spectrogram, - i.e. A 2D time series of Fourier coefficients with axes time and frequency. - -""" -from aurora.time_series.frequency_band_helpers import extract_band -from typing import Optional -import xarray - - -class Spectrogram(object): - """ - Class to contain methods for STFT objects. - TODO: Add support for cross powers - TODO: Add OLS Z-estimates - TODO: Add Sims/Vozoff Z-estimates - - """ - - def __init__(self, dataset=None): - """Constructor""" - self._dataset = dataset - self._frequency_increment = None - - def _lowest_frequency(self): - pass - - def _higest_frequency(self): - pass - - def __str__(self) -> str: - """Returns a Description of frequency coverage""" - intro = "Spectrogram:" - frequency_coverage = ( - f"{self.dataset.dims['frequency']} harmonics, {self.frequency_increment}Hz spaced \n" - f" from {self.dataset.frequency.data[0]} to {self.dataset.frequency.data[-1]} Hz." - ) - time_coverage = f"\n{self.dataset.dims['time']} Time observations" - time_coverage = f"{time_coverage} \nStart: {self.dataset.time.data[0]}" - time_coverage = f"{time_coverage} \nEnd: {self.dataset.time.data[-1]}" - - channel_coverage = list(self.dataset.data_vars.keys()) - channel_coverage = "\n".join(channel_coverage) - channel_coverage = f"\nChannels present: \n{channel_coverage}" - return ( - intro - + "\n" - + frequency_coverage - + "\n" - + time_coverage - + "\n" - + channel_coverage - ) - - def __repr__(self): - return self.__str__() - - @property - def dataset(self): - """returns the underlying xarray data""" - return self._dataset - - @property - def time_axis(self): - """returns the time axis of the underlying xarray""" - return self.dataset.time - - @property - def frequency_increment(self): - """ - returns the "delta f" of the frequency axis - - assumes uniformly sampled in frequency domain - """ - if self._frequency_increment is None: - frequency_axis = self.dataset.frequency - self._frequency_increment = frequency_axis.data[1] - frequency_axis.data[0] - return self._frequency_increment - - def num_harmonics_in_band(self, frequency_band, epsilon=1e-7): - """ - - Returns the number of harmonics within the frequency band in the underlying dataset - - Parameters - ---------- - band - stft_obj - - Returns - ------- - - """ - cond1 = self._dataset.frequency >= frequency_band.lower_bound - epsilon - cond2 = self._dataset.frequency <= frequency_band.upper_bound + epsilon - num_harmonics = (cond1 & cond2).data.sum() - return num_harmonics - - def extract_band(self, frequency_band, channels=[]): - """ - Returns another instance of Spectrogram, with the frequency axis reduced to the input band. - - TODO: Consider returning a copy of the data... - - Parameters - ---------- - frequency_band - channels - - Returns - ------- - spectrogram: aurora.time_series.spectrogram.Spectrogram - Returns a Spectrogram object with only the extracted band for a dataset - - """ - extracted_band_dataset = extract_band( - frequency_band, - self.dataset, - channels=channels, - epsilon=self.frequency_increment / 2.0, - ) - spectrogram = Spectrogram(dataset=extracted_band_dataset) - return spectrogram - - # TODO: Add cross power method - # def cross_powers(self, ch1, ch2, band=None): - # pass - - def flatten(self, chunk_by: Optional[str] = "time") -> xarray.Dataset: - """ - - Returns the flattened xarray (time-chunked by default). - - Parameters - ---------- - chunk_by: str - Controlled vocabulary ["time", "frequency"]. Reshaping the 2D spectrogram can be done two ways - (basically "row-major", or column-major). In xarray, but we either keep frequency constant and - iterate over time, or keep time constant and iterate over frequency (in the inner loop). - - - Returns - ------- - xarray.Dataset : The dataset from the band spectrogram, stacked. - - Development Notes: - The flattening used in tf calculation by default is opposite to here - dataset.stack(observation=("frequency", "time")) - However, for feature extraction, it may make sense to swap the order: - xrds = band_spectrogram.dataset.stack(observation=("time", "frequency")) - This is like chunking into time windows and allows individual features to be computed on each time window -- if desired. - Still need to split the time series though--Splitting to time would be a reshape by (last_freq_index-first_freq_index). - Using pure xarray this may not matter but if we drop down into numpy it could be useful. - - - """ - if chunk_by == "time": - observation = ("time", "frequency") - elif chunk_by == "frequency": - observation = ("frequency", "time") - return self.dataset.stack(observation=observation) diff --git a/aurora/time_series/xarray_helpers.py b/aurora/time_series/xarray_helpers.py index db30493f..f497ab4c 100644 --- a/aurora/time_series/xarray_helpers.py +++ b/aurora/time_series/xarray_helpers.py @@ -2,13 +2,17 @@ Placeholder module for methods manipulating xarray time series """ -import numpy as np import xarray as xr from loguru import logger -from typing import Optional, Union +from typing import Optional -def handle_nan(X, Y, RR, drop_dim=""): +def handle_nan( + X: xr.Dataset, + Y: Optional[xr.Dataset], + RR: Optional[xr.Dataset], + drop_dim: Optional[str] = "", +) -> tuple: """ Drops Nan from multiple channel series'. @@ -87,118 +91,3 @@ def handle_nan(X, Y, RR, drop_dim=""): RR = RR.rename(data_var_rm_label_mapper) return X, Y, RR - - -def covariance_xr( - X: xr.DataArray, aweights: Optional[Union[np.ndarray, None]] = None -) -> xr.DataArray: - """ - Compute the covariance matrix with numpy.cov. - - Parameters - ---------- - X: xarray.core.dataarray.DataArray - Multivariate time series as an xarray - aweights: array_like, optional - Doc taken from numpy cov follows: - 1-D array of observation vector weights. These relative weights are - typically large for observations considered "important" and smaller for - observations considered less "important". If ``ddof=0`` the array of - weights can be used to assign probabilities to observation vectors. - - Returns - ------- - S: xarray.DataArray - The covariance matrix of the data in xarray form. - """ - - channels = list(X.coords["variable"].values) - - S = xr.DataArray( - np.cov(X, aweights=aweights), - dims=["channel_1", "channel_2"], - coords={"channel_1": channels, "channel_2": channels}, - ) - return S - - -def initialize_xrda_1d( - channels: list, - dtype=Optional[type], - value: Optional[Union[complex, float, bool]] = 0, -) -> xr.DataArray: - """ - - Returns a 1D xr.DataArray with variable "channel", having values channels named by the input list. - - Parameters - ---------- - channels: list - The channels in the multivariate array - dtype: type - The datatype to initialize the array. - Common cases are complex, float, and bool - value: Union[complex, float, bool] - The default value to assign the array - - Returns - ------- - xrda: xarray.core.dataarray.DataArray - An xarray container for the channels, initialized to zeros. - """ - k = len(channels) - logger.debug(f"Initializing xarray with values {value}") - xrda = xr.DataArray( - np.zeros(k, dtype=dtype), - dims=[ - "variable", - ], - coords={ - "variable": channels, - }, - ) - if value != 0: - data = value * np.ones(k, dtype=dtype) - xrda.data = data - return xrda - - -def initialize_xrda_2d( - channels, dtype=complex, value: Optional[Union[complex, float, bool]] = 0, dims=None -): - - """ - TODO: consider merging with initialize_xrda_1d - TODO: consider changing nomenclature from dims=["channel_1", "channel_2"], - to dims=["variable_1", "variable_2"], to be consistent with initialize_xrda_1d - - Parameters - ---------- - channels: list - The channels in the multivariate array - dtype: type - The datatype to initialize the array. - Common cases are complex, float, and bool - value: Union[complex, float, bool] - The default value to assign the array - - Returns - ------- - xrda: xarray.core.dataarray.DataArray - An xarray container for the channel variances etc., initialized to zeros. - """ - if dims is None: - dims = [channels, channels] - - K = len(channels) - logger.debug(f"Initializing 2D xarray to {value}") - xrda = xr.DataArray( - np.zeros((K, K), dtype=dtype), - dims=["channel_1", "channel_2"], - coords={"channel_1": dims[0], "channel_2": dims[1]}, - ) - if value != 0: - data = value * np.ones(xrda.shape, dtype=dtype) - xrda.data = data - - return xrda diff --git a/aurora/transfer_function/base.py b/aurora/transfer_function/base.py index 3f1ffc73..f26ac2e7 100644 --- a/aurora/transfer_function/base.py +++ b/aurora/transfer_function/base.py @@ -50,10 +50,6 @@ def __init__( """ Constructor. - Development Notes: - change 2021-07-23 to require a frequency_bands object. We may want - to just pass the band_edges. - Parameters ---------- _emtf_header : legacy header information used by Egbert's matlab class. Header contains @@ -61,8 +57,8 @@ def __init__( decimation_level_id: int Identifies the relevant decimation level. Used for accessing the appropriate info in self.processing config. - frequency_bands: aurora.time_series.frequency_band.FrequencyBands - frequency bands object + frequency_bands: FrequencyBands + frequency bands object defining the tf estimation bands. """ self._emtf_tf_header = None self.decimation_level_id = decimation_level_id diff --git a/docs/tutorials/synthetic_data_processing.ipynb b/docs/tutorials/synthetic_data_processing.ipynb index b505a313..d62fbc47 100644 --- a/docs/tutorials/synthetic_data_processing.ipynb +++ b/docs/tutorials/synthetic_data_processing.ipynb @@ -1908,9 +1908,9 @@ "metadata": {}, "outputs": [], "source": [ - "from aurora.pipelines.fourier_coefficients import add_fcs_to_mth5\n", - "from aurora.pipelines.fourier_coefficients import fc_decimations_creator\n", - "from aurora.pipelines.fourier_coefficients import read_back_fcs" + "from mth5.timeseries.spectre.helpers import add_fcs_to_mth5\n", + "from mth5.timeseries.spectre.helpers import fc_decimations_creator\n", + "from mth5.timeseries.spectre.helpers import read_back_fcs" ] }, { diff --git a/tests/config/test_config_creator.py b/tests/config/test_config_creator.py index 047d3cad..f7e2f4f4 100644 --- a/tests/config/test_config_creator.py +++ b/tests/config/test_config_creator.py @@ -1,4 +1,5 @@ # import logging +import pandas as pd import unittest from aurora.config.config_creator import ConfigCreator @@ -68,8 +69,13 @@ def test_frequency_bands(self): delta_f = dec_level_0.frequency_sample_interval lower_edges = (dec_level_0.lower_bounds * delta_f) - delta_f / 2.0 upper_edges = (dec_level_0.upper_bounds * delta_f) + delta_f / 2.0 - band_edges_b = np.vstack((lower_edges, upper_edges)).T - assert (band_edges_b - band_edges_a == 0).all() + band_edges_b = pd.DataFrame( + data={ + "lower_bound": lower_edges, + "upper_bound": upper_edges, + } + ) + assert (band_edges_b - band_edges_a == 0).all().all() def main(): diff --git a/tests/synthetic/test_fourier_coefficients.py b/tests/synthetic/test_fourier_coefficients.py index 3725984e..2981faaf 100644 --- a/tests/synthetic/test_fourier_coefficients.py +++ b/tests/synthetic/test_fourier_coefficients.py @@ -1,26 +1,21 @@ import unittest from aurora.config.config_creator import ConfigCreator -from aurora.pipelines.fourier_coefficients import add_fcs_to_mth5 -from aurora.pipelines.fourier_coefficients import fc_decimations_creator - -# from aurora.pipelines.fourier_coefficients import read_back_fcs from aurora.pipelines.process_mth5 import process_mth5 from aurora.test_utils.synthetic.make_processing_configs import ( create_test_run_config, ) from aurora.test_utils.synthetic.paths import SyntheticTestPaths +from loguru import logger from mth5.data.make_mth5_from_asc import create_test1_h5 from mth5.data.make_mth5_from_asc import create_test2_h5 from mth5.data.make_mth5_from_asc import create_test3_h5 from mth5.data.make_mth5_from_asc import create_test12rr_h5 -from mth5.timeseries.spectre.helpers import read_back_fcs - -# from mtpy-v2 -from mtpy.processing import RunSummary, KernelDataset - -from loguru import logger from mth5.helpers import close_open_files +from mth5.timeseries.spectre.helpers import add_fcs_to_mth5 +from mth5.timeseries.spectre.helpers import fc_decimations_creator +from mth5.timeseries.spectre.helpers import read_back_fcs +from mtpy.processing import RunSummary, KernelDataset # from mtpy-v2 synthetic_test_paths = SyntheticTestPaths() synthetic_test_paths.mkdirs() diff --git a/tests/synthetic/test_stft_methods_agree.py b/tests/synthetic/test_stft_methods_agree.py index 3878882a..16320096 100644 --- a/tests/synthetic/test_stft_methods_agree.py +++ b/tests/synthetic/test_stft_methods_agree.py @@ -7,18 +7,15 @@ from aurora.pipelines.time_series_helpers import prototype_decimate from aurora.pipelines.time_series_helpers import run_ts_to_stft -from aurora.pipelines.time_series_helpers import run_ts_to_stft_scipy from aurora.test_utils.synthetic.make_processing_configs import ( create_test_run_config, ) - -# from mtpy-v2 -from mtpy.processing import RunSummary, KernelDataset - from loguru import logger from mth5.data.make_mth5_from_asc import create_test1_h5 from mth5.mth5 import MTH5 from mth5.helpers import close_open_files +from mth5.timeseries.spectre.stft import run_ts_to_stft_scipy +from mtpy.processing import RunSummary, KernelDataset # from mtpy-v2 def test_stft_methods_agree(): diff --git a/tests/time_series/test_spectrogram.py b/tests/time_series/test_spectrogram.py deleted file mode 100644 index 22658981..00000000 --- a/tests/time_series/test_spectrogram.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -""" -""" - -import unittest - -from aurora.time_series.spectrogram import Spectrogram - - -class TestSpectrogram(unittest.TestCase): - """ - Test Spectrogram class - """ - - @classmethod - def setUpClass(self): - pass - - def setUp(self): - pass - - def test_initialize(self): - spectrogram = Spectrogram() - assert isinstance(spectrogram, Spectrogram) - - def test_slice_band(self): - """ - Place holder - TODO: Once FCs are added to an mth5, load a spectrogram and extract a Band - """ - pass - - -if __name__ == "__main__": - # tmp = TestSpectrogram() - # tmp.test_initialize() - unittest.main() diff --git a/tests/time_series/test_xarray_helpers.py b/tests/time_series/test_xarray_helpers.py index 247e77da..9f57df15 100644 --- a/tests/time_series/test_xarray_helpers.py +++ b/tests/time_series/test_xarray_helpers.py @@ -4,74 +4,83 @@ """ import numpy as np -import unittest - import xarray as xr +import pytest + +from aurora.time_series.xarray_helpers import handle_nan + + +def test_handle_nan_basic(): + """Test basic functionality of handle_nan with NaN values.""" + # Create sample data with NaN values + times = np.array([0, 1, 2, 3, 4]) + data_x = np.array([1.0, np.nan, 3.0, 4.0, 5.0]) + data_y = np.array([1.0, 2.0, np.nan, 4.0, 5.0]) + + X = xr.Dataset({"hx": ("time", data_x)}, coords={"time": times}) + Y = xr.Dataset({"ex": ("time", data_y)}, coords={"time": times}) + + # Test with X and Y only + X_clean, Y_clean, _ = handle_nan(X, Y, None, drop_dim="time") + + # Check that NaN values were dropped + assert len(X_clean.time) == 3 + assert len(Y_clean.time) == 3 + assert not np.any(np.isnan(X_clean.hx.values)) + assert not np.any(np.isnan(Y_clean.ex.values)) + + +def test_handle_nan_with_remote_reference(): + """Test handle_nan with remote reference data.""" + # Create sample data + times = np.array([0, 1, 2, 3]) + data_x = np.array([1.0, np.nan, 3.0, 4.0]) + data_y = np.array([1.0, 2.0, 3.0, 4.0]) + data_rr = np.array([1.0, 2.0, np.nan, 4.0]) + + X = xr.Dataset({"hx": ("time", data_x)}, coords={"time": times}) + Y = xr.Dataset({"ex": ("time", data_y)}, coords={"time": times}) + RR = xr.Dataset({"hx": ("time", data_rr)}, coords={"time": times}) + + # Test with all datasets + X_clean, Y_clean, RR_clean = handle_nan(X, Y, RR, drop_dim="time") + + # Check that NaN values were dropped + assert len(X_clean.time) == 2 + assert len(Y_clean.time) == 2 + assert len(RR_clean.time) == 2 + assert not np.any(np.isnan(X_clean.hx.values)) + assert not np.any(np.isnan(Y_clean.ex.values)) + assert not np.any(np.isnan(RR_clean.hx.values)) + + # Check that the values are correct + expected_times = np.array([0, 3]) + assert np.allclose(X_clean.time.values, expected_times) + assert np.allclose(Y_clean.time.values, expected_times) + assert np.allclose(RR_clean.time.values, expected_times) + assert np.allclose(X_clean.hx.values, np.array([1.0, 4.0])) + assert np.allclose(Y_clean.ex.values, np.array([1.0, 4.0])) + assert np.allclose(RR_clean.hx.values, np.array([1.0, 4.0])) + + +def test_handle_nan_time_mismatch(): + """Test handle_nan with time coordinate mismatches.""" + # Create sample data with slightly different timestamps + times_x = np.array([0, 1, 2, 3]) + times_rr = times_x + 0.1 # Small offset + data_x = np.array([1.0, 2.0, 3.0, 4.0]) + data_rr = np.array([1.0, 2.0, 3.0, 4.0]) + + X = xr.Dataset({"hx": ("time", data_x)}, coords={"time": times_x}) + RR = xr.Dataset({"hx": ("time", data_rr)}, coords={"time": times_rr}) + + # Test handling of time mismatch + X_clean, _, RR_clean = handle_nan(X, None, RR, drop_dim="time") + + # Check that data was preserved despite time mismatch + assert len(X_clean.time) == 4 + assert "hx" in RR_clean.data_vars + assert np.allclose(RR_clean.hx.values, data_rr) -from aurora.time_series.xarray_helpers import covariance_xr -from aurora.time_series.xarray_helpers import initialize_xrda_1d -from aurora.time_series.xarray_helpers import initialize_xrda_2d - - -class TestXarrayHelpers(unittest.TestCase): - """ - Test methods in xarray helpers - - may get broken into separate tests if this module grows - """ - - @classmethod - def setUpClass(self): - self.standard_channel_names = ["ex", "ey", "hx", "hy", "hz"] - - def setUp(self): - pass - - def test_initialize_xrda_1d(self): - dtype = float - value = -1 - tmp = initialize_xrda_1d(self.standard_channel_names, dtype=dtype, value=value) - self.assertTrue((tmp.data == value).all()) - - def test_initialize_xrda_2d(self): - dtype = float - value = -1 - tmp = initialize_xrda_2d(self.standard_channel_names, dtype=dtype, value=value) - self.assertTrue((tmp.data == value).all()) - - def test_covariance_xr(self): - np.random.seed(0) - n_observations = 100 - xrds = xr.Dataset( - { - "hx": ( - [ - "time", - ], - np.abs(np.random.randn(n_observations)), - ), - "hy": ( - [ - "time", - ], - np.abs(np.random.randn(n_observations)), - ), - }, - coords={ - "time": np.arange(n_observations), - }, - ) - - X = xrds.to_array() - cov = covariance_xr(X) - self.assertTrue((cov.data == cov.data.transpose().conj()).all()) - - def test_sometehing_else(self): - """ - Place holder - - """ - pass - - -if __name__ == "__main__": - unittest.main() + # Check that the time values match X's time values + assert np.allclose(RR_clean.time.values, X_clean.time.values) diff --git a/tests/transfer_function/test_cross_power.py b/tests/transfer_function/test_cross_power.py index b312e5e7..6c708f6f 100644 --- a/tests/transfer_function/test_cross_power.py +++ b/tests/transfer_function/test_cross_power.py @@ -1,4 +1,4 @@ -from aurora.time_series.xarray_helpers import initialize_xrda_2d +from mth5.timeseries.xarray_helpers import initialize_xrda_2d_cov from aurora.transfer_function.cross_power import tf_from_cross_powers from aurora.transfer_function.cross_power import _channel_names from aurora.transfer_function.cross_power import ( @@ -32,7 +32,7 @@ def setUpClass(self): station_1_channels = [f"{self.station_ids[0]}_{x}" for x in components] station_2_channels = [f"{self.station_ids[1]}_{x}" for x in components] channels = station_1_channels + station_2_channels - sdm = initialize_xrda_2d( + sdm = initialize_xrda_2d_cov( channels=channels, dtype=complex, )