Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for TKEO operation #3668

Open
jesusdpa1 opened this issue Feb 4, 2025 · 0 comments
Open

Adding support for TKEO operation #3668

jesusdpa1 opened this issue Feb 4, 2025 · 0 comments

Comments

@jesusdpa1
Copy link

Hi I have written this implementation of TKEO but I am not sure where this can be added for a pull request

from enum import Enum, auto

import numpy as np
from spikeinterface.core import get_chunk_with_margin
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.preprocessing.basepreprocessor import (
    BasePreprocessor,
    BasePreprocessorSegment,
)


class TKEOMethod(Enum):
    """Enumeration of TKEO calculation methods"""

    LI_2007 = auto()  # Li et al. 2007 (2 samples)
    DEBURCHGRAVE_2008 = auto()  # Deburchgrave et al. 2008 (4 samples)
    ORIGINAL = auto()  # Original Teager-Kaiser method


class TKEORecording(BasePreprocessor):
    def __init__(
        self,
        recording,
        margin_ms=5.0,
        dtype=None,
        tkeo_method=TKEOMethod.DEBURCHGRAVE_2008,
        add_reflect_padding=False,
    ):
        dtype = self._fix_dtype(recording, dtype)

        BasePreprocessor.__init__(self, recording, dtype=dtype)
        self.annotate(is_tkeo=True)

        if "offset_to_uV" in self.get_property_keys():
            self.set_channel_offsets(0)

        margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0)
        for parent_segment in recording._recording_segments:
            self.add_recording_segment(
                TKEORecordingSegment(
                    parent_segment,
                    margin,
                    dtype,
                    tkeo_method=tkeo_method,
                    add_reflect_padding=add_reflect_padding,
                )
            )

        self._kwargs = dict(
            recording=recording,
            margin_ms=margin_ms,
            dtype=dtype.str,
            tkeo_method=tkeo_method,
            add_reflect_padding=add_reflect_padding,
        )

    @staticmethod
    def _fix_dtype(recording, dtype):
        if dtype is None:
            dtype = recording.get_dtype()
        dtype = np.dtype(dtype)

        # if uint --> force int
        if dtype.kind == "u":
            dtype = np.dtype(dtype.str.replace("u", "i"))

        return dtype


class TKEORecordingSegment(BasePreprocessorSegment):
    def __init__(
        self,
        parent_recording_segment,
        margin,
        dtype,
        tkeo_method=TKEOMethod.DEBURCHGRAVE_2008,
        add_reflect_padding=False,
    ):
        BasePreprocessorSegment.__init__(self, parent_recording_segment)
        self.margin = margin
        self.add_reflect_padding = add_reflect_padding
        self.dtype = dtype
        self.tkeo_method = tkeo_method

    def get_traces(self, start_frame, end_frame, channel_indices):
        traces_chunk, left_margin, right_margin = get_chunk_with_margin(
            self.parent_recording_segment,
            start_frame,
            end_frame,
            channel_indices,
            self.margin,
            add_reflect_padding=self.add_reflect_padding,
        )

        # Apply TKEO with selected method
        tkeo_traces = self.apply_tkeo(traces_chunk)

        if right_margin > 0:
            tkeo_traces = tkeo_traces[left_margin:-right_margin, :]
        else:
            tkeo_traces = tkeo_traces[left_margin:, :]

        if np.issubdtype(self.dtype, np.integer):
            tkeo_traces = tkeo_traces.round()

        return tkeo_traces.astype(self.dtype)

    def apply_tkeo(self, traces):
        """
        Apply TKEO based on selected method

        Parameters:
        -----------
        traces : np.ndarray
            Input traces

        Returns:
        --------
        np.ndarray
            TKEO-transformed traces
        """
        if self.tkeo_method == TKEOMethod.LI_2007:
            # Li et al. 2007 method (2 samples)
            return np.abs(traces[:-2] * (traces[1:-1] ** 2 - traces[:-2] * traces[2:]))

        elif self.tkeo_method == TKEOMethod.DEBURCHGRAVE_2008:
            # Deburchgrave et al. 2008 method (4 samples)
            result = np.zeros_like(traces)
            result[2:-2] = traces[2:-2] * (
                traces[3:-1] ** 2 - traces[2:-2] * traces[4:]
            )
            return np.abs(result)

        elif self.tkeo_method == TKEOMethod.ORIGINAL:
            # Original Teager-Kaiser method
            result = np.zeros_like(traces)
            result[1:-1] = traces[1:-1] ** 2 - traces[:-2] * traces[2:]
            return np.abs(result)

        else:
            raise ValueError(f"Unknown TKEO method: {self.tkeo_method}")


tkeo_transform = define_function_from_class(
    source_class=TKEORecording, name="tkeo_transform"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant