Skip to content

Commit

Permalink
merge remote channels into STFT object
Browse files Browse the repository at this point in the history
- Now use a single xarray as input to RR tf estimation
- move check_time_axes_synched into xarray_helpers
- deprecate get_band_for_tf_estimate
  • Loading branch information
kkappler committed Mar 27, 2024
1 parent 46b8197 commit 579ed8a
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 110 deletions.
60 changes: 48 additions & 12 deletions aurora/pipelines/process_mth5.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def make_stft_objects(processing_config, i_dec_level, run_obj, run_xrds, units="
return stft_obj


def process_tf_decimation_level(
config, i_dec_level, local_stft_obj, remote_stft_obj, units="MT"
):
def process_tf_decimation_level(config, i_dec_level, merged_stft_object, units="MT"):
"""
Processing pipeline for a single decimation_level
Expand All @@ -110,10 +108,9 @@ def process_tf_decimation_level(
i_dec_level: int
decimation level_id
?could we pack this into the decimation level as an attr?
local_stft_obj: xarray.core.dataset.Dataset
The time series of Fourier coefficients from the local station
remote_stft_obj: xarray.core.dataset.Dataset or None
The time series of Fourier coefficients from the remote station
merged_stft_object: xarray.core.dataset.Dataset
The time series of Fourier coefficients from the local station, with optional
keys ["rx", "ry"] for remote reference channels
units: str
one of ["MT","SI"]
Expand All @@ -127,7 +124,7 @@ def process_tf_decimation_level(
dec_level_config = config.decimations[i_dec_level]
# segment_weights = coherence_weights(dec_level_config, local_stft_obj, remote_stft_obj)
transfer_function_obj = process_transfer_functions(
dec_level_config, local_stft_obj, remote_stft_obj, transfer_function_obj
dec_level_config, merged_stft_object, transfer_function_obj
)

return transfer_function_obj
Expand Down Expand Up @@ -167,7 +164,7 @@ def triage_issue_289(local_stfts, remote_stfts):
return local_stfts, remote_stfts


def merge_stfts(stfts, tfk):
def merge_stft_runs_in_time(stfts, tfk):
# Timing Error Workaround See Aurora Issue #289
local_stfts = stfts["local"]
remote_stfts = stfts["remote"]
Expand All @@ -182,6 +179,39 @@ def merge_stfts(stfts, tfk):
return local_merged_stft_obj, remote_merged_stft_obj


def merge_stft_remote_channels(local_stft_obj, remote_stft_obj, remote_channels=[]):
"""
Temporary function. Would like to push the merge with remote further upstream in the processing,
and use the spectrogram class earlier as well. For now, this provides a way to merge the classes at the
start of process_transfer_functions.
This provides an interface to merge
Parameters
----------
local_stft_obj
remote_stft_obj
Returns
-------
"""
from aurora.time_series.xarray_helpers import check_time_axes_synched

if remote_stft_obj is not None:
check_time_axes_synched(local_stft_obj, remote_stft_obj)
# this could be gneralized, but first we need to make sure
# that it doesn't need to support other nomenclatures (e1, e2, h1, h2, h3, etc.)
if "ex" in remote_channels:
local_stft_obj["rx"] = remote_stft_obj["ex"]
elif "hx" in remote_channels:
local_stft_obj["rx"] = remote_stft_obj["hx"]
if "ey" in remote_channels:
local_stft_obj["ry"] = remote_stft_obj["ey"]
elif "hy" in remote_channels:
local_stft_obj["ry"] = remote_stft_obj["hy"]
return local_stft_obj


def append_chunk_to_stfts(stfts, chunk, remote):
if remote:
stfts["remote"].append(chunk)
Expand Down Expand Up @@ -418,16 +448,22 @@ def process_mth5(

stfts = get_spectrogams(tfk, i_dec_level, units=units)

local_merged_stft_obj, remote_merged_stft_obj = merge_stfts(stfts, tfk)
local_merged_stft_obj, remote_merged_stft_obj = merge_stft_runs_in_time(
stfts, tfk
)

merged_stft_obj = merge_stft_remote_channels(
local_merged_stft_obj,
remote_merged_stft_obj,
remote_channels=dec_level_config.reference_channels,
)
# FC TF Interface here (see Note #3)
# Could downweight bad FCs here

ttfz_obj = process_tf_decimation_level(
tfk.config,
i_dec_level,
local_merged_stft_obj,
remote_merged_stft_obj,
merged_stft_obj,
)
ttfz_obj.apparent_resistivity(tfk.config.channel_nomenclature, units=units)
tf_dict[i_dec_level] = ttfz_obj
Expand Down
57 changes: 35 additions & 22 deletions aurora/pipelines/transfer_function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np

from aurora.time_series.frequency_band_helpers import adjust_band_for_coherence_sorting
from aurora.time_series.frequency_band_helpers import get_band_for_tf_estimate

# from aurora.time_series.frequency_band_helpers import get_band_for_tf_estimate
from aurora.time_series.spectrogram import Spectrogram
from aurora.time_series.xarray_helpers import handle_nan
from aurora.transfer_function.regression import get_regression_estimator
Expand Down Expand Up @@ -124,8 +125,7 @@ def apply_weights(X, Y, RR, W, segment=False, dropna=False):

def process_transfer_functions(
dec_level_config,
local_stft_obj,
remote_stft_obj,
merged_stft_object,
transfer_function_obj,
# segment_weights=["multiple_coherence",],#["simple_coherence",],#["multiple_coherence",],#jj84_coherence_weights",],
segment_weights=[],
Expand All @@ -137,8 +137,10 @@ def process_transfer_functions(
Parameters
----------
dec_level_config
local_stft_obj
remote_stft_obj
merged_stft_object: xarray.core.dataset.Dataset
The time series of Fourier coefficients from the local station, with optional
keys ["rx", "ry"] for remote reference channels
transfer_function_obj: aurora.transfer_function.TTFZ.TTFZ
The transfer function container ready to receive values in this method.
segment_weights : numpy array or list of strings
Expand Down Expand Up @@ -170,25 +172,25 @@ def process_transfer_functions(
-------
"""
# Experimental nomenclature change for RR case-- If adopted, add mapping rx:hx, ry:hy to processing config
local_stft_obj["rx"] = remote_stft_obj["hx"]
local_stft_obj["ry"] = remote_stft_obj["hy"]

# Also consider applying channel nomenlclature map to standard channels for regression, map back when done.

estimator_class = get_regression_estimator(dec_level_config.estimator.engine)
iter_control = set_up_iter_control(dec_level_config)
spectrogram = Spectrogram(dataset=merged_stft_object)
for band in transfer_function_obj.frequency_bands.bands():

# Uncomment for Testing -- this will make a 3xFC-wide spectrogram if there is only 1 FC
rule = "min3" # TODO: Put band-adjustment rule into processing config
spectrogram = Spectrogram(dataset=local_stft_obj)
adjusted_band = adjust_band_for_coherence_sorting(band, spectrogram, rule=rule)

band_spectrogram = spectrogram.extract_band(adjusted_band)
X, Y, RR = get_band_for_tf_estimate(
band, dec_level_config, local_stft_obj, remote_stft_obj
)
# adjusted_band = adjust_band_for_coherence_sorting(band, spectrogram, rule=rule)
# band_spectrogram = spectrogram.extract_band(adjusted_band)

band_spectrogram = spectrogram.extract_band(band)
X = band_spectrogram.dataset[dec_level_config.input_channels]
Y = band_spectrogram.dataset[dec_level_config.output_channels]
if dec_level_config.reference_channels:
RR = band_spectrogram.dataset[["rx", "ry"]]
else:
RR = None

# Apply segment weights first -- see Note #2

Expand All @@ -197,7 +199,13 @@ def process_transfer_functions(
coherence_weights_jj84,
)

Wjj84 = coherence_weights_jj84(band, local_stft_obj, remote_stft_obj)
rule = "min3" # TODO: Put band-adjustment rule into processing config
adjusted_band = adjust_band_for_coherence_sorting(
band, spectrogram, rule=rule
)
weight_band_spectrogram = spectrogram.extract_band(adjusted_band)

Wjj84 = coherence_weights_jj84(weight_band_spectrogram)
apply_weights(X, Y, RR, Wjj84, segment=True, dropna=False)
if "simple_coherence" in segment_weights:
# Note that these weights might be better applied within the loop over channel as the weights
Expand Down Expand Up @@ -226,21 +234,19 @@ def process_transfer_functions(
# cumulative_weights np.ones(n_obs)
# for k,v in weights.items():
# cumulative_weights *= v

rule = "min3" # TODO: Put band-adjustment rule into processing config
spectrogram = Spectrogram(dataset=local_stft_obj)
adjusted_band = adjust_band_for_coherence_sorting(
band, spectrogram, rule=rule
)
band_spectrogram = spectrogram.extract_band(adjusted_band)
weight_band_spectrogram = spectrogram.extract_band(adjusted_band)

# Optionally add rx, ry to spectrogram here
channel_pairs = (
("ex", "hy"),
("ey", "hx"),
)
simple_coherences = estimate_simple_coherence(
band_spectrogram, channel_pairs=channel_pairs
weight_band_spectrogram, channel_pairs=channel_pairs
)
# TODO: Put cutoffs in the procesing config:
cutoffs = {}
Expand Down Expand Up @@ -268,7 +274,14 @@ def process_transfer_functions(
multiple_coherence_weights,
)

W = multiple_coherence_weights(band, local_stft_obj, remote_stft_obj)
rule = "min3" # TODO: Put band-adjustment rule into processing config
adjusted_band = adjust_band_for_coherence_sorting(
band, spectrogram, rule=rule
)
weight_band_spectrogram = spectrogram.extract_band(adjusted_band)
W = multiple_coherence_weights(
weight_band_spectrogram
) # band, local_stft_obj, remote_stft_obj)
apply_weights(X, Y, RR, W, segment=True, dropna=False)

# if there are channel weights apply them here
Expand Down
77 changes: 1 addition & 76 deletions aurora/time_series/frequency_band_helpers.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,7 @@
import numpy as np
# import numpy as np
from loguru import logger


def get_band_for_tf_estimate(band, dec_level_config, local_stft_obj, remote_stft_obj):
"""
Returns spectrograms X, Y, RR for harmonics within the given band
Parameters
----------
band : mt_metadata.transfer_functions.processing.aurora.FrequencyBands
object with lower_bound and upper_bound to tell stft object which
subarray to return
config : mt_metadata.transfer_functions.processing.aurora.decimation_level.DecimationLevel
information about the input and output channels needed for TF
estimation problem setup
local_stft_obj : xarray.core.dataset.Dataset or None
Time series of Fourier coefficients for the station whose TF is to be
estimated
remote_stft_obj : xarray.core.dataset.Dataset or None
Time series of Fourier coefficients for the remote reference station
Returns
-------
X, Y, RR : xarray.core.dataset.Dataset or None
data structures as local_stft_object and remote_stft_object, but
restricted only to input_channels, output_channels,
reference_channels and also the frequency axes are restricted to
being within the frequency band given as an input argument.
"""
logger.info(
f"Processing band {band.center_period:.6f}s ({1./band.center_period:.6f}Hz)"
)
band_dataset = extract_band(band, local_stft_obj)
X = band_dataset[dec_level_config.input_channels]
Y = band_dataset[dec_level_config.output_channels]
check_time_axes_synched(X, Y)
if dec_level_config.reference_channels:
band_dataset = extract_band(band, remote_stft_obj)
RR = band_dataset[dec_level_config.reference_channels]
check_time_axes_synched(Y, RR)
else:
RR = None

return X, Y, RR


def extract_band(frequency_band, fft_obj, channels=[], epsilon=1e-7):
"""
Stand alone method that operates on an xr.DataArray, and is wrapped with Spectrogram
Expand Down Expand Up @@ -85,34 +42,6 @@ def extract_band(frequency_band, fft_obj, channels=[], epsilon=1e-7):
return band


def check_time_axes_synched(X, Y):
"""
Utility function for checking that time axes agree
Parameters
----------
X : xarray
Y : xarray
Returns
-------
"""
"""
It is critical that X, Y, RR have the same time axes here
Returns
-------
"""
if (X.time == Y.time).all():
pass
else:
logger.warning("WARNING - NAN Handling could fail if X,Y dont share time axes")
raise Exception
return


def adjust_band_for_coherence_sorting(frequency_band, spectrogram, rule="min3"):
"""
Expand Down Expand Up @@ -247,7 +176,3 @@ def adjust_band_for_coherence_sorting(frequency_band, spectrogram, rule="min3"):
# return get_band_for_tf_estimate(
# band, dec_level_config, local_stft_obj, remote_stft_obj
# )


def cross_spectra(X, Y):
return X.conj() * Y
28 changes: 28 additions & 0 deletions aurora/time_series/xarray_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,31 @@ def handle_nan(X, Y, RR, drop_dim=""):
RR = RR.rename(data_var_rm_label_mapper)

return X, Y, RR


def check_time_axes_synched(X, Y):
"""
Utility function for checking that time axes agree
Parameters
----------
X : xarray
Y : xarray
Returns
-------
"""
"""
It is critical that X, Y, RR have the same time axes here
Returns
-------
"""
if (X.time == Y.time).all():
pass
else:
logger.warning("WARNING - NAN Handling could fail if X,Y dont share time axes")
raise Exception
return

0 comments on commit 579ed8a

Please sign in to comment.