Skip to content

Commit

Permalink
added prototype_decimation_4 to use xr.sps_signal accessors
Browse files Browse the repository at this point in the history
  • Loading branch information
kujaku11 committed May 29, 2024
1 parent e5c7ee8 commit 103c554
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 54 deletions.
62 changes: 51 additions & 11 deletions aurora/pipelines/time_series_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def apply_prewhitening(decimation_obj, run_xrds_input):
run_xrds = run_xrds_input.differentiate("time")

else:
msg = f"{decimation_obj.prewhitening_type} pre-whitening not implemented"
msg = (
f"{decimation_obj.prewhitening_type} pre-whitening not implemented"
)
logger.exception(msg)
raise NotImplementedError(msg)
return run_xrds
Expand Down Expand Up @@ -194,7 +196,9 @@ def truncate_to_clock_zero(decimation_obj, run_xrds):
pass # time series start is already clock zero
else:
windowing_scheme = window_scheme_from_decimation(decimation_obj)
number_of_steps = delta_t_seconds / windowing_scheme.duration_advance
number_of_steps = (
delta_t_seconds / windowing_scheme.duration_advance
)
n_partial_steps = number_of_steps - np.floor(number_of_steps)
n_clip = n_partial_steps * windowing_scheme.num_samples_advance
n_clip = int(np.round(n_clip))
Expand Down Expand Up @@ -222,8 +226,10 @@ def nan_to_mean(xrds):
for ch in xrds.keys():
null_values_present = xrds[ch].isnull().any()
if null_values_present:
nan_count = np.count_nonzero(np.isnan(xrds[ch]))
logger.info(
"Null values detected in xrds -- this is not expected and should be examined"
f"{nan_count} Null values detected in xrds channel {ch}. "
"Check if this is unexpected."
)
value = np.nan_to_num(np.nanmean(xrds[ch].data))
xrds[ch] = xrds[ch].fillna(value)
Expand Down Expand Up @@ -259,7 +265,9 @@ def run_ts_to_stft(decimation_obj, run_xrds_orig):
if not np.prod(windowed_obj.to_array().data.shape):
raise ValueError

windowed_obj = WindowedTimeSeries.detrend(data=windowed_obj, detrend_type="linear")
windowed_obj = WindowedTimeSeries.detrend(
data=windowed_obj, detrend_type="linear"
)
tapered_obj = windowed_obj * windowing_scheme.taper
stft_obj = windowing_scheme.apply_fft(
tapered_obj, detrend_type=decimation_obj.extra_pre_fft_detrend_type
Expand All @@ -269,7 +277,9 @@ def run_ts_to_stft(decimation_obj, run_xrds_orig):
return stft_obj


def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None):
def calibrate_stft_obj(
stft_obj, run_obj, units="MT", channel_scale_factors=None
):
"""
Parameters
Expand All @@ -291,15 +301,16 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None
Time series of calibrated Fourier coefficients
"""
for channel_id in stft_obj.keys():

channel = run_obj.get_channel(channel_id)
channel_response = channel.channel_response
if not channel_response.filters_list:
msg = f"Channel {channel_id} with empty filters list detected"
logger.warning(msg)
if channel_id == "hy":
msg = "Channel hy has no filters, try using filters from hx"
logger.warning("Channel HY has no filters, try using filters from HX")
logger.warning(
"Channel HY has no filters, try using filters from HX"
)
channel_response = run_obj.get_channel("hx").channel_response

indices_to_flip = channel_response.get_indices_of_filters_to_remove(
Expand All @@ -308,7 +319,9 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None
indices_to_flip = [
i for i in indices_to_flip if channel.metadata.filter.applied[i]
]
filters_to_remove = [channel_response.filters_list[i] for i in indices_to_flip]
filters_to_remove = [
channel_response.filters_list[i] for i in indices_to_flip
]
if not filters_to_remove:
logger.warning("No filters to remove")
calibration_response = channel_response.complex_response(
Expand All @@ -321,7 +334,9 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None
channel_scale_factor = 1.0
calibration_response /= channel_scale_factor
if units == "SI":
logger.warning("Warning: SI Units are not robustly supported issue #36")
logger.warning(
"Warning: SI Units are not robustly supported issue #36"
)
stft_obj[channel_id].data /= calibration_response
return stft_obj

Expand Down Expand Up @@ -353,7 +368,9 @@ def prototype_decimate(config, run_xrds):
num_channels = len(channel_labels)
new_data = np.full((num_observations, num_channels), np.nan)
for i_ch, ch_label in enumerate(channel_labels):
new_data[:, i_ch] = ssig.decimate(run_xrds[ch_label], int(config.factor))
new_data[:, i_ch] = ssig.decimate(
run_xrds[ch_label], int(config.factor)
)

xr_da = xr.DataArray(
new_data,
Expand Down Expand Up @@ -387,7 +404,9 @@ def prototype_decimate_2(config, run_xrds):
xr_ds: xr.Dataset
Decimated version of the input run_xrds
"""
new_xr_ds = run_xrds.coarsen(time=int(config.factor), boundary="trim").mean()
new_xr_ds = run_xrds.coarsen(
time=int(config.factor), boundary="trim"
).mean()
attr_dict = run_xrds.attrs
attr_dict["sample_rate"] = config.sample_rate
new_xr_ds.attrs = attr_dict
Expand Down Expand Up @@ -422,3 +441,24 @@ def prototype_decimate_3(config, run_xrds):
attr_dict["sample_rate"] = config.sample_rate
new_xr_ds.attrs = attr_dict
return new_xr_ds


def prototype_decimate_4(config, run_xrds):
"""
use scipy filters resample_poly
:param config: DESCRIPTION
:type config: TYPE
:param run_xrds: DESCRIPTION
:type run_xrds: TYPE
:return: DESCRIPTION
:rtype: TYPE
"""
new_ds = run_xrds.fillna(0)
new_ds = new_ds.sps_filters.resample_poly(
config.sample_rate, pad_type="mean"
)

new_ds.attrs["sample_rate"] = config.sample_rate
return new_ds
1 change: 0 additions & 1 deletion aurora/pipelines/transfer_function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def process_transfer_functions(
# if RR is not None:
# W = effective_degrees_of_freedom_weights(X_, RR_, edf_obj=None)
# X_, Y_, RR_ = apply_weights(X_, Y_, RR_, W, segment=False)

regression_estimator = estimator_class(
X=X_, Y=Y_, Z=RR_, iter_control=iter_control
)
Expand Down
105 changes: 77 additions & 28 deletions aurora/pipelines/transfer_function_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import psutil

from aurora.pipelines.helpers import initialize_config
from aurora.pipelines.time_series_helpers import prototype_decimate
from aurora.pipelines.time_series_helpers import (
prototype_decimate,
prototype_decimate_4,
)
from mth5.utils.exceptions import MTH5Error
from mth5.utils.helpers import initialize_mth5
from mth5.utils.helpers import path_or_mth5_object
Expand Down Expand Up @@ -132,21 +135,23 @@ def update_dataset_df(self, i_dec_level):
continue
run_xrds = row["run_dataarray"].to_dataset("channel")
decimation = self.config.decimations[i_dec_level].decimation
decimated_xrds = prototype_decimate(decimation, run_xrds)
self.dataset_df["run_dataarray"].at[i] = decimated_xrds.to_array(
decimated_xrds = prototype_decimate_4(decimation, run_xrds)
self.dataset_df["run_dataarray"].at[
i
] = decimated_xrds.to_array(
"channel"
) # See Note 1 above

msg = (
f"Dataset Dataframe Updated for decimation level {i_dec_level} Successfully"
)
msg = f"Dataset Dataframe Updated for decimation level {i_dec_level} Successfully"
logger.info(msg)
return

def apply_clock_zero(self, dec_level_config):
"""get clock-zero from data if needed"""
if dec_level_config.window.clock_zero_type == "data start":
dec_level_config.window.clock_zero = str(self.dataset_df.start.min())
dec_level_config.window.clock_zero = str(
self.dataset_df.start.min()
)
return dec_level_config

@property
Expand Down Expand Up @@ -212,7 +217,12 @@ def check_if_fcs_already_exist(self):
remote = run_sub_df.remote.iloc[0]
mth5_path = run_sub_df.mth5_path.iloc[0]
fcs_present = mth5_has_fcs(
mth5_path, survey_id, station_id, run_id, remote, self.processing_config
mth5_path,
survey_id,
station_id,
run_id,
remote,
self.processing_config,
)
self.dataset_df.loc[dataset_df_indices, "fc"] = fcs_present

Expand Down Expand Up @@ -245,7 +255,9 @@ def show_processing_summary(
columns_to_show = self.processing_summary.columns
columns_to_show = [x for x in columns_to_show if x not in omit_columns]
logger.info("Processing Summary Dataframe:")
logger.info(f"\n{self.processing_summary[columns_to_show].to_string()}")
logger.info(
f"\n{self.processing_summary[columns_to_show].to_string()}"
)

def make_processing_summary(self):
"""
Expand All @@ -265,11 +277,15 @@ def make_processing_summary(self):
decimation_info = self.config.decimation_info()
for i_dec, dec_factor in decimation_info.items():
tmp[i_dec] = dec_factor
tmp = tmp.melt(id_vars=id_vars, value_name="dec_factor", var_name="dec_level")
tmp = tmp.melt(
id_vars=id_vars, value_name="dec_factor", var_name="dec_level"
)
sortby = ["survey", "station_id", "run_id", "start", "dec_level"]
tmp.sort_values(by=sortby, inplace=True)
tmp.reset_index(drop=True, inplace=True)
tmp.drop("sample_rate", axis=1, inplace=True) # not valid for decimated data
tmp.drop(
"sample_rate", axis=1, inplace=True
) # not valid for decimated data

# Add window info
group_by = [
Expand Down Expand Up @@ -305,7 +321,9 @@ def make_processing_summary(self):
num_samples_window=row.num_samples_window,
num_samples_overlap=row.num_samples_overlap,
)
num_windows[i] = ws.available_number_of_windows(row.num_samples)
num_windows[i] = ws.available_number_of_windows(
row.num_samples
)
df["num_stft_windows"] = num_windows
groups.append(df)

Expand Down Expand Up @@ -347,7 +365,8 @@ def validate_decimation_scheme_and_dataset_compatability(
for x in self.processing_config.decimations
}
min_stft_window_list = [
min_stft_window_info[x] for x in self.processing_summary.dec_level
min_stft_window_info[x]
for x in self.processing_summary.dec_level
]
min_num_stft_windows = pd.Series(min_stft_window_list)

Expand All @@ -371,7 +390,9 @@ def validate_processing(self):
self.config.drop_reference_channels()
for decimation in self.config.decimations:
if decimation.estimator.engine == "RME_RR":
logger.info("No RR station specified, switching RME_RR to RME")
logger.info(
"No RR station specified, switching RME_RR to RME"
)
decimation.estimator.engine = "RME"

# Make sure that a local station is defined
Expand Down Expand Up @@ -399,7 +420,9 @@ def valid_decimations(self):
valid_levels = tmp.dec_level.unique()

dec_levels = [x for x in self.config.decimations]
dec_levels = [x for x in dec_levels if x.decimation.level in valid_levels]
dec_levels = [
x for x in dec_levels if x.decimation.level in valid_levels
]
msg = f"After validation there are {len(dec_levels)} valid decimation levels"
logger.info(msg)
return dec_levels
Expand All @@ -412,7 +435,9 @@ def validate_save_fc_settings(self):
# if dec_level_config.save_fcs:
dec_level_config.save_fcs = False
if self.config.stations.remote:
save_any_fcs = np.array([x.save_fcs for x in self.config.decimations]).any()
save_any_fcs = np.array(
[x.save_fcs for x in self.config.decimations]
).any()
if save_any_fcs:
msg = "\n Saving FCs for remote reference processing is not supported"
msg = f"{msg} \n - To save FCs, process as single station, then you can use the FCs for RR processing"
Expand Down Expand Up @@ -521,17 +546,27 @@ def make_decimation_dict_for_tf(tf_collection, processing_config):
-------
"""
from mt_metadata.transfer_functions.io.zfiles.zmm import PERIOD_FORMAT
from mt_metadata.transfer_functions.io.zfiles.zmm import (
PERIOD_FORMAT,
)

decimation_dict = {}

for i_dec, dec_level_cfg in enumerate(processing_config.decimations):
for i_dec, dec_level_cfg in enumerate(
processing_config.decimations
):
for i_band, band in enumerate(dec_level_cfg.bands):
period_key = f"{band.center_period:{PERIOD_FORMAT}}"
period_value = {}
period_value["level"] = i_dec + 1 # +1 to match EMTF standard
period_value["bands"] = tuple(band.harmonic_indices[np.r_[0, -1]])
period_value["sample_rate"] = dec_level_cfg.sample_rate_decimation
period_value["level"] = (
i_dec + 1
) # +1 to match EMTF standard
period_value["bands"] = tuple(
band.harmonic_indices[np.r_[0, -1]]
)
period_value[
"sample_rate"
] = dec_level_cfg.sample_rate_decimation
try:
period_value["npts"] = tf_collection.tf_dict[
i_dec
Expand Down Expand Up @@ -561,20 +596,30 @@ def make_decimation_dict_for_tf(tf_collection, processing_config):
tf_cls.transfer_function = tmp

isp = merged_tf_dict["cov_ss_inv"]
renamer_dict = {"input_channel_1": "input", "input_channel_2": "output"}
renamer_dict = {
"input_channel_1": "input",
"input_channel_2": "output",
}
isp = isp.rename(renamer_dict)
tf_cls.inverse_signal_power = isp

res_cov = merged_tf_dict["cov_nn"]
renamer_dict = {"output_channel_1": "input", "output_channel_2": "output"}
renamer_dict = {
"output_channel_1": "input",
"output_channel_2": "output",
}
res_cov = res_cov.rename(renamer_dict)
tf_cls.residual_covariance = res_cov

# Set key as first el't of dict, nor currently supporting mixed surveys in TF
tf_cls.survey_metadata = self.dataset.local_survey_metadata
tf_cls.station_metadata.transfer_function.processing_type = self.processing_type
tf_cls.station_metadata.transfer_function.processing_config = self.processing_config

tf_cls.station_metadata.transfer_function.processing_type = (
self.processing_type
)
tf_cls.station_metadata.transfer_function.processing_config = (
self.processing_config
)

return tf_cls

def memory_warning(self):
Expand All @@ -599,7 +644,9 @@ def memory_warning(self):
num_samples = self.dataset_df.duration * self.dataset_df.sample_rate
total_samples = num_samples.sum()
total_bytes = total_samples * bytes_per_sample
logger.info(f"Total Bytes of Raw Data: {total_bytes / (1024 ** 3):.3f} GB")
logger.info(
f"Total Bytes of Raw Data: {total_bytes / (1024 ** 3):.3f} GB"
)

ram_fraction = 1.0 * total_bytes / total_memory
logger.info(f"Raw Data will use: {100 * ram_fraction:.3f} % of memory")
Expand Down Expand Up @@ -655,7 +702,9 @@ def mth5_has_fcs(m, survey_id, station_id, run_id, remote, processing_config):
return False

if len(fc_group.groups_list) < processing_config.num_decimation_levels:
msg = f"Not enough FC Groups found for {row_ssr_str} -- will build them"
msg = (
f"Not enough FC Groups found for {row_ssr_str} -- will build them"
)
return False

# Can check time periods here if desired, but unique (survey, station, run) should make this unneeded
Expand Down
Loading

0 comments on commit 103c554

Please sign in to comment.