Skip to content

Commit

Permalink
Merge pull request #277 from kujaku11/cursor_001
Browse files Browse the repository at this point in the history
Cursor 001
  • Loading branch information
kkappler authored Jan 13, 2025
2 parents 8e2c85c + 3b71ba6 commit dfe98a0
Show file tree
Hide file tree
Showing 4 changed files with 499 additions and 103 deletions.
180 changes: 169 additions & 11 deletions mth5/timeseries/spectre/spectrogram.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""
WORK IN PROGRESS (WIP): This module contains a class that represents a spectrogram,
i.e. A 2D time series of Fourier coefficients with axes time and frequency.
Module contains a class that represents a spectrogram,
i.e. A 2D time series of Fourier coefficients with axes time and frequency.
"""
# Standard library imports
from typing import List, Optional, Tuple, Union

from mt_metadata.transfer_functions.processing.aurora.band import Band
from typing import Optional, Union

# Third-party imports
import numpy as np
import pandas as pd
import xarray as xr

# Local imports
from mt_metadata.transfer_functions.processing.aurora.band import Band
from mt_metadata.transfer_functions.processing.aurora.frequency_bands import FrequencyBands
from mth5.timeseries.xarray_helpers import covariance_xr, initialize_xrda_2d


class Spectrogram(object):
"""
Expand All @@ -26,6 +32,7 @@ def __init__(
"""Constructor"""
self._dataset = dataset
self._frequency_increment = None
self._frequency_band = None

def _lowest_frequency(self): # -> float:
pass # return self.dataset.frequency.min
Expand Down Expand Up @@ -75,6 +82,22 @@ def time_axis(self):
"""returns the time axis of the underlying xarray"""
return self.dataset.time

@property
def frequency_axis(self):
"""returns the frequency axis of the underlying xarray"""
return self.dataset.frequency

@property
def frequency_band(self) -> Band:
""" returns a frequency band object representing the spectrograms band (assumes continuous)"""
if self._frequency_band is None:
band = Band(
frequency_min=self.frequency_axis.min().item(),
frequency_max=self.frequency_axis.max().item()
)
self._frequency_band = band
return self._frequency_band

@property
def frequency_increment(self):
"""
Expand Down Expand Up @@ -109,8 +132,6 @@ def extract_band(self, frequency_band, channels=[]):
"""
Returns another instance of Spectrogram, with the frequency axis reduced to the input band.
TODO: Consider returning a copy of the data...
Parameters
----------
frequency_band
Expand All @@ -128,12 +149,149 @@ def extract_band(self, frequency_band, channels=[]):
channels=channels,
epsilon=self.frequency_increment / 2.0,
)
# Drop NaN values along the frequency dimension
# extracted_band_dataset = extracted_band_dataset.dropna(dim='frequency', how='any')
spectrogram = Spectrogram(dataset=extracted_band_dataset)
return spectrogram

# TODO: Add cross power method
# def cross_powers(self, ch1, ch2, band=None):
# pass
def cross_power_label(self, ch1: str, ch2: str, join_char: str = "_"):
""" joins channel names with join_char"""
return f"{ch1}{join_char}{ch2}"

def _validate_frequency_bands(
self,
frequency_bands: FrequencyBands,
strict: bool = True,
):
"""
Make sure that the frequency bands passed are relevant. If not, drop and warn.
:param frequency_bands: A collection of bands
:type frequency_bands: FrequencyBands
:param strict: If true, band must be contained to be valid, if false, any overlapping band is valid.
:type strict: bool
:return:
"""
if strict:
valid_bands = [x for x in frequency_bands.bands() if self.frequency_band.contains(x)]
else:
valid_bands = [x for x in frequency_bands.bands() if self.frequency_band.overlaps(x)]
lower_bounds = [x.lower_bound for x in valid_bands]
upper_bounds = [x.upper_bound for x in valid_bands]
valid_frequency_bands = FrequencyBands(
pd.DataFrame(data={
"lower_bound": lower_bounds,
"upper_bound": upper_bounds,
})
)

# TODO: If strict, only take bands that are contained
return valid_frequency_bands

def cross_powers(self, frequency_bands, channel_pairs=None):
"""
Compute cross powers between channel pairs for given frequency bands.
TODO: Add handling for case when band in frequency_bands is not contained
in self.frequencies.
Parameters
----------
frequency_bands : FrequencyBands
The frequency bands to compute cross powers for.
channel_pairs : list of tuples, optional
List of channel pairs to compute cross powers for.
If None, all possible pairs will be used.
Returns
-------
xr.Dataset
Dataset containing cross powers for all channel pairs.
Each variable is named by the channel pair (e.g. 'ex_hy')
and contains a 2D array with dimensions (frequency, time).
All variables share common frequency and time coordinates.
"""
from itertools import combinations_with_replacement

valid_frequency_bands = self._validate_frequency_bands(frequency_bands)
# If no channel pairs specified, use all possible pairs
if channel_pairs is None:
channels = list(self.dataset.data_vars.keys())
channel_pairs = list(combinations_with_replacement(channels, 2))

# Create variable names from channel pairs
var_names = [self.cross_power_label(ch1, ch2) for ch1, ch2 in channel_pairs]

# Initialize a single multi-channel 2D xarray
xpower_array = initialize_xrda_2d(
var_names,
coords={'frequency': frequency_bands.band_centers(),
'time': self.dataset.time.values},
dtype=complex
)

# Compute cross powers for each band and channel pair
for band in valid_frequency_bands.bands():
# Extract band data
band_data = self.extract_band(band).dataset

# Compute cross powers for each channel pair
for ch1, ch2 in channel_pairs:
label = self.cross_power_label(ch1, ch2)
# Always compute as ch1 * conj(ch2)
xpower = (band_data[ch1] * band_data[ch2].conj()).mean(dim='frequency')

# Store the cross power
xpower_array.loc[dict(frequency=band.center_frequency, variable=label, time=slice(None))] = xpower

return xpower_array

def covariance_matrix(
self,
band_data: Optional['Spectrogram'] = None,
method: str = "numpy_cov"
) -> xr.DataArray:
"""
TODO: Add tests for this WIP Work-in-progress method
Compute full covariance matrix for spectrogram data.
For complex-valued data, the result is a Hermitian matrix where:
- diagonal elements are real-valued variances
- off-diagonal element [i,j] is E[ch_i * conj(ch_j)]
- off-diagonal element [j,i] is the complex conjugate of [i,j]
Parameters
----------
band_data : Spectrogram, optional
If provided, compute covariance for this data
If None, use the full spectrogram
method : str
Computation method. Currently only supports 'numpy_cov'
Returns
-------
xr.DataArray
Hermitian covariance matrix with proper channel labeling
For channels i,j: matrix[i,j] = E[ch_i * conj(ch_j)]
"""
data = band_data or self
flat_data = data.flatten(chunk_by="time")

if method == "numpy_cov":
# Convert to DataArray for covariance_xr
stacked = flat_data.to_array(dim="variable")
return covariance_xr(stacked)
else:
raise ValueError(f"Unknown method: {method}")

def _get_all_channel_pairs(self) -> List[Tuple[str, str]]:
"""Get all unique channel pairs (upper triangle)"""
channels = list(self.dataset.data_vars.keys())
pairs = []
for i, ch1 in enumerate(channels[:-1]):
for ch2 in channels[i+1:]:
pairs.append((ch1, ch2))
return pairs

def flatten(self, chunk_by: Optional[str] = "time") -> xr.Dataset:
"""
Expand Down
81 changes: 59 additions & 22 deletions mth5/timeseries/xarray_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,45 +78,82 @@ def initialize_xrda_1d(
return xrda


def initialize_xrda_2d(
channels: list,
def initialize_xrds_2d(
variables: list,
coords: dict,
dtype: Optional[type] = complex,
value: Optional[Union[complex, float, bool]] = 0,
) -> xr.DataArray:
) -> xr.Dataset:
"""
Returns a 2D xr.DataArray with dimensions channel_1 and channel_2.
Returns a 2D xr.Dataset with the given variables and coordinates.
Parameters
----------
channels: list
The channels in the multivariate array
variables: list
List of variable names to create in the dataset
coords: dict
the coordinates of the data array to return.
Dictionary of coordinates for the dataset dimensions
dtype: type, optional
The datatype to initialize the array.
The datatype to initialize the arrays.
Common cases are complex, float, and bool
value: Union[complex, float, bool], optional
The default value to assign the array
The default value to assign the arrays
Returns
-------
xrda: xarray.core.dataarray.DataArray
An xarray container for the channel variances etc., initialized to zeros.
xrds: xr.Dataset
A 2D xarray Dataset with dimensions from coords
"""
# Get dimensions from coords
dims = list(coords.keys())
K1 = len(coords[dims[0]])
K2 = len(coords[dims[1]])
xrda = xr.DataArray(
np.zeros((K1, K2), dtype=dtype),
dims=dims,
coords=coords
)
if value != 0:
data = value * np.ones(xrda.shape, dtype=dtype)
xrda.data = data
shape = tuple(len(v) for v in coords.values())

# Initialize empty dataset
xrds = xr.Dataset(coords=coords)

# Add each variable
for var in variables:
if value == 0:
data = np.zeros(shape, dtype=dtype)
else:
data = value * np.ones(shape, dtype=dtype)

xrds[var] = xr.DataArray(
data,
dims=dims,
coords=coords
)

return xrds


def initialize_xrda_2d(variables, coords, dtype=complex, value=0):
"""Initialize a 3D xarray DataArray with dimensions from coords plus 'variable'.
return xrda
Parameters
----------
variables : list
List of variable names for the additional dimension.
coords : dict
Dictionary of coordinates for the dataset dimensions.
dtype : type, optional
Data type for the array, by default complex.
value : int or float, optional
Value to initialize the array with, by default 0.
Returns
-------
xr.DataArray
A 3D DataArray with dimensions from coords plus 'variable'.
"""
# Create Dataset first
ds = initialize_xrds_2d(variables, coords, dtype, value)

# Convert to DataArray with original dimension order plus 'variable'
dims = list(coords.keys())
da = ds.to_array(dim='variable').transpose(*dims, 'variable')

return da


def initialize_xrda_2d_cov(
Expand Down
Loading

0 comments on commit dfe98a0

Please sign in to comment.