Skip to content

Commit

Permalink
Merge pull request #312 from simpeg/fix_issue_31
Browse files Browse the repository at this point in the history
Fix issue 36
  • Loading branch information
kkappler authored Jan 11, 2024
2 parents 3f09400 + d8bc61d commit 618f267
Show file tree
Hide file tree
Showing 37 changed files with 9,193 additions and 4,337 deletions.
24 changes: 12 additions & 12 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ jobs:
python --version
conda install -c conda-forge pytest pytest-cov certifi">=2017.4.17" pandoc
pip install -r requirements-dev.txt
pip install git+https://github.com/kujaku11/mt_metadata.git
pip install git+https://github.com/kujaku11/mth5.git
pip install git+https://github.com/kujaku11/mt_metadata.git@main
pip install git+https://github.com/kujaku11/mth5.git@master
- name: Install Our Package
run: |
Expand All @@ -49,16 +49,16 @@ jobs:
python -m ipykernel install --user --name aurora-test
# Install any other dependencies you need
- name: Execute Jupyter Notebooks
run: |
jupyter nbconvert --to notebook --execute docs/examples/dataset_definition.ipynb
jupyter nbconvert --to notebook --execute docs/examples/make_cas04_single_station_h5.ipynb
jupyter nbconvert --to notebook --execute docs/examples/operate_aurora.ipynb
jupyter nbconvert --to notebook --execute tests/test_run_on_commit.ipynb
jupyter nbconvert --to notebook --execute tutorials/pole_zero_fitting/lemi_pole_zero_fitting_example.ipynb
jupyter nbconvert --to notebook --execute tutorials/processing_configuration.ipynb
jupyter nbconvert --to notebook --execute tutorials/synthetic_data_processing.ipynb
# Replace "notebook.ipynb" with your notebook's filename
# - name: Execute Jupyter Notebooks
# run: |
# jupyter nbconvert --to notebook --execute docs/examples/dataset_definition.ipynb
# jupyter nbconvert --to notebook --execute docs/examples/make_cas04_single_station_h5.ipynb
# jupyter nbconvert --to notebook --execute docs/examples/operate_aurora.ipynb
# jupyter nbconvert --to notebook --execute tests/test_run_on_commit.ipynb
# jupyter nbconvert --to notebook --execute tutorials/pole_zero_fitting/lemi_pole_zero_fitting_example.ipynb
# jupyter nbconvert --to notebook --execute tutorials/processing_configuration.ipynb
# jupyter nbconvert --to notebook --execute tutorials/synthetic_data_processing.ipynb
# # Replace "notebook.ipynb" with your notebook's filename

# - name: Commit changes (if any)
# run: |
Expand Down
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
.idea
data/figures
data/parkfield/aurora_results/
data/parkfield/*h5
data/synthetic/aurora_results/
data/synthetic/config/
tests/io/from_matlab.zss
*ignore*
*fix_issue*
aurora/sandbox/config
aurora/sandbox/data
tests/cas04/aurora_results/*
tests/cas04/CAS04*xml
tests/cas04/*ipynb
tests/cas04/config/*
tests/cas04/data/*
tests/io/*zrr
tests/parkfield/*.png
tests/parkfield/aurora_results/*
tests/parkfield/config/*backup.json
Expand Down
77 changes: 29 additions & 48 deletions aurora/pipelines/fourier_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,37 +55,25 @@
execute compute_transfer_function()
Questions:
1. Shouldn;t there be an experiment column in the channel_summary dataframe for a v0.2.0 file?
GROUPBY_COLUMNS = ["survey", "station", "sample_rate"]
If I use ["experiment", "survey", "station", "sample_rate"] instead (for a v0.2.0 file) encounter KeyError.
2. How to assign default values to Decimation.time_period?
Usually we will want to convert the entire run, so these should be assigned
during processing when we knwo the run extents. Thus the
"""
# =============================================================================
# Imports
# =============================================================================

import mt_metadata.timeseries.time_period

from aurora.pipelines.time_series_helpers import calibrate_stft_obj
from aurora.pipelines.time_series_helpers import prototype_decimate
from aurora.pipelines.time_series_helpers import run_ts_to_stft_scipy
from mth5.mth5 import MTH5
import mt_metadata.timeseries.time_period
from loguru import logger
from mth5.utils.helpers import path_or_mth5_object
from mt_metadata.transfer_functions.processing.fourier_coefficients import (
Decimation as FCDecimation,
)
from loguru import logger



# =============================================================================
FILE_VERSION = "you need to set this, and ideally cycle over 0.1.0, 0.2.0"
DEFAULT_TIME = "1980-01-01T00:00:00+00:00"
GROUPBY_COLUMNS = ["survey", "station", "sample_rate"] # See Question 1
GROUPBY_COLUMNS = ["survey", "station", "sample_rate"]


def decimation_and_stft_config_creator(
Expand All @@ -95,7 +83,7 @@ def decimation_and_stft_config_creator(
Based on the number of samples in the run, we can compute the maximum number of valid decimation levels.
This would re-use code in processing summary ... or we could just decimate until we cant anymore?
You can provide soemthing like: decimation_info = {0: 1.0, 1: 4.0, 2: 4.0, 3: 4.0}
You can provide something like: decimation_info = {0: 1.0, 1: 4.0, 2: 4.0, 3: 4.0}
Note 1: This does not yet work through the assignment of which bands to keep. Refer to
mt_metadata.transfer_functions.processing.Processing.assign_bands() to see how this was done in the past
Expand Down Expand Up @@ -133,38 +121,42 @@ def decimation_and_stft_config_creator(
dd.sample_rate_decimation = current_sample_rate

if time_period:
if isinstance(mt_metadata.timeseries.time_period.TimePeriod, time_period):
if isinstance(time_period, mt_metadata.timeseries.time_period.TimePeriod):
dd.time_period = time_period
else:
logger.info(f"Not sure how to assign time_period with {time_period}")
raise NotImplementedError
msg = (
f"Not sure how to assign time_period with type {type(time_period)}"
)
logger.info(msg)
raise NotImplementedError(msg)

decimation_and_stft_config.append(dd)

return decimation_and_stft_config


def add_fcs_to_mth5(mth5_path, decimation_and_stft_configs=None):
@path_or_mth5_object
def add_fcs_to_mth5(m, decimation_and_stft_configs=None):
"""
usssr_grouper: output of a groupby on unique {survey, station, sample_rate} tuples
Args:
mth5_path: str or pathlib.Path
Where the mth5 file is locatid
m: str or pathlib.Path, or MTH5 object
Where the mth5 file is located
decimation_and_stft_configs:
Returns:
"""
m = MTH5()
m.open_mth5(mth5_path)
channel_summary_df = m.channel_summary.to_dataframe()

usssr_grouper = channel_summary_df.groupby(GROUPBY_COLUMNS)
logger.debug(f"DETECTED {len(usssr_grouper)} unique station-sample_rate instances")
logger.debug(f"Detected {len(usssr_grouper)} unique station-sample_rate instances")

for (survey, station, sample_rate), usssr_group in usssr_grouper:
logger.info(f"\n\n\nsurvey: {survey}, station: {station}, sample_rate {sample_rate}")
logger.info(
f"\n\n\nsurvey: {survey}, station: {station}, sample_rate {sample_rate}"
)
station_obj = m.get_station(station, survey)
run_summary = station_obj.run_summary

Expand Down Expand Up @@ -210,7 +202,7 @@ def add_fcs_to_mth5(mth5_path, decimation_and_stft_configs=None):
if i_dec_level != 0:
# Apply decimation
run_xrds = prototype_decimate(decimation_stft_obj, run_xrds)
logger.info(f"type decimation_stft_obj = {type(decimation_stft_obj)}")
logger.debug(f"type decimation_stft_obj = {type(decimation_stft_obj)}")
if not decimation_stft_obj.is_valid_for_time_series_length(
run_xrds.time.shape[0]
):
Expand All @@ -222,17 +214,16 @@ def add_fcs_to_mth5(mth5_path, decimation_and_stft_configs=None):
stft_obj = run_ts_to_stft_scipy(decimation_stft_obj, run_xrds)
stft_obj = calibrate_stft_obj(stft_obj, run_obj)

# print("Pack FCs into h5 and update metadata")
# Pack FCs into h5 and update metadata
decimation_level = fc_group.add_decimation_level(f"{i_dec_level}")
decimation_level.from_xarray(stft_obj)
decimation_level.update_metadata()
fc_group.update_metadata()

m.close_mth5()
return


def read_back_fcs(mth5_path):
@path_or_mth5_object
def read_back_fcs(m):
"""
This is mostly a helper function for tests. It was used as a sanity check while debugging the FC files, and
also is a good example for how to access the data at each level for each channel.
Expand All @@ -241,14 +232,12 @@ def read_back_fcs(mth5_path):
(for now -- storing all fcs by default)
Args:
mth5_path: str or pathlib.Path
m: pathlib.Path, str or an MTH5 object
The path to an h5 file that we will scan the fcs from
Returns:
"""
m = MTH5()
m.open_mth5(mth5_path)
channel_summary_df = m.channel_summary.to_dataframe()
logger.debug(channel_summary_df)
usssr_grouper = channel_summary_df.groupby(GROUPBY_COLUMNS)
Expand All @@ -262,18 +251,10 @@ def read_back_fcs(mth5_path):
dec_level_ids = fc_group.groups_list
for dec_level_id in dec_level_ids:
dec_level = fc_group.get_decimation_level(dec_level_id)
logger.info(
f"dec_level {dec_level_id}"
) # channel_summary {dec_level.channel_summary}")
xrds = dec_level.to_xarray(["hx", "hy"])
logger.info(f"Time axis shape {xrds.time.data.shape}")
logger.info(f"Freq axis shape {xrds.frequency.data.shape}")
return True

msg = f"dec_level {dec_level_id}"
msg = f"{msg} \n Time axis shape {xrds.time.data.shape}"
msg = f"{msg} \n Freq axis shape {xrds.frequency.data.shape}"
logger.debug(msg)

def main():
pass


if __name__ == "__main__":
main()
return True
31 changes: 15 additions & 16 deletions aurora/pipelines/process_mth5.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from aurora.pipelines.time_series_helpers import run_ts_to_stft
from aurora.pipelines.transfer_function_helpers import process_transfer_functions
from aurora.pipelines.transfer_function_kernel import TransferFunctionKernel
from aurora.sandbox.triage_metadata import triage_run_id
from aurora.transfer_function.transfer_function_collection import (
TransferFunctionCollection,
)
Expand All @@ -39,9 +40,8 @@

# =============================================================================

def make_stft_objects(
processing_config, i_dec_level, run_obj, run_xrds, units, station_id
):

def make_stft_objects(processing_config, i_dec_level, run_obj, run_xrds, units="MT"):
"""
Operates on a "per-run" basis
Expand Down Expand Up @@ -73,11 +73,11 @@ def make_stft_objects(
stft_config = processing_config.get_decimation_level(i_dec_level)
stft_obj = run_ts_to_stft(stft_config, run_xrds)
run_id = run_obj.metadata.id
if station_id == processing_config.stations.local.id:
if run_obj.station_metadata.id == processing_config.stations.local.id:
scale_factors = processing_config.stations.local.run_dict[
run_id
].channel_scale_factors
elif station_id == processing_config.stations.remote[0].id:
elif run_obj.station_metadata.id == processing_config.stations.remote[0].id:
scale_factors = (
processing_config.stations.remote[0].run_dict[run_id].channel_scale_factors
)
Expand Down Expand Up @@ -122,6 +122,7 @@ def process_tf_decimation_level(
Processing pipeline for a single decimation_level
TODO: Add a check that the processing config sample rates agree with the data
TODO: Add units to local_stft_obj, remote_stft_obj
sampling rates otherwise raise Exception
This method can be single station or remote based on the process cfg
Expand All @@ -146,9 +147,9 @@ def process_tf_decimation_level(
"""
frequency_bands = config.decimations[i_dec_level].frequency_bands_obj()
transfer_function_obj = TTFZ(i_dec_level, frequency_bands, processing_config=config)

dec_level_config = config.decimations[i_dec_level]
transfer_function_obj = process_transfer_functions(
config, i_dec_level, local_stft_obj, remote_stft_obj, transfer_function_obj
dec_level_config, local_stft_obj, remote_stft_obj, transfer_function_obj
)

return transfer_function_obj
Expand Down Expand Up @@ -314,7 +315,7 @@ def process_mth5(
tfk.show_processing_summary()
tfk.validate()
# See Note #1
if config.decimations[0].save_fcs:
if tfk.config.decimations[0].save_fcs:
mth5_mode = "a"
else:
mth5_mode = "r"
Expand Down Expand Up @@ -358,18 +359,16 @@ def process_mth5(
continue

run_xrds = row["run_dataarray"].to_dataset("channel")
run_obj = row.mth5_obj.from_reference(row.run_reference)

# Musgraves workaround for old MT data
try:
assert row.run_id == run_obj.metadata.id
except AssertionError:
logger.warning("WARNING Run ID in dataset_df does not match run_obj")
logger.warning("WARNING Forcing run metadata to match dataset_df")
run_obj.metadata.id = row.run_id
triage_run_id(row.run_id, run_obj)

stft_obj = make_stft_objects(
tfk.config, i_dec_level, run_obj, run_xrds, units, row.station_id
tfk.config,
i_dec_level,
run_obj,
run_xrds,
units,
)
# Pack FCs into h5
save_fourier_coefficients(
Expand Down
24 changes: 18 additions & 6 deletions aurora/pipelines/time_series_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,29 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None
Time series of calibrated Fourier coefficients
"""
for channel_id in stft_obj.keys():
mth5_channel = run_obj.get_channel(channel_id)
channel_filter = mth5_channel.channel_response_filter
if not channel_filter.filters_list:

channel = run_obj.get_channel(channel_id)
channel_response = channel.channel_response
if not channel_response.filters_list:
msg = f"Channel {channel_id} with empty filters list detected"
logger.warning(msg)
if channel_id == "hy":
msg = "Channel hy has no filters, try using filters from hx"
logger.warning("Channel HY has no filters, try using filters from HX")
channel_filter = run_obj.get_channel("hx").channel_response_filter
calibration_response = channel_filter.complex_response(stft_obj.frequency.data)
channel_response = run_obj.get_channel("hx").channel_response

indices_to_flip = channel_response.get_indices_of_filters_to_remove(
include_decimation=False, include_delay=False
)
indices_to_flip = [
i for i in indices_to_flip if channel.metadata.filter.applied[i]
]
filters_to_remove = [channel_response.filters_list[i] for i in indices_to_flip]
if not filters_to_remove:
logger.warning("No filters to remove")
calibration_response = channel_response.complex_response(
stft_obj.frequency.data, filters_list=filters_to_remove
)
if channel_scale_factors:
try:
channel_scale_factor = channel_scale_factors[channel_id]
Expand All @@ -309,7 +322,6 @@ def calibrate_stft_obj(stft_obj, run_obj, units="MT", channel_scale_factors=None
calibration_response /= channel_scale_factor
if units == "SI":
logger.warning("Warning: SI Units are not robustly supported issue #36")

stft_obj[channel_id].data /= calibration_response
return stft_obj

Expand Down
Loading

0 comments on commit 618f267

Please sign in to comment.