Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
- in prep for doi relase,
- updated branches of mth5,mth_metadata to mains
- pretty up some logging messages
  • Loading branch information
kkappler committed Jan 11, 2024
1 parent 9fe5610 commit d8bc61d
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 12 deletions.
9 changes: 7 additions & 2 deletions aurora/pipelines/transfer_function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def process_transfer_functions(
transfer_function_obj,
segment_weights=None,
channel_weights=None,
# use_multiple_coherence_weights=False,
):
"""
This method based on TTFestBand.m
Expand Down Expand Up @@ -127,10 +128,12 @@ def process_transfer_functions(
-------
"""
# PUT COHERENCE SORTING HERE IF WIDE BAND?
estimator_class = get_estimator_class(dec_level_config.estimator.engine)
iter_control = set_up_iter_control(dec_level_config)
for band in transfer_function_obj.frequency_bands.bands():
iter_control = set_up_iter_control(dec_level_config)
# if use_multiple_coherence_weights:
# from aurora.transfer_function.weights.coherence_weights import compute_multiple_coherence_weights
# Wmc = compute_multiple_coherence_weights(band, local_stft_obj, remote_stft_obj)
X, Y, RR = get_band_for_tf_estimate(
band, dec_level_config, local_stft_obj, remote_stft_obj
)
Expand All @@ -143,6 +146,8 @@ def process_transfer_functions(
RR = RR.stack(observation=("frequency", "time"))

W = effective_degrees_of_freedom_weights(X, RR, edf_obj=None)
# if use_multiple_coherence_weights:
# W *= Wmc
W[W == 0] = np.nan # use this to drop values in the handle_nan
# apply weights
X *= W
Expand Down
10 changes: 4 additions & 6 deletions aurora/pipelines/transfer_function_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,11 @@ def make_processing_summary(self):
for group, df in grouper:
try:
try:
assert (
df.dec_level.diff()[1:] == 1
).all() # dec levels increment by 1
cond = (df.dec_level.diff()[1:] == 1).all()
assert cond # dec levels increment by 1
except AssertionError:
logger.info(
f"Skipping {group} because decimation levels are messy."
)
msg = f"Skipping {group} because decimation levels are messy."
logger.info(msg)

Check warning on line 323 in aurora/pipelines/transfer_function_kernel.py

View check run for this annotation

Codecov / codecov/patch

aurora/pipelines/transfer_function_kernel.py#L322-L323

Added lines #L322 - L323 were not covered by tests
continue
assert df.dec_factor.iloc[0] == 1
assert df.dec_level.iloc[0] == 0
Expand Down
4 changes: 3 additions & 1 deletion aurora/time_series/frequency_band_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def get_band_for_tf_estimate(band, dec_level_config, local_stft_obj, remote_stft
reference_channels and also the frequency axes are restricted to
being within the frequency band given as an input argument.
"""
logger.info(f"Processing band {band.center_period:.6f}s")
logger.info(
f"Processing band {band.center_period:.6f}s ({1./band.center_period:.6f}Hz)"
)
band_dataset = extract_band(band, local_stft_obj)
X = band_dataset[dec_level_config.input_channels]
Y = band_dataset[dec_level_config.output_channels]
Expand Down
7 changes: 4 additions & 3 deletions aurora/time_series/window_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
"""
import numpy as np
from numpy.lib.stride_tricks import as_strided
from numba import jit
import time

from loguru import logger
from numba import jit
from numpy.lib.stride_tricks import as_strided


# Window-to-timeseries relationshp
# Window-to-timeseries relationship
def available_number_of_windows_in_array(n_samples_array, n_samples_window, n_advance):
"""
Expand Down
10 changes: 10 additions & 0 deletions tests/synthetic/test_compare_aurora_vs_archived_emtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ def test_pipeline(merged=True):

tfk_dataset = KernelDataset()
tfk_dataset.from_run_summary(run_summary, "test2", "test1")
# Uncomment to sanity check the problem is linear
# scale_factors = {
# "ex": 20.0,
# "ey": 20.0,
# "hx": 20.0,
# "hy": 20.0,
# "hz": 20.0,
# }
# tfk_dataset.df["channel_scale_factors"].at[0] = scale_factors
# tfk_dataset.df["channel_scale_factors"].at[1] = scale_factors
run_test2r1(tfk_dataset)


Expand Down

0 comments on commit d8bc61d

Please sign in to comment.