Skip to content

Commit

Permalink
invert order of class definition for typehinting
Browse files Browse the repository at this point in the history
  • Loading branch information
kkappler committed Jan 19, 2025
1 parent f99adac commit 97e1278
Showing 1 changed file with 191 additions and 190 deletions.
381 changes: 191 additions & 190 deletions mth5/groups/fourier_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,196 +94,6 @@ def remove_fc_group(self, fc_name: str) -> None:
self._remove_group(fc_name)


class FCGroup(BaseGroup):
"""
Holds a set of Fourier Coefficients based on a single set of configuration
parameters. A processing run.
.. note:: Must be calibrated FCs. Otherwise weird things will happen, can
always rerun the FC estimation if the metadata changes.
Metadata should include:
- list of decimation levels
- start time (earliest)
- end time (latest)
- method (fft, wavelet, ...)
- list of channels (all inclusive)
- list of acquistion runs (maybe)
- starting sample rate
"""

def __init__(self, group, decimation_level_metadata=None, **kwargs):

super().__init__(
group, group_metadata=decimation_level_metadata, **kwargs
)

@BaseGroup.metadata.getter
def metadata(self) -> fc.Decimation:
"""Overwrite get metadata to include channel information in the runs"""

self._metadata.channels = []
for dl in self.groups_list:
dl_group = self.get_decimation_level(dl)
self._metadata.levels.append(dl_group.metadata)
self._metadata.hdf5_reference = self.hdf5_group.ref
return self._metadata

@property
def decimation_level_summary(self):
"""
summary of channels in run
:return: DESCRIPTION
:rtype: TYPE
"""

ch_list = []
for key, group in self.hdf5_group.items():
try:
ch_type = group.attrs["mth5_type"]
if ch_type in ["FCDecimation"]:
ch_list.append(
(
group.attrs["decimation_level"],
group.attrs["time_period.start"].split("+")[0],
group.attrs["time_period.end"].split("+")[0],
group.ref,
)
)
except KeyError as error:
self.logger.debug(f"Could not find key: {error}")
ch_summary = np.array(
ch_list,
dtype=np.dtype(
[
("component", "U20"),
("start", "datetime64[ns]"),
("end", "datetime64[ns]"),
("hdf5_reference", h5py.ref_dtype),
]
),
)

return pd.DataFrame(ch_summary)

def add_decimation_level(
self, decimation_level_name, decimation_level_metadata=None
): # TODO: FIXME NameError when output correctly dtyped-> FCDecimationGroup:
"""
add a Decimation level
:param decimation_level_name: DESCRIPTION
:type decimation_level_name: TYPE
:param decimation_level_metadata: DESCRIPTION, defaults to None
:type decimation_level_metadata: TYPE, optional
:return: DESCRIPTION
:rtype: TYPE
"""

return self._add_group(
decimation_level_name,
FCDecimationGroup,
group_metadata=decimation_level_metadata,
match="decimation_level",
)

def get_decimation_level(self, decimation_level_name: str) -> FCDecimationGroup:
"""
Get a Decimation Level
:param decimation_level_name: DESCRIPTION
:type decimation_level_name: TYPE
:return: DESCRIPTION
:rtype: TYPE
"""
return self._get_group(decimation_level_name, FCDecimationGroup)

def remove_decimation_level(self, decimation_level_name: str) -> None:
"""
Remove decimation level
:param decimation_level_name: DESCRIPTION
:type decimation_level_name: TYPE
:return: DESCRIPTION
:rtype: TYPE
"""

self._remove_group(decimation_level_name)

def update_metadata(self) -> None:
"""
update metadata from channels
:return: DESCRIPTION
:rtype: TYPE
"""
decimation_level_summary = self.decimation_level_summary.copy()
if not decimation_level_summary.empty:
self._metadata.time_period.start = (
decimation_level_summary.start.min().isoformat()
)
self._metadata.time_period.end = (
decimation_level_summary.end.max().isoformat()
)
self.write_metadata()

def supports_aurora_processing_config(
self, processing_config: 'aurora.config.metadata.processing.Processing', remote: bool
) -> bool:
"""
An "all-or-nothing" check: Return True if every (valid) decimation needed to satisfy the processing_config
is available in the FCGroup (self) otherwise return False (and we will build all FCs).
Logic:
1. Get a list of all fc groups in the FCGroup (self)
2. Loop the processing_config decimations, checking if there is a corresponding, already built FCDecimation
in the FCGroup.
Parameters
----------
processing_config: aurora.config.metadata.processing.Processing
remote: bool
Returns
-------
"""
pre_existing_fc_decimation_ids_to_check = self.groups_list
levels_present = np.full(processing_config.num_decimation_levels, False)

for i, aurora_decimation_level in enumerate(processing_config.decimations):

# Quit checking if dec_level wasn't there
if i > 0:
if not levels_present[i - 1]:
return False

# iterate over existing decimations
for fc_decimation_id in pre_existing_fc_decimation_ids_to_check:
fc_dec_group = self.get_decimation_level(fc_decimation_id)
fc_decimation = fc_dec_group.metadata
levels_present[i] = aurora_decimation_level.is_consistent_with_archived_fc_parameters(
fc_decimation=fc_decimation,
remote=remote
)
if levels_present[i]:
pre_existing_fc_decimation_ids_to_check.remove(
fc_decimation_id
) # no need to check this one again
break # break inner for-loop over decimations

return levels_present.all()


class FCDecimationGroup(BaseGroup):
"""
Holds a single decimation level
Expand Down Expand Up @@ -683,3 +493,194 @@ def add_feature(
**kwargs,
) -> None:
pass



class FCGroup(BaseGroup):
"""
Holds a set of Fourier Coefficients based on a single set of configuration
parameters. A processing run.
.. note:: Must be calibrated FCs. Otherwise weird things will happen, can
always rerun the FC estimation if the metadata changes.
Metadata should include:
- list of decimation levels
- start time (earliest)
- end time (latest)
- method (fft, wavelet, ...)
- list of channels (all inclusive)
- list of acquistion runs (maybe)
- starting sample rate
"""

def __init__(self, group, decimation_level_metadata=None, **kwargs):

super().__init__(
group, group_metadata=decimation_level_metadata, **kwargs
)

@BaseGroup.metadata.getter
def metadata(self) -> fc.Decimation:
"""Overwrite get metadata to include channel information in the runs"""

self._metadata.channels = []
for dl in self.groups_list:
dl_group = self.get_decimation_level(dl)
self._metadata.levels.append(dl_group.metadata)
self._metadata.hdf5_reference = self.hdf5_group.ref
return self._metadata

@property
def decimation_level_summary(self):
"""
summary of channels in run
:return: DESCRIPTION
:rtype: TYPE
"""

ch_list = []
for key, group in self.hdf5_group.items():
try:
ch_type = group.attrs["mth5_type"]
if ch_type in ["FCDecimation"]:
ch_list.append(
(
group.attrs["decimation_level"],
group.attrs["time_period.start"].split("+")[0],
group.attrs["time_period.end"].split("+")[0],
group.ref,
)
)
except KeyError as error:
self.logger.debug(f"Could not find key: {error}")
ch_summary = np.array(
ch_list,
dtype=np.dtype(
[
("component", "U20"),
("start", "datetime64[ns]"),
("end", "datetime64[ns]"),
("hdf5_reference", h5py.ref_dtype),
]
),
)

return pd.DataFrame(ch_summary)

def add_decimation_level(
self, decimation_level_name, decimation_level_metadata=None
): # TODO: FIXME NameError when output correctly dtyped-> FCDecimationGroup:
"""
add a Decimation level
:param decimation_level_name: DESCRIPTION
:type decimation_level_name: TYPE
:param decimation_level_metadata: DESCRIPTION, defaults to None
:type decimation_level_metadata: TYPE, optional
:return: DESCRIPTION
:rtype: TYPE
"""

return self._add_group(
decimation_level_name,
FCDecimationGroup,
group_metadata=decimation_level_metadata,
match="decimation_level",
)

def get_decimation_level(self, decimation_level_name: str) -> FCDecimationGroup:
"""
Get a Decimation Level
:param decimation_level_name: DESCRIPTION
:type decimation_level_name: TYPE
:return: DESCRIPTION
:rtype: TYPE
"""
return self._get_group(decimation_level_name, FCDecimationGroup)

def remove_decimation_level(self, decimation_level_name: str) -> None:
"""
Remove decimation level
:param decimation_level_name: DESCRIPTION
:type decimation_level_name: TYPE
:return: DESCRIPTION
:rtype: TYPE
"""

self._remove_group(decimation_level_name)

def update_metadata(self) -> None:
"""
update metadata from channels
:return: DESCRIPTION
:rtype: TYPE
"""
decimation_level_summary = self.decimation_level_summary.copy()
if not decimation_level_summary.empty:
self._metadata.time_period.start = (
decimation_level_summary.start.min().isoformat()
)
self._metadata.time_period.end = (
decimation_level_summary.end.max().isoformat()
)
self.write_metadata()

def supports_aurora_processing_config(
self, processing_config: 'aurora.config.metadata.processing.Processing', remote: bool
) -> bool:
"""
An "all-or-nothing" check: Return True if every (valid) decimation needed to satisfy the processing_config
is available in the FCGroup (self) otherwise return False (and we will build all FCs).
Logic:
1. Get a list of all fc groups in the FCGroup (self)
2. Loop the processing_config decimations, checking if there is a corresponding, already built FCDecimation
in the FCGroup.
Parameters
----------
processing_config: aurora.config.metadata.processing.Processing
remote: bool
Returns
-------
"""
pre_existing_fc_decimation_ids_to_check = self.groups_list
levels_present = np.full(processing_config.num_decimation_levels, False)

for i, aurora_decimation_level in enumerate(processing_config.decimations):

# Quit checking if dec_level wasn't there
if i > 0:
if not levels_present[i - 1]:
return False

# iterate over existing decimations
for fc_decimation_id in pre_existing_fc_decimation_ids_to_check:
fc_dec_group = self.get_decimation_level(fc_decimation_id)
fc_decimation = fc_dec_group.metadata
levels_present[i] = aurora_decimation_level.is_consistent_with_archived_fc_parameters(
fc_decimation=fc_decimation,
remote=remote
)
if levels_present[i]:
pre_existing_fc_decimation_ids_to_check.remove(
fc_decimation_id
) # no need to check this one again
break # break inner for-loop over decimations

return levels_present.all()

0 comments on commit 97e1278

Please sign in to comment.