Skip to content

Commit

Permalink
Update transfer_function_kernel.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kujaku11 committed Jan 17, 2025
1 parent 7e27599 commit bf85ce9
Showing 1 changed file with 52 additions and 23 deletions.
75 changes: 52 additions & 23 deletions aurora/pipelines/transfer_function_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,15 @@ def update_dataset_df(self, i_dec_level):
if not self.is_valid_dataset(row, i_dec_level):
continue
if row.fc:
row_ssr_str = (
f"survey: {row.survey}, station: {row.station}, run: {row.run}"
)
row_ssr_str = f"survey: {row.survey}, station: {row.station}, run: {row.run}"
msg = f"FC already exists for {row_ssr_str} -- skipping decimation"
logger.info(msg)
continue
run_xrds = row["run_dataarray"].to_dataset("channel")
decimation = self.config.decimations[i_dec_level].decimation
decimated_xrds = prototype_decimate(decimation, run_xrds)
self.dataset_df["run_dataarray"].at[i] = decimated_xrds.to_array(
"channel"
self.dataset_df["run_dataarray"].at[i] = (
decimated_xrds.to_array("channel")
) # See Note 1 above

logger.info(
Expand All @@ -159,7 +157,9 @@ def apply_clock_zero(self, dec_level_config):
The modified DecimationLevel with clock-zero information set.
"""
if dec_level_config.window.clock_zero_type == "data start":
dec_level_config.window.clock_zero = str(self.dataset_df.start.min())
dec_level_config.window.clock_zero = str(
self.dataset_df.start.min()
)
return dec_level_config

@property
Expand All @@ -169,6 +169,7 @@ def all_fcs_already_exist(self) -> bool:
self.check_if_fcs_already_exist()

# these should all be booleans now
print(self.kernel_dataset.df["fc"])
assert not self.kernel_dataset.df["fc"].isna().any()

return self.kernel_dataset.df.fc.all()
Expand Down Expand Up @@ -243,11 +244,12 @@ def check_if_fcs_already_exist(self):
msg += "Skip time series processing is OK"
else:
msg = f"Some, but not all fc_levels already exist = {self.dataset_df['fc']}"
logger.info(msg)
return True
else:
msg = "FC levels not present"
logger.info(msg)

return
logger.info(msg)
return False

def show_processing_summary(
self,
Expand Down Expand Up @@ -296,11 +298,15 @@ def make_processing_summary(self):
decimation_info = self.config.decimation_info()
for i_dec, dec_factor in decimation_info.items():
tmp[i_dec] = dec_factor
tmp = tmp.melt(id_vars=id_vars, value_name="dec_factor", var_name="dec_level")
tmp = tmp.melt(
id_vars=id_vars, value_name="dec_factor", var_name="dec_level"
)
sortby = ["survey", "station", "run", "start", "dec_level"]
tmp.sort_values(by=sortby, inplace=True)
tmp.reset_index(drop=True, inplace=True)
tmp.drop("sample_rate", axis=1, inplace=True) # not valid for decimated data
tmp.drop(
"sample_rate", axis=1, inplace=True
) # not valid for decimated data

# Add window info
group_by = [
Expand All @@ -317,7 +323,9 @@ def make_processing_summary(self):
cond = (df.dec_level.diff()[1:] == 1).all()
assert cond # dec levels increment by 1
except AssertionError:
msg = f"Skipping {group} because decimation levels are messy."
msg = (
f"Skipping {group} because decimation levels are messy."
)
logger.info(msg)
continue
assert df.dec_factor.iloc[0] == 1
Expand Down Expand Up @@ -382,7 +390,8 @@ def validate_decimation_scheme_and_dataset_compatability(
for x in self.processing_config.decimations
}
min_stft_window_list = [
min_stft_window_info[x] for x in self.processing_summary.dec_level
min_stft_window_info[x]
for x in self.processing_summary.dec_level
]
min_num_stft_windows = pd.Series(min_stft_window_list)

Expand All @@ -408,7 +417,9 @@ def validate_processing(self):
self.config.drop_reference_channels()
for decimation in self.config.decimations:
if decimation.estimator.engine == "RME_RR":
logger.info("No RR station specified, switching RME_RR to RME")
logger.info(
"No RR station specified, switching RME_RR to RME"
)
decimation.estimator.engine = "RME"

# Make sure that a local station is defined
Expand Down Expand Up @@ -440,7 +451,9 @@ def valid_decimations(self):
valid_levels = tmp.dec_level.unique()

dec_levels = [x for x in self.config.decimations]
dec_levels = [x for x in dec_levels if x.decimation.level in valid_levels]
dec_levels = [
x for x in dec_levels if x.decimation.level in valid_levels
]
msg = f"After validation there are {len(dec_levels)} valid decimation levels"
logger.info(msg)
return dec_levels
Expand All @@ -458,7 +471,9 @@ def validate_save_fc_settings(self):
# if dec_level_config.save_fcs:
dec_level_config.save_fcs = False
if self.config.stations.remote:
save_any_fcs = np.array([x.save_fcs for x in self.config.decimations]).any()
save_any_fcs = np.array(
[x.save_fcs for x in self.config.decimations]
).any()
if save_any_fcs:
msg = "\n Saving FCs for remote reference processing is not supported"
msg = f"{msg} \n - To save FCs, process as single station, then you can use the FCs for RR processing"
Expand Down Expand Up @@ -574,13 +589,21 @@ def make_decimation_dict_for_tf(tf_collection, processing_config):

decimation_dict = {}

for i_dec, dec_level_cfg in enumerate(processing_config.decimations):
for i_dec, dec_level_cfg in enumerate(
processing_config.decimations
):
for i_band, band in enumerate(dec_level_cfg.bands):
period_key = f"{band.center_period:{PERIOD_FORMAT}}"
period_value = {}
period_value["level"] = i_dec + 1 # +1 to match EMTF standard
period_value["bands"] = tuple(band.harmonic_indices[np.r_[0, -1]])
period_value["sample_rate"] = dec_level_cfg.sample_rate_decimation
period_value["level"] = (
i_dec + 1
) # +1 to match EMTF standard
period_value["bands"] = tuple(
band.harmonic_indices[np.r_[0, -1]]
)
period_value["sample_rate"] = (
dec_level_cfg.sample_rate_decimation
)
try:
period_value["npts"] = tf_collection.tf_dict[
i_dec
Expand Down Expand Up @@ -624,7 +647,9 @@ def make_decimation_dict_for_tf(tf_collection, processing_config):

# Set key as first el't of dict, nor currently supporting mixed surveys in TF
tf_cls.survey_metadata = self.dataset.local_survey_metadata
tf_cls.station_metadata.transfer_function.processing_type = self.processing_type
tf_cls.station_metadata.transfer_function.processing_type = (
self.processing_type
)
# tf_cls.station_metadata.transfer_function.processing_config = (
# self.processing_config
# )
Expand Down Expand Up @@ -655,7 +680,9 @@ def memory_check(self) -> None:
num_samples = self.dataset_df.duration * self.dataset_df.sample_rate
total_samples = num_samples.sum()
total_bytes = total_samples * bytes_per_sample
logger.info(f"Total Bytes of Raw Data: {total_bytes / (1024 ** 3):.3f} GB")
logger.info(
f"Total Bytes of Raw Data: {total_bytes / (1024 ** 3):.3f} GB"
)

ram_fraction = 1.0 * total_bytes / total_memory
logger.info(f"Raw Data will use: {100 * ram_fraction:.3f} % of memory")
Expand All @@ -676,7 +703,9 @@ def memory_check(self) -> None:


@path_or_mth5_object
def mth5_has_fcs(m, survey_id, station_id, run_id, remote, processing_config, **kwargs):
def mth5_has_fcs(
m, survey_id, station_id, run_id, remote, processing_config, **kwargs
):
"""
Checks if all needed fc-levels for survey-station-run are present under processing_config
Expand Down

0 comments on commit bf85ce9

Please sign in to comment.