From 62e4fa98b0f4fc65b6e05b1c86b191c31ff237d0 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 3 Dec 2024 14:37:00 +0100 Subject: [PATCH 001/118] add file_base module and class --- src/OSmOSE/file_base.py | 63 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/OSmOSE/file_base.py diff --git a/src/OSmOSE/file_base.py b/src/OSmOSE/file_base.py new file mode 100644 index 00000000..1dfc8a91 --- /dev/null +++ b/src/OSmOSE/file_base.py @@ -0,0 +1,63 @@ +"""FileBase: Base class for the File objects (e.g. AudioFile), which associated timestamps with file-written data.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from os import PathLike + + import numpy as np + +from pathlib import Path + +from pandas import Timestamp + +from OSmOSE.utils.timestamp_utils import strptime_from_text + + +class FileBase: + """Base class for the File objects (e.g. AudioFile), which associated timestamps with file-written data.""" + + def __init__(self, path: PathLike | str, begin: Timestamp | None = None, end: Timestamp | None = None, strptime_format: str | None = None) -> None: + """Initialize a File object with a path and a begin timestamp. + + The begin timestamp can either be provided as a parameter or parsed from the filename according to the provided stroptime_format. + + Parameters + ---------- + path: PathLike | str + Full path to the file. + begin: pandas.Timestamp | None + Timestamp corresponding to the first data point in the file. + If it is not provided, strptime_format is mandatory. + If both begin and strptime_format are provided, begin will overrule the timestamp embedded in the filename. + end: pandas.Timestamp | None + (Optional) Timestamp after the last data point in the file. + strptime_format: str | None + The strptime format used in the text. + It should use valid strftime codes (https://strftime.org/). + Example: '%y%m%d_%H:%M:%S'. + + """ + self.path = Path(path) + + if begin is None and strptime_format is None: + raise ValueError("Either begin or strptime_format must be specified") + + self.begin = begin if begin is not None else strptime_from_text(text = self.path.name, datetime_template=strptime_format) + + def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: + """Return the data that is between start and stop from the file. + + Parameters + ---------- + start: pandas.Timestamp + Timestamp corresponding to the first data point to read. + stop: pandas.Timestamp + Timestamp after the last data point to read. + + Returns + ------- + The data between start and stop. + + """ From cab77b1c7de4b7a3a3c9531177851d9e95cb4d98 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 3 Dec 2024 14:43:45 +0100 Subject: [PATCH 002/118] add AudioFile class --- src/OSmOSE/audio_file.py | 61 ++++++++++++++++++++++++++++++++++++++++ src/OSmOSE/file_base.py | 8 +++--- 2 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 src/OSmOSE/audio_file.py diff --git a/src/OSmOSE/audio_file.py b/src/OSmOSE/audio_file.py new file mode 100644 index 00000000..bcefb61b --- /dev/null +++ b/src/OSmOSE/audio_file.py @@ -0,0 +1,61 @@ +""".""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from os import PathLike + + import numpy as np +import soundfile as sf +from pandas import Timedelta, Timestamp + +from OSmOSE.file_base import FileBase + + +class AudioFile(FileBase): + """Audio file associated with timestamps.""" + + def __init__(self, path: PathLike | str, begin: Timestamp | None = None, strptime_format: str | None = None) -> None: + """Initialize an AudioFile object with a path and a begin timestamp. + + The begin timestamp can either be provided as a parameter or parsed from the filename according to the provided strptime_format. + + Parameters + ---------- + path: PathLike | str + Full path to the file. + begin: pandas.Timestamp | None + Timestamp corresponding to the first data point in the file. + If it is not provided, strptime_format is mandatory. + If both begin and strptime_format are provided, begin will overrule the timestamp embedded in the filename. + strptime_format: str | None + The strptime format used in the text. + It should use valid strftime codes (https://strftime.org/). + Example: '%y%m%d_%H:%M:%S'. + + """ + super().__init__(path, begin, strptime_format) + self.metadata = sf.info(path) + self.end = self.begin + Timedelta(seconds = self.metadata.duration) + + def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: + """Return the audio data between start and stop from the file. + + Parameters + ---------- + start: pandas.Timestamp + Timestamp corresponding to the first data point to read. + stop: pandas.Timestamp + Timestamp after the last data point to read. + + Returns + ------- + numpy.ndarray: + The audio data between start and stop. + + """ + sample_rate = self.metadata.samplerate + start_sample = round((start-self.begin).total_seconds() * sample_rate) + stop_sample = round((stop-self.begin).total_seconds() * sample_rate) + return sf.read(self.path, start=start_sample, stop=stop_sample)[0] diff --git a/src/OSmOSE/file_base.py b/src/OSmOSE/file_base.py index 1dfc8a91..75824b6c 100644 --- a/src/OSmOSE/file_base.py +++ b/src/OSmOSE/file_base.py @@ -7,11 +7,10 @@ from os import PathLike import numpy as np + from pandas import Timestamp from pathlib import Path -from pandas import Timestamp - from OSmOSE.utils.timestamp_utils import strptime_from_text @@ -19,9 +18,9 @@ class FileBase: """Base class for the File objects (e.g. AudioFile), which associated timestamps with file-written data.""" def __init__(self, path: PathLike | str, begin: Timestamp | None = None, end: Timestamp | None = None, strptime_format: str | None = None) -> None: - """Initialize a File object with a path and a begin timestamp. + """Initialize a File object with a path and timestamps. - The begin timestamp can either be provided as a parameter or parsed from the filename according to the provided stroptime_format. + The begin timestamp can either be provided as a parameter or parsed from the filename according to the provided strptime_format. Parameters ---------- @@ -45,6 +44,7 @@ def __init__(self, path: PathLike | str, begin: Timestamp | None = None, end: Ti raise ValueError("Either begin or strptime_format must be specified") self.begin = begin if begin is not None else strptime_from_text(text = self.path.name, datetime_template=strptime_format) + self.end = end if end is not None else self.begin def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the data that is between start and stop from the file. From f7778236851df7047cb64b80148d6d7154ff4e7b Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 3 Dec 2024 17:12:17 +0100 Subject: [PATCH 003/118] add tests for AudioFile class --- src/OSmOSE/audio_file.py | 16 ++++-- src/OSmOSE/config.py | 1 + src/OSmOSE/file_base.py | 17 ++++++- tests/conftest.py | 33 +++++++++++- tests/test_audio.py | 105 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 164 insertions(+), 8 deletions(-) create mode 100644 tests/test_audio.py diff --git a/src/OSmOSE/audio_file.py b/src/OSmOSE/audio_file.py index bcefb61b..d4b4dfd9 100644 --- a/src/OSmOSE/audio_file.py +++ b/src/OSmOSE/audio_file.py @@ -1,4 +1,5 @@ """.""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -16,7 +17,12 @@ class AudioFile(FileBase): """Audio file associated with timestamps.""" - def __init__(self, path: PathLike | str, begin: Timestamp | None = None, strptime_format: str | None = None) -> None: + def __init__( + self, + path: PathLike | str, + begin: Timestamp | None = None, + strptime_format: str | None = None, + ) -> None: """Initialize an AudioFile object with a path and a begin timestamp. The begin timestamp can either be provided as a parameter or parsed from the filename according to the provided strptime_format. @@ -35,9 +41,9 @@ def __init__(self, path: PathLike | str, begin: Timestamp | None = None, strptim Example: '%y%m%d_%H:%M:%S'. """ - super().__init__(path, begin, strptime_format) + super().__init__(path=path, begin=begin, strptime_format=strptime_format) self.metadata = sf.info(path) - self.end = self.begin + Timedelta(seconds = self.metadata.duration) + self.end = self.begin + Timedelta(seconds=self.metadata.duration) def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the audio data between start and stop from the file. @@ -56,6 +62,6 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """ sample_rate = self.metadata.samplerate - start_sample = round((start-self.begin).total_seconds() * sample_rate) - stop_sample = round((stop-self.begin).total_seconds() * sample_rate) + start_sample = round((start - self.begin).total_seconds() * sample_rate) + stop_sample = round((stop - self.begin).total_seconds() * sample_rate) return sf.read(self.path, start=start_sample, stop=stop_sample)[0] diff --git a/src/OSmOSE/config.py b/src/OSmOSE/config.py index a5a739a7..431c3a59 100755 --- a/src/OSmOSE/config.py +++ b/src/OSmOSE/config.py @@ -34,6 +34,7 @@ OSMOSE_PATH = namedtuple("path_list", __global_path_dict.keys())(**__global_path_dict) TIMESTAMP_FORMAT_AUDIO_FILE = "%Y-%m-%dT%H:%M:%S.%f%z" +TIMESTAMP_FORMAT_TEST_FILES = "%y%m%d%H%M%S" FPDEFAULT = 0o664 # Default file permissions DPDEFAULT = stat.S_ISGID | 0o775 # Default directory permissions diff --git a/src/OSmOSE/file_base.py b/src/OSmOSE/file_base.py index 75824b6c..b97cbf98 100644 --- a/src/OSmOSE/file_base.py +++ b/src/OSmOSE/file_base.py @@ -1,4 +1,5 @@ """FileBase: Base class for the File objects (e.g. AudioFile), which associated timestamps with file-written data.""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -17,7 +18,13 @@ class FileBase: """Base class for the File objects (e.g. AudioFile), which associated timestamps with file-written data.""" - def __init__(self, path: PathLike | str, begin: Timestamp | None = None, end: Timestamp | None = None, strptime_format: str | None = None) -> None: + def __init__( + self, + path: PathLike | str, + begin: Timestamp | None = None, + end: Timestamp | None = None, + strptime_format: str | None = None, + ) -> None: """Initialize a File object with a path and timestamps. The begin timestamp can either be provided as a parameter or parsed from the filename according to the provided strptime_format. @@ -43,7 +50,13 @@ def __init__(self, path: PathLike | str, begin: Timestamp | None = None, end: Ti if begin is None and strptime_format is None: raise ValueError("Either begin or strptime_format must be specified") - self.begin = begin if begin is not None else strptime_from_text(text = self.path.name, datetime_template=strptime_format) + self.begin = ( + begin + if begin is not None + else strptime_from_text( + text=self.path.name, datetime_template=strptime_format + ) + ) self.end = end if end is not None else self.begin def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: diff --git a/tests/conftest.py b/tests/conftest.py index 53c15a6d..9b26ba97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import shutil @@ -6,11 +8,40 @@ from unittest.mock import MagicMock import numpy as np +import pandas as pd import pytest import soundfile as sf from scipy.signal import chirp -from OSmOSE.config import OSMOSE_PATH +from OSmOSE.config import OSMOSE_PATH, TIMESTAMP_FORMAT_TEST_FILES + + +@pytest.fixture +def audio_files( + tmp_path: Path, request: pytest.fixtures.Subrequest, +) -> tuple[list[Path], pytest.fixtures.Subrequest]: + v_begin = 0.0 + v_end = 1.0 + nb_files = request.param.get("nb_files", 1) + sample_rate = request.param.get("sample_rate", 48_000) + duration = request.param.get("duration", 1.) + date_begin = request.param.get("date_begin", pd.Timestamp("2000-01-01 00:00:00")) + inter_file_duration = request.param.get("inter_file_duration", 0) + n_samples = int(round(duration * sample_rate)) + data = np.linspace(v_begin, v_end, n_samples) + files = [] + for begin_time in pd.date_range( + date_begin, + periods=nb_files, + freq=pd.Timedelta(seconds=duration + inter_file_duration), + ): + time_str = begin_time.strftime(format=TIMESTAMP_FORMAT_TEST_FILES) + file = tmp_path / f"audio_{time_str}.wav" + files.append(file) + sf.write( + file=file, data=data, samplerate=sample_rate, subtype="DOUBLE", + ) + return files, request @pytest.fixture(autouse=True) diff --git a/tests/test_audio.py b/tests/test_audio.py new file mode 100644 index 00000000..fa3e0b4a --- /dev/null +++ b/tests/test_audio.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from OSmOSE.audio_file import AudioFile +from OSmOSE.config import TIMESTAMP_FORMAT_TEST_FILES + + +@pytest.mark.parametrize( + "audio_files", + [ + pytest.param( + { + "duration": .05, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + id="basic_audio_file", + ), + pytest.param( + { + "duration": .06, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + id="different_duration", + ), + pytest.param( + { + "duration": .05, + "sample_rate": 44_100, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + id="different_sample_rate", + ), + ], + indirect=True, +) +def test_audio_file_timestamps( + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], +) -> None: + files, request = audio_files + duration = request.param["duration"] + date_begin = request.param["date_begin"] + + for file in files: + audio_file = AudioFile(file, strptime_format=TIMESTAMP_FORMAT_TEST_FILES) + + assert audio_file.begin == date_begin + assert audio_file.end == date_begin + pd.Timedelta(seconds=duration) + +@pytest.mark.parametrize( + ("audio_files", "start", "stop", "expected"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + pd.Timestamp("2024-01-01 12:00:00"), + pd.Timestamp("2024-01-01 12:00:01"), + np.linspace(0.,1.,48_000), + id="read_whole_file", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + pd.Timestamp("2024-01-01 12:00:00"), + pd.Timestamp(year = 2024, month=1, day=1, hour=12, minute=0, second=0, microsecond=100_000), + np.linspace(0.,1.,48_000)[:4_800], + id="read_begin_only", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + pd.Timestamp(year = 2024, month=1, day=1, hour=12, minute=0, second=0, microsecond=500_000), + pd.Timestamp(year = 2024, month=1, day=1, hour=12, minute=0, second=0, microsecond=600_000), + np.linspace(0.,1.,48_000)[24_000:28_800], + id="read_in_the_middle_of_the_file", + ), + ], + indirect=["audio_files"], +) +def test_audio_file_read( + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], start: pd.Timestamp, stop: pd.Timestamp, expected: np.ndarray) -> None: + files, request = audio_files + file = AudioFile(files[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) + assert np.array_equal(file.read(start, stop), expected) From 2b60db83e6a43d3534c266ab0278c2bd40c2ecef Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 3 Dec 2024 17:38:03 +0100 Subject: [PATCH 004/118] add test read_end_file --- tests/conftest.py | 10 +++++-- tests/test_audio.py | 69 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 66 insertions(+), 13 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9b26ba97..680f4ff4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,13 +18,14 @@ @pytest.fixture def audio_files( - tmp_path: Path, request: pytest.fixtures.Subrequest, + tmp_path: Path, + request: pytest.fixtures.Subrequest, ) -> tuple[list[Path], pytest.fixtures.Subrequest]: v_begin = 0.0 v_end = 1.0 nb_files = request.param.get("nb_files", 1) sample_rate = request.param.get("sample_rate", 48_000) - duration = request.param.get("duration", 1.) + duration = request.param.get("duration", 1.0) date_begin = request.param.get("date_begin", pd.Timestamp("2000-01-01 00:00:00")) inter_file_duration = request.param.get("inter_file_duration", 0) n_samples = int(round(duration * sample_rate)) @@ -39,7 +40,10 @@ def audio_files( file = tmp_path / f"audio_{time_str}.wav" files.append(file) sf.write( - file=file, data=data, samplerate=sample_rate, subtype="DOUBLE", + file=file, + data=data, + samplerate=sample_rate, + subtype="DOUBLE", ) return files, request diff --git a/tests/test_audio.py b/tests/test_audio.py index fa3e0b4a..828c9f2f 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -15,7 +15,7 @@ [ pytest.param( { - "duration": .05, + "duration": 0.05, "sample_rate": 48_000, "nb_files": 1, "date_begin": pd.Timestamp("2024-01-01 12:00:00"), @@ -24,7 +24,7 @@ ), pytest.param( { - "duration": .06, + "duration": 0.06, "sample_rate": 48_000, "nb_files": 1, "date_begin": pd.Timestamp("2024-01-01 12:00:00"), @@ -33,7 +33,7 @@ ), pytest.param( { - "duration": .05, + "duration": 0.05, "sample_rate": 44_100, "nb_files": 1, "date_begin": pd.Timestamp("2024-01-01 12:00:00"), @@ -56,6 +56,7 @@ def test_audio_file_timestamps( assert audio_file.begin == date_begin assert audio_file.end == date_begin + pd.Timedelta(seconds=duration) + @pytest.mark.parametrize( ("audio_files", "start", "stop", "expected"), [ @@ -68,7 +69,7 @@ def test_audio_file_timestamps( }, pd.Timestamp("2024-01-01 12:00:00"), pd.Timestamp("2024-01-01 12:00:01"), - np.linspace(0.,1.,48_000), + np.linspace(0.0, 1.0, 48_000), id="read_whole_file", ), pytest.param( @@ -79,8 +80,16 @@ def test_audio_file_timestamps( "date_begin": pd.Timestamp("2024-01-01 12:00:00"), }, pd.Timestamp("2024-01-01 12:00:00"), - pd.Timestamp(year = 2024, month=1, day=1, hour=12, minute=0, second=0, microsecond=100_000), - np.linspace(0.,1.,48_000)[:4_800], + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=100_000, + ), + np.linspace(0.0, 1.0, 48_000)[:4_800], id="read_begin_only", ), pytest.param( @@ -90,16 +99,56 @@ def test_audio_file_timestamps( "nb_files": 1, "date_begin": pd.Timestamp("2024-01-01 12:00:00"), }, - pd.Timestamp(year = 2024, month=1, day=1, hour=12, minute=0, second=0, microsecond=500_000), - pd.Timestamp(year = 2024, month=1, day=1, hour=12, minute=0, second=0, microsecond=600_000), - np.linspace(0.,1.,48_000)[24_000:28_800], + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=500_000, + ), + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=600_000, + ), + np.linspace(0.0, 1.0, 48_000)[24_000:28_800], id="read_in_the_middle_of_the_file", ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=900_000, + ), + pd.Timestamp("2024-01-01 12:00:01"), + np.linspace(0.0, 1.0, 48_000)[43_200:], + id="read_end_of_file", + ), ], indirect=["audio_files"], ) def test_audio_file_read( - audio_files: tuple[list[Path], pytest.fixtures.Subrequest], start: pd.Timestamp, stop: pd.Timestamp, expected: np.ndarray) -> None: + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + start: pd.Timestamp, + stop: pd.Timestamp, + expected: np.ndarray, +) -> None: files, request = audio_files file = AudioFile(files[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) assert np.array_equal(file.read(start, stop), expected) From b39326ed8aec646275b17361e9139cc8c82ad13b Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 4 Dec 2024 09:37:43 +0100 Subject: [PATCH 005/118] move data modules to new package --- src/OSmOSE/data/__init__.py | 0 src/OSmOSE/{ => data}/audio_file.py | 2 +- src/OSmOSE/{ => data}/file_base.py | 0 tests/test_audio.py | 7 +++++-- 4 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 src/OSmOSE/data/__init__.py rename src/OSmOSE/{ => data}/audio_file.py (98%) rename src/OSmOSE/{ => data}/file_base.py (100%) diff --git a/src/OSmOSE/data/__init__.py b/src/OSmOSE/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/OSmOSE/audio_file.py b/src/OSmOSE/data/audio_file.py similarity index 98% rename from src/OSmOSE/audio_file.py rename to src/OSmOSE/data/audio_file.py index d4b4dfd9..93880623 100644 --- a/src/OSmOSE/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -11,7 +11,7 @@ import soundfile as sf from pandas import Timedelta, Timestamp -from OSmOSE.file_base import FileBase +from OSmOSE.data.file_base import FileBase class AudioFile(FileBase): diff --git a/src/OSmOSE/file_base.py b/src/OSmOSE/data/file_base.py similarity index 100% rename from src/OSmOSE/file_base.py rename to src/OSmOSE/data/file_base.py diff --git a/tests/test_audio.py b/tests/test_audio.py index 828c9f2f..68409ea4 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -1,13 +1,16 @@ from __future__ import annotations -from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import pandas as pd import pytest -from OSmOSE.audio_file import AudioFile from OSmOSE.config import TIMESTAMP_FORMAT_TEST_FILES +from OSmOSE.data.audio_file import AudioFile + +if TYPE_CHECKING: + from pathlib import Path @pytest.mark.parametrize( From 3d32b3fda9c4779395188e6e84fe94c261b41028 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 4 Dec 2024 09:59:01 +0100 Subject: [PATCH 006/118] add item base and audio classes --- src/OSmOSE/data/audio_file.py | 2 +- src/OSmOSE/data/audio_item.py | 33 ++++++++++++++++++++++++++ src/OSmOSE/data/file_base.py | 14 ++++++++--- src/OSmOSE/data/item_base.py | 44 +++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+), 4 deletions(-) create mode 100644 src/OSmOSE/data/audio_item.py create mode 100644 src/OSmOSE/data/item_base.py diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index 93880623..0162ddcc 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -1,4 +1,4 @@ -""".""" +"""Audio file associated with timestamps.""" from __future__ import annotations diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py new file mode 100644 index 00000000..fe961300 --- /dev/null +++ b/src/OSmOSE/data/audio_item.py @@ -0,0 +1,33 @@ +"""AudioItem corresponding to a portion of an AudioFile object.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from OSmOSE.data.item_base import ItemBase + +if TYPE_CHECKING: + from pandas import Timestamp + + from OSmOSE.data.audio_file import AudioFile + + +class AudioItem(ItemBase): + """AudioItem corresponding to a portion of an AudioFile object.""" + + def __init__(self, file: AudioFile, begin: Timestamp | None = None, end: Timestamp | None = None) -> None: + """Initialize an AudioItem from an AudioItem and begin/end timestamps. + + Parameters + ---------- + file: OSmOSE.data.audio_file.AudioFile + The AudioFile in which this Item belongs. + begin: pandas.Timestamp (optional) + The timestamp at which this item begins. + It is defaulted to the AudioFile begin. + end: pandas.Timestamp (optional) + The timestamp at which this item ends. + It is defaulted to the AudioFile end. + + """ + super().__init__(file, begin, end) diff --git a/src/OSmOSE/data/file_base.py b/src/OSmOSE/data/file_base.py index b97cbf98..35f8b5bf 100644 --- a/src/OSmOSE/data/file_base.py +++ b/src/OSmOSE/data/file_base.py @@ -1,4 +1,7 @@ -"""FileBase: Base class for the File objects (e.g. AudioFile), which associated timestamps with file-written data.""" +"""FileBase: Base class for the File objects (e.g. AudioFile). + +A File object associates file-written data to timestamps. +""" from __future__ import annotations @@ -16,7 +19,10 @@ class FileBase: - """Base class for the File objects (e.g. AudioFile), which associated timestamps with file-written data.""" + """Base class for the File objects (e.g. AudioFile). + + A File object associates file-written data to timestamps. + """ def __init__( self, @@ -54,7 +60,7 @@ def __init__( begin if begin is not None else strptime_from_text( - text=self.path.name, datetime_template=strptime_format + text=self.path.name, datetime_template=strptime_format, ) ) self.end = end if end is not None else self.begin @@ -62,6 +68,8 @@ def __init__( def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the data that is between start and stop from the file. + This is an abstract method and should be overridden with actual implementations. + Parameters ---------- start: pandas.Timestamp diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py new file mode 100644 index 00000000..055ef924 --- /dev/null +++ b/src/OSmOSE/data/item_base.py @@ -0,0 +1,44 @@ +"""ItemBase: Base class for the Item objects (e.g. AudioItem). + +Items correspond to a portion of a File object. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np + from pandas import Timestamp + + from OSmOSE.data.file_base import FileBase + + +class ItemBase: + """Base class for the Item objects (e.g. AudioItem). + + An Item correspond to a portion of a File object. + """ + + def __init__(self, file: FileBase, begin: Timestamp | None = None, end: Timestamp | None = None) -> None: + """Initialize an ItemBase from a File and begin/end timestamps. + + Parameters + ---------- + file: OSmOSE.data.file_base.FileBase + The File in which this Item belongs. + begin: pandas.Timestamp (optional) + The timestamp at which this item begins. + It is defaulted to the File begin. + end: pandas.Timestamp (optional) + The timestamp at which this item ends. + It is defaulted to the File end. + + """ + self.file = file + self.begin = begin if begin is not None else self.file.begin + self.end = end if end is not None else self.file.end + + def get_value(self) -> np.ndarray: + """Get the values from the File between the begin and stop timestamps.""" + return self.file.read(start = self.begin, stop = self.end) From 03d414ba7bf7f936f94ab95268ddc563f99a7018 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 4 Dec 2024 10:37:40 +0100 Subject: [PATCH 007/118] add audio sample generation util --- src/OSmOSE/data/audio_item.py | 7 +++++- src/OSmOSE/data/file_base.py | 5 +++-- src/OSmOSE/data/item_base.py | 9 ++++++-- src/OSmOSE/utils/audio_utils.py | 40 +++++++++++++++++++++++++++++++++ tests/conftest.py | 23 +++++++++++-------- tests/test_audio.py | 9 ++++---- 6 files changed, 75 insertions(+), 18 deletions(-) diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index fe961300..4adb5402 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -15,7 +15,12 @@ class AudioItem(ItemBase): """AudioItem corresponding to a portion of an AudioFile object.""" - def __init__(self, file: AudioFile, begin: Timestamp | None = None, end: Timestamp | None = None) -> None: + def __init__( + self, + file: AudioFile, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> None: """Initialize an AudioItem from an AudioItem and begin/end timestamps. Parameters diff --git a/src/OSmOSE/data/file_base.py b/src/OSmOSE/data/file_base.py index 35f8b5bf..65f2c190 100644 --- a/src/OSmOSE/data/file_base.py +++ b/src/OSmOSE/data/file_base.py @@ -60,7 +60,8 @@ def __init__( begin if begin is not None else strptime_from_text( - text=self.path.name, datetime_template=strptime_format, + text=self.path.name, + datetime_template=strptime_format, ) ) self.end = end if end is not None else self.begin @@ -69,7 +70,7 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the data that is between start and stop from the file. This is an abstract method and should be overridden with actual implementations. - + Parameters ---------- start: pandas.Timestamp diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index 055ef924..150c7856 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -20,7 +20,12 @@ class ItemBase: An Item correspond to a portion of a File object. """ - def __init__(self, file: FileBase, begin: Timestamp | None = None, end: Timestamp | None = None) -> None: + def __init__( + self, + file: FileBase, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> None: """Initialize an ItemBase from a File and begin/end timestamps. Parameters @@ -41,4 +46,4 @@ def __init__(self, file: FileBase, begin: Timestamp | None = None, end: Timestam def get_value(self) -> np.ndarray: """Get the values from the File between the begin and stop timestamps.""" - return self.file.read(start = self.begin, stop = self.end) + return self.file.read(start=self.begin, stop=self.end) diff --git a/src/OSmOSE/utils/audio_utils.py b/src/OSmOSE/utils/audio_utils.py index fb40d8d7..e329c017 100644 --- a/src/OSmOSE/utils/audio_utils.py +++ b/src/OSmOSE/utils/audio_utils.py @@ -1,5 +1,9 @@ +from __future__ import annotations + from pathlib import Path +from typing import Literal +import numpy as np import pandas as pd from OSmOSE.config import ( @@ -118,3 +122,39 @@ def check_audio( ): message = "Your audio files have large duration discrepancies." raise ValueError(message) + +def generate_sample_audio( + nb_files: int, + nb_samples: int, + series_type: Literal["repeat", "increase"] = "repeat", + min_value: float = 0.0, + max_value: float = 1.0, +) -> list[np.ndarray]: + """Generate sample audio data. + + Parameters + ---------- + nb_files: int + Number of audio data to generate. + nb_samples: int + Number of samples per audio data. + series_type: Literal["repeat", "increase"] (Optional) + "repeat": audio data contain the same linear values from min to max. + "increase": audio data contain increasing values from min to max. + Defaults to "repeat". + min_value: float + Minimum value of the audio data. + max_value: float + Maximum value of the audio data. + + Returns + ------- + list[numpy.ndarray]: + The generated audio data. + + """ + if series_type == "repeat": + return np.split(np.tile(np.linspace(min_value, max_value, nb_samples), nb_files), nb_files) + if series_type == "increase": + return np.split(np.linspace(min_value, max_value, nb_samples * nb_files), nb_files) + return np.split(np.empty(nb_samples * nb_files), nb_files) diff --git a/tests/conftest.py b/tests/conftest.py index 680f4ff4..5f17729f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ from scipy.signal import chirp from OSmOSE.config import OSMOSE_PATH, TIMESTAMP_FORMAT_TEST_FILES +from OSmOSE.utils.audio_utils import generate_sample_audio @pytest.fixture @@ -21,27 +22,31 @@ def audio_files( tmp_path: Path, request: pytest.fixtures.Subrequest, ) -> tuple[list[Path], pytest.fixtures.Subrequest]: - v_begin = 0.0 - v_end = 1.0 nb_files = request.param.get("nb_files", 1) sample_rate = request.param.get("sample_rate", 48_000) duration = request.param.get("duration", 1.0) date_begin = request.param.get("date_begin", pd.Timestamp("2000-01-01 00:00:00")) inter_file_duration = request.param.get("inter_file_duration", 0) - n_samples = int(round(duration * sample_rate)) - data = np.linspace(v_begin, v_end, n_samples) + series_type = request.param.get("series_type", "repeat") + + nb_samples = int(round(duration * sample_rate)) + data = generate_sample_audio( + nb_files=nb_files, nb_samples=nb_samples, series_type=series_type, + ) files = [] - for begin_time in pd.date_range( - date_begin, - periods=nb_files, - freq=pd.Timedelta(seconds=duration + inter_file_duration), + for index, begin_time in enumerate( + pd.date_range( + date_begin, + periods=nb_files, + freq=pd.Timedelta(seconds=duration + inter_file_duration), + ), ): time_str = begin_time.strftime(format=TIMESTAMP_FORMAT_TEST_FILES) file = tmp_path / f"audio_{time_str}.wav" files.append(file) sf.write( file=file, - data=data, + data=data[index], samplerate=sample_rate, subtype="DOUBLE", ) diff --git a/tests/test_audio.py b/tests/test_audio.py index 68409ea4..10be1a30 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -8,6 +8,7 @@ from OSmOSE.config import TIMESTAMP_FORMAT_TEST_FILES from OSmOSE.data.audio_file import AudioFile +from OSmOSE.utils.audio_utils import generate_sample_audio if TYPE_CHECKING: from pathlib import Path @@ -72,7 +73,7 @@ def test_audio_file_timestamps( }, pd.Timestamp("2024-01-01 12:00:00"), pd.Timestamp("2024-01-01 12:00:01"), - np.linspace(0.0, 1.0, 48_000), + generate_sample_audio(nb_files=1, nb_samples=48_000)[0], id="read_whole_file", ), pytest.param( @@ -92,7 +93,7 @@ def test_audio_file_timestamps( second=0, microsecond=100_000, ), - np.linspace(0.0, 1.0, 48_000)[:4_800], + generate_sample_audio(nb_files=1, nb_samples=48_000)[0][:4_800], id="read_begin_only", ), pytest.param( @@ -120,7 +121,7 @@ def test_audio_file_timestamps( second=0, microsecond=600_000, ), - np.linspace(0.0, 1.0, 48_000)[24_000:28_800], + generate_sample_audio(nb_files=1, nb_samples=48_000)[0][24_000:28_800], id="read_in_the_middle_of_the_file", ), pytest.param( @@ -140,7 +141,7 @@ def test_audio_file_timestamps( microsecond=900_000, ), pd.Timestamp("2024-01-01 12:00:01"), - np.linspace(0.0, 1.0, 48_000)[43_200:], + generate_sample_audio(nb_files=1, nb_samples=48_000)[0][43_200:], id="read_end_of_file", ), ], From c9b25e59f888a286088d3c133277ebbd7eaf4454 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 4 Dec 2024 10:58:23 +0100 Subject: [PATCH 008/118] add AudioItem tests --- tests/test_audio.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/test_audio.py b/tests/test_audio.py index 10be1a30..b0c43a94 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -8,6 +8,7 @@ from OSmOSE.config import TIMESTAMP_FORMAT_TEST_FILES from OSmOSE.data.audio_file import AudioFile +from OSmOSE.data.audio_item import AudioItem from OSmOSE.utils.audio_utils import generate_sample_audio if TYPE_CHECKING: @@ -156,3 +157,61 @@ def test_audio_file_read( files, request = audio_files file = AudioFile(files[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) assert np.array_equal(file.read(start, stop), expected) + + +@pytest.mark.parametrize( + ("audio_files", "start", "stop", "expected"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + None, + None, + generate_sample_audio(nb_files=1, nb_samples=48_000)[0], + id="whole_file", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=500_000, + ), + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=600_000, + ), + generate_sample_audio(nb_files=1, nb_samples=48_000)[0][24_000:28_800], + id="mid_file", + ), + ], + indirect=["audio_files"], +) +def test_audio_item( + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + start: pd.Timestamp | None, + stop: pd.Timestamp | None, + expected: np.ndarray, +) -> None: + files, request = audio_files + file = AudioFile(files[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) + item = AudioItem(file, start, stop) + assert np.array_equal(item.get_value(), expected) From cd88fad30b8258a713ca9d18c83188bd24b7b616 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 4 Dec 2024 12:01:31 +0100 Subject: [PATCH 009/118] add DataBase and AudioData --- src/OSmOSE/Spectrogram.py | 1 - src/OSmOSE/Weather.py | 1 + src/OSmOSE/__init__.py | 1 - src/OSmOSE/cluster/audio_reshaper.py | 2 - src/OSmOSE/data/audio_data.py | 26 ++++++++++++ src/OSmOSE/data/data_base.py | 61 ++++++++++++++++++++++++++++ src/OSmOSE/utils/audio_utils.py | 9 +++- tests/conftest.py | 4 +- tests/test_audio.py | 10 ++--- tests/test_logging.py | 2 - tests/test_permissions.py | 7 ---- tests/test_timestamp_utils.py | 1 - 12 files changed, 103 insertions(+), 22 deletions(-) create mode 100644 src/OSmOSE/data/audio_data.py create mode 100644 src/OSmOSE/data/data_base.py diff --git a/src/OSmOSE/Spectrogram.py b/src/OSmOSE/Spectrogram.py index 3738db4c..6b977eba 100644 --- a/src/OSmOSE/Spectrogram.py +++ b/src/OSmOSE/Spectrogram.py @@ -750,7 +750,6 @@ def initialize( i_max = -1 for batch in range(self.batch_number): - i_min = i_max + 1 i_max = ( i_min + batch_size diff --git a/src/OSmOSE/Weather.py b/src/OSmOSE/Weather.py index 67c549ab..1b82048c 100644 --- a/src/OSmOSE/Weather.py +++ b/src/OSmOSE/Weather.py @@ -3,6 +3,7 @@ @author: cazaudo """ + import itertools import os import sys diff --git a/src/OSmOSE/__init__.py b/src/OSmOSE/__init__.py index f72857c4..21907def 100755 --- a/src/OSmOSE/__init__.py +++ b/src/OSmOSE/__init__.py @@ -24,7 +24,6 @@ def _setup_logging( config_file: FileName = "logging_config.yaml", default_level: int = logging.INFO, ) -> None: - user_config_file_path = Path(os.getenv("OSMOSE_USER_CONFIG", ".")) / config_file default_config_file_path = Path(__file__).parent / config_file diff --git a/src/OSmOSE/cluster/audio_reshaper.py b/src/OSmOSE/cluster/audio_reshaper.py index ef5bf40b..30b769ae 100644 --- a/src/OSmOSE/cluster/audio_reshaper.py +++ b/src/OSmOSE/cluster/audio_reshaper.py @@ -229,7 +229,6 @@ def reshape( result = [] timestamp_list = [] for i in range(batch_ind_max - batch_ind_min + 1): - audio_data = np.zeros(shape=segment_size * new_sr) if concat: @@ -252,7 +251,6 @@ def reshape( len_sig = 0 for index, row in file_metadata.iterrows(): - file_datetime_begin = row["timestamp"] file_datetime_end = row["timestamp"] + pd.Timedelta(seconds=row["duration"]) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py new file mode 100644 index 00000000..a088badd --- /dev/null +++ b/src/OSmOSE/data/audio_data.py @@ -0,0 +1,26 @@ +"""AudioData encapsulating to a collection of AudioItem objects.""" + +from __future__ import annotations + +from OSmOSE.data.audio_item import AudioItem +from OSmOSE.data.data_base import DataBase + + +class AudioData(DataBase): + """AudioData encapsulating to a collection of AudioItem objects. + + The audio data can be retrieved from several Files through the Items. + """ + + item_cls = AudioItem + + def __init__(self, items: list[AudioItem]) -> None: + """Initialize an AudioData from a list of AudioItems. + + Parameters + ---------- + items: list[AudioItem] + List of the AudioItem constituting the AudioData. + + """ + super().__init__(items) diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py new file mode 100644 index 00000000..c9fd28b6 --- /dev/null +++ b/src/OSmOSE/data/data_base.py @@ -0,0 +1,61 @@ +"""DataBase: Base class for the Data objects (e.g. AudioData). + +Data corresponds to a collection of Items. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from OSmOSE.data.item_base import ItemBase + +if TYPE_CHECKING: + from OSmOSE.data.file_base import FileBase + + +class DataBase: + """Base class for Data objects. + + A Data object is a collection of Item objects. + Data can be retrieved from several Files through the Items. + """ + + item_cls = ItemBase + + def __init__(self, items: list[ItemBase]) -> None: + """Initialize an DataBase from a list of Items. + + Parameters + ---------- + items: list[ItemBase] + List of the Items constituting the Data. + + """ + self.items = items + + def get_value(self) -> np.ndarray: + """Get the concatenated values from all Items.""" + return np.concatenate([item.get_value() for item in self.items]) + + @classmethod + def from_file(cls, file: FileBase) -> DataBase: + """Initialize a DataBase from a single File. + + The resulting Data object will contain a single Item. + This single Item will correspond to the whole File. + + Parameters + ---------- + file: OSmOSE.data.file_base.FileBase + The File encapsulated in the Data object. + + Returns + ------- + OSmOSE.data.data_base.DataBase + The Data object. + + """ + item = cls.item_cls(file) + return cls(items=[item]) diff --git a/src/OSmOSE/utils/audio_utils.py b/src/OSmOSE/utils/audio_utils.py index e329c017..45544d4c 100644 --- a/src/OSmOSE/utils/audio_utils.py +++ b/src/OSmOSE/utils/audio_utils.py @@ -123,6 +123,7 @@ def check_audio( message = "Your audio files have large duration discrepancies." raise ValueError(message) + def generate_sample_audio( nb_files: int, nb_samples: int, @@ -154,7 +155,11 @@ def generate_sample_audio( """ if series_type == "repeat": - return np.split(np.tile(np.linspace(min_value, max_value, nb_samples), nb_files), nb_files) + return np.split( + np.tile(np.linspace(min_value, max_value, nb_samples), nb_files), nb_files + ) if series_type == "increase": - return np.split(np.linspace(min_value, max_value, nb_samples * nb_files), nb_files) + return np.split( + np.linspace(min_value, max_value, nb_samples * nb_files), nb_files + ) return np.split(np.empty(nb_samples * nb_files), nb_files) diff --git a/tests/conftest.py b/tests/conftest.py index 5f17729f..3fbb7dda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,9 @@ def audio_files( nb_samples = int(round(duration * sample_rate)) data = generate_sample_audio( - nb_files=nb_files, nb_samples=nb_samples, series_type=series_type, + nb_files=nb_files, + nb_samples=nb_samples, + series_type=series_type, ) files = [] for index, begin_time in enumerate( diff --git a/tests/test_audio.py b/tests/test_audio.py index b0c43a94..f917270b 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -174,7 +174,7 @@ def test_audio_file_read( generate_sample_audio(nb_files=1, nb_samples=48_000)[0], id="whole_file", ), - pytest.param( + pytest.param( { "duration": 1, "sample_rate": 48_000, @@ -206,10 +206,10 @@ def test_audio_file_read( indirect=["audio_files"], ) def test_audio_item( - audio_files: tuple[list[Path], pytest.fixtures.Subrequest], - start: pd.Timestamp | None, - stop: pd.Timestamp | None, - expected: np.ndarray, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + start: pd.Timestamp | None, + stop: pd.Timestamp | None, + expected: np.ndarray, ) -> None: files, request = audio_files file = AudioFile(files[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) diff --git a/tests/test_logging.py b/tests/test_logging.py index 27833645..753360b7 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -87,7 +87,6 @@ def set_user_config_env(temp_user_logging_config): @pytest.mark.allow_log_write_to_file def test_user_logging_config(set_user_config_env, caplog, tmp_path): - assert ( len(logging.getLogger("test_user_logger").handlers) > 0 ) # This is a tweaky way of checking if the test_user_logger logger has already been created @@ -101,7 +100,6 @@ def test_user_logging_config(set_user_config_env, caplog, tmp_path): def test_default_logging_config(caplog, tmp_path): - assert ( len(logging.getLogger("dataset").handlers) > 0 ) # This is a tweaky way of checking if the test_user_logger logger has already been created diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 989f24a5..da9ccad8 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -16,7 +16,6 @@ @pytest.mark.unit def test_no_error_o_non_unix_os(tmp_path: Path) -> None: - OSmOSE.utils.core_utils._is_grp_supported = False try: change_owner_group(path=tmp_path, owner_group="test") @@ -30,7 +29,6 @@ def test_no_chmod_attempt_if_not_needed( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: - OSmOSE.utils.core_utils._is_grp_supported = True monkeypatch.setattr(os, "access", lambda path, mode: mode in [os.R_OK, os.W_OK]) @@ -72,7 +70,6 @@ def test_chmod_called_if_missing_permissions( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: - OSmOSE.utils.core_utils._is_grp_supported = True file_mode = 0o664 @@ -94,7 +91,6 @@ def test_error_logged_if_no_chmod_permission( monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, ) -> None: - OSmOSE.utils.core_utils._is_grp_supported = True monkeypatch.setattr(os, "access", lambda path, mode: mode in []) @@ -134,7 +130,6 @@ def test_change_owner_group( tmp_path: Path, patch_grp_module: MagicMock, ) -> None: - OSmOSE.utils.core_utils._is_grp_supported = True patch_grp_module.groups = existing_groups @@ -151,7 +146,6 @@ def test_change_owner_group_keyerror_is_logged( caplog: pytest.LogCaptureFixture, patch_grp_module: MagicMock, ) -> None: - OSmOSE.utils.core_utils._is_grp_supported = True patch_grp_module.groups = ["ensta", "gosmose", "other"] @@ -171,7 +165,6 @@ def test_change_owner_group_permission_error_is_logged( patch_grp_module: MagicMock, monkeypatch: pytest.MonkeyPatch, ) -> None: - OSmOSE.utils.core_utils._is_grp_supported = True existing_groups = ["ensta", "gosmose", "other"] diff --git a/tests/test_timestamp_utils.py b/tests/test_timestamp_utils.py index a8270fce..1422afa6 100644 --- a/tests/test_timestamp_utils.py +++ b/tests/test_timestamp_utils.py @@ -590,7 +590,6 @@ def test_localize_timestamp_warns_when_changing_timezone( expected: Timestamp, caplog: pytest.LogCaptureFixture, ) -> None: - with caplog.at_level(logging.WARNING): result = localize_timestamp(timestamp=timestamp, timezone=timezone) From e81c384b620441668af4bd250d2fdb067a3d3b1f Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 4 Dec 2024 15:17:57 +0100 Subject: [PATCH 010/118] cap item boundaries --- src/OSmOSE/data/item_base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index 150c7856..1c838d3c 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -34,15 +34,15 @@ def __init__( The File in which this Item belongs. begin: pandas.Timestamp (optional) The timestamp at which this item begins. - It is defaulted to the File begin. + It is defaulted or maxed to the File begin. end: pandas.Timestamp (optional) The timestamp at which this item ends. - It is defaulted to the File end. + It is defaulted or mined to the File end. """ self.file = file - self.begin = begin if begin is not None else self.file.begin - self.end = end if end is not None else self.file.end + self.begin = max(begin, self.file.begin) if begin is not None else self.file.begin + self.end = min(end, self.file.end) if end is not None else self.file.end def get_value(self) -> np.ndarray: """Get the values from the File between the begin and stop timestamps.""" From 8dd587c7e60c69e33e34e07cded0579c9ed74928 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 4 Dec 2024 15:36:30 +0100 Subject: [PATCH 011/118] add is_overlapping function --- src/OSmOSE/utils/timestamp_utils.py | 28 +++++++++ tests/test_timestamp_utils.py | 98 +++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/src/OSmOSE/utils/timestamp_utils.py b/src/OSmOSE/utils/timestamp_utils.py index 03178cbf..fe4015b2 100644 --- a/src/OSmOSE/utils/timestamp_utils.py +++ b/src/OSmOSE/utils/timestamp_utils.py @@ -416,6 +416,34 @@ def is_osmose_format_timestamp(timestamp: str) -> bool: return True +def is_overlapping( + event1: tuple[Timestamp, Timestamp], event2: tuple[Timestamp, Timestamp] +) -> bool: + """Return True if the two events are overlapping, False otherwise. + + Parameters + ---------- + event1: tuple[pandas.Timestamp, pandas.Timestamp] + The first event. + event2: tuple[pandas.Timestamp, pandas.Timestamp] + The second event. + + Returns + ------- + bool: + True if the two events are overlapping, False otherwise. + + Examples + -------- + >>> is_overlapping((Timestamp("2024-01-01 00:00:00"),(Timestamp("2024-01-02 00:00:00"))), (Timestamp("2024-01-01 12:00:00"),(Timestamp("2024-01-02 12:00:00")))) + True + >>> is_overlapping((Timestamp("2024-01-01 00:00:00"),(Timestamp("2024-01-02 00:00:00"))), (Timestamp("2024-01-02 00:00:00"),(Timestamp("2024-01-02 12:00:00")))) + False + + """ + return event1[0] < event2[1] and event1[1] > event2[0] + + def get_timestamps( path_osmose_dataset: str, campaign_name: str, diff --git a/tests/test_timestamp_utils.py b/tests/test_timestamp_utils.py index 1422afa6..fb088464 100644 --- a/tests/test_timestamp_utils.py +++ b/tests/test_timestamp_utils.py @@ -12,6 +12,7 @@ associate_timestamps, build_regex_from_datetime_template, is_datetime_template_valid, + is_overlapping, localize_timestamp, parse_timestamps_csv, reformat_timestamp, @@ -1052,3 +1053,100 @@ def test_adapt_timestamp_csv_to_osmose( assert adapt_timestamp_csv_to_osmose(timestamps, date_template, timezone).equals( expected, ) + + +@pytest.mark.parametrize( + ("event1", "event2", "expected"), + [ + pytest.param( + ( + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-02 00:00:00"), + ), + ( + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-02 00:00:00"), + ), + True, + id="same_event", + ), + pytest.param( + ( + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-02 00:00:00"), + ), + ( + Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-02 12:00:00"), + ), + True, + id="overlapping_events", + ), + pytest.param( + ( + Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-02 12:00:00"), + ), + ( + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-02 00:00:00"), + ), + True, + id="overlapping_events_reversed", + ), + pytest.param( + ( + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-02 00:00:00"), + ), + ( + Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-01 12:01:00"), + ), + True, + id="embedded_events", + ), + pytest.param( + ( + Timestamp("2024-01-01 0:00:00"), + Timestamp("2024-01-01 12:00:00"), + ), + ( + Timestamp("2024-01-02 00:00:00"), + Timestamp("2024-01-02 12:00:00"), + ), + False, + id="non_overlapping_events", + ), + pytest.param( + ( + Timestamp("2024-01-02 0:00:00"), + Timestamp("2024-01-02 12:00:00"), + ), + ( + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-01 12:00:00"), + ), + False, + id="non_overlapping_events_reversed", + ), + pytest.param( + ( + Timestamp("2024-01-01 0:00:00"), + Timestamp("2024-01-01 12:00:00"), + ), + ( + Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-02 00:00:00"), + ), + False, + id="border_sharing_isnt_overlapping", + ), + ], +) +def test_overlapping_events( + event1: tuple[Timestamp, Timestamp], + event2: tuple[Timestamp, Timestamp], + expected: bool, +) -> None: + assert is_overlapping(event1, event2) == expected From 56097a9105769d00aa7fba9c682020a6c068046d Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 4 Dec 2024 16:46:05 +0100 Subject: [PATCH 012/118] parse Data over several Files --- src/OSmOSE/data/data_base.py | 25 +++++++++++--- src/OSmOSE/data/item_base.py | 4 ++- tests/test_audio.py | 63 ++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 6 deletions(-) diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index c9fd28b6..eaa44c32 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -8,8 +8,10 @@ from typing import TYPE_CHECKING import numpy as np +from pandas import Timestamp from OSmOSE.data.item_base import ItemBase +from OSmOSE.utils.timestamp_utils import is_overlapping if TYPE_CHECKING: from OSmOSE.data.file_base import FileBase @@ -34,13 +36,20 @@ def __init__(self, items: list[ItemBase]) -> None: """ self.items = items + self.begin = min(item.begin for item in self.items) + self.end = max(item.end for item in self.items) def get_value(self) -> np.ndarray: """Get the concatenated values from all Items.""" return np.concatenate([item.get_value() for item in self.items]) @classmethod - def from_file(cls, file: FileBase) -> DataBase: + def from_files( + cls, + files: list[FileBase], + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> DataBase: """Initialize a DataBase from a single File. The resulting Data object will contain a single Item. @@ -48,8 +57,8 @@ def from_file(cls, file: FileBase) -> DataBase: Parameters ---------- - file: OSmOSE.data.file_base.FileBase - The File encapsulated in the Data object. + files: list[OSmOSE.data.file_base.FileBase] + The Files encapsulated in the Data object. Returns ------- @@ -57,5 +66,11 @@ def from_file(cls, file: FileBase) -> DataBase: The Data object. """ - item = cls.item_cls(file) - return cls(items=[item]) + begin = min(file.begin for file in files) if begin is None else begin + end = max(file.end for file in files) if end is None else end + + overlapping_files = [file for file in files if is_overlapping((file.begin, file.end), (begin, end))] + + items = [cls.item_cls(file, begin, end) for file in overlapping_files] + + return cls(items=items) diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index 1c838d3c..4b6a0215 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -41,7 +41,9 @@ def __init__( """ self.file = file - self.begin = max(begin, self.file.begin) if begin is not None else self.file.begin + self.begin = ( + max(begin, self.file.begin) if begin is not None else self.file.begin + ) self.end = min(end, self.file.end) if end is not None else self.file.end def get_value(self) -> np.ndarray: diff --git a/tests/test_audio.py b/tests/test_audio.py index f917270b..83c087d8 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -7,6 +7,7 @@ import pytest from OSmOSE.config import TIMESTAMP_FORMAT_TEST_FILES +from OSmOSE.data.audio_data import AudioData from OSmOSE.data.audio_file import AudioFile from OSmOSE.data.audio_item import AudioItem from OSmOSE.utils.audio_utils import generate_sample_audio @@ -215,3 +216,65 @@ def test_audio_item( file = AudioFile(files[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) item = AudioItem(file, start, stop) assert np.array_equal(item.get_value(), expected) + + +@pytest.mark.parametrize( + ("audio_files", "start", "stop", "expected"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 2, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + None, + None, + generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0], + id="all_files", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 2, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=800_000, + ), + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=1, + microsecond=200_000, + ), + generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0][38_400:57_600], + id="between_files", + ), + ], + indirect=["audio_files"], +) +def test_audio_data( + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + start: pd.Timestamp | None, + stop: pd.Timestamp | None, + expected: np.ndarray, +) -> None: + files, request = audio_files + audio_files = [ + AudioFile(file, strptime_format=TIMESTAMP_FORMAT_TEST_FILES) for file in files + ] + data = AudioData.from_files(audio_files, begin=start, end=stop) + assert np.array_equal(data.get_value(), expected) From 49f1bae257f27e2872c48eb04a0ed142c59f455a Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 4 Dec 2024 17:22:03 +0100 Subject: [PATCH 013/118] black --- src/OSmOSE/data/data_base.py | 15 +++++++++-- tests/test_timestamp_utils.py | 48 +++++++++++++++++------------------ 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index eaa44c32..ee61d1dd 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -8,12 +8,13 @@ from typing import TYPE_CHECKING import numpy as np -from pandas import Timestamp from OSmOSE.data.item_base import ItemBase from OSmOSE.utils.timestamp_utils import is_overlapping if TYPE_CHECKING: + from pandas import Timestamp + from OSmOSE.data.file_base import FileBase @@ -59,6 +60,12 @@ def from_files( ---------- files: list[OSmOSE.data.file_base.FileBase] The Files encapsulated in the Data object. + begin: pandas.Timestamp | None + The begin of the Data object. + defaulted to the begin of the first File. + end: pandas.Timestamp | None + The end of the Data object. + default to the end of the last File. Returns ------- @@ -69,7 +76,11 @@ def from_files( begin = min(file.begin for file in files) if begin is None else begin end = max(file.end for file in files) if end is None else end - overlapping_files = [file for file in files if is_overlapping((file.begin, file.end), (begin, end))] + overlapping_files = [ + file + for file in files + if is_overlapping((file.begin, file.end), (begin, end)) + ] items = [cls.item_cls(file, begin, end) for file in overlapping_files] diff --git a/tests/test_timestamp_utils.py b/tests/test_timestamp_utils.py index fb088464..df095ecd 100644 --- a/tests/test_timestamp_utils.py +++ b/tests/test_timestamp_utils.py @@ -1072,72 +1072,72 @@ def test_adapt_timestamp_csv_to_osmose( ), pytest.param( ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-02 00:00:00"), + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-02 00:00:00"), ), ( - Timestamp("2024-01-01 12:00:00"), - Timestamp("2024-01-02 12:00:00"), + Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-02 12:00:00"), ), True, id="overlapping_events", ), pytest.param( ( - Timestamp("2024-01-01 12:00:00"), - Timestamp("2024-01-02 12:00:00"), + Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-02 12:00:00"), ), ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-02 00:00:00"), + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-02 00:00:00"), ), True, id="overlapping_events_reversed", ), pytest.param( ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-02 00:00:00"), + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-02 00:00:00"), ), ( - Timestamp("2024-01-01 12:00:00"), - Timestamp("2024-01-01 12:01:00"), + Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-01 12:01:00"), ), True, id="embedded_events", ), pytest.param( ( - Timestamp("2024-01-01 0:00:00"), - Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-01 0:00:00"), + Timestamp("2024-01-01 12:00:00"), ), ( - Timestamp("2024-01-02 00:00:00"), - Timestamp("2024-01-02 12:00:00"), + Timestamp("2024-01-02 00:00:00"), + Timestamp("2024-01-02 12:00:00"), ), False, id="non_overlapping_events", ), pytest.param( ( - Timestamp("2024-01-02 0:00:00"), - Timestamp("2024-01-02 12:00:00"), + Timestamp("2024-01-02 0:00:00"), + Timestamp("2024-01-02 12:00:00"), ), ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-01 00:00:00"), + Timestamp("2024-01-01 12:00:00"), ), False, id="non_overlapping_events_reversed", ), pytest.param( ( - Timestamp("2024-01-01 0:00:00"), - Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-01 0:00:00"), + Timestamp("2024-01-01 12:00:00"), ), ( - Timestamp("2024-01-01 12:00:00"), - Timestamp("2024-01-02 00:00:00"), + Timestamp("2024-01-01 12:00:00"), + Timestamp("2024-01-02 00:00:00"), ), False, id="border_sharing_isnt_overlapping", From 6591490209dc91478399cd1f092ef56142f21a2a Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 5 Dec 2024 17:46:27 +0100 Subject: [PATCH 014/118] add concatenate item method --- src/OSmOSE/data/audio_item.py | 2 +- src/OSmOSE/data/data_base.py | 1 + src/OSmOSE/data/item_base.py | 50 ++++++++++++++++++++++++++++++++--- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index 4adb5402..0412b135 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -17,7 +17,7 @@ class AudioItem(ItemBase): def __init__( self, - file: AudioFile, + file: AudioFile | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, ) -> None: diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index ee61d1dd..70b74947 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -83,5 +83,6 @@ def from_files( ] items = [cls.item_cls(file, begin, end) for file in overlapping_files] + items = ItemBase.concatenate_items(items) return cls(items=items) diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index 4b6a0215..29325e61 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -7,6 +7,8 @@ from typing import TYPE_CHECKING +from OSmOSE.utils.timestamp_utils import is_overlapping + if TYPE_CHECKING: import numpy as np from pandas import Timestamp @@ -22,7 +24,7 @@ class ItemBase: def __init__( self, - file: FileBase, + file: FileBase | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, ) -> None: @@ -41,11 +43,53 @@ def __init__( """ self.file = file + + if file is None: + self.begin = begin + self.end = end + return + self.begin = ( max(begin, self.file.begin) if begin is not None else self.file.begin ) self.end = min(end, self.file.end) if end is not None else self.file.end def get_value(self) -> np.ndarray: - """Get the values from the File between the begin and stop timestamps.""" - return self.file.read(start=self.begin, stop=self.end) + """Get the values from the File between the begin and stop timestamps. + + If the Item is empty, return a single 0. + """ + return ( + np.zeros(1) + if self.is_empty + else self.file.read(start=self.begin, stop=self.end) + ) + + @property + def is_empty(self) -> bool: + return self.file is None + + @staticmethod + def concatenate_items(items: list[ItemBase]) -> list[ItemBase]: + items = sorted(items, key=lambda item: (item.begin, item.end)) + concatenated_items: list[ItemBase] = [] + for item in items: + overlapping_items = [ + item2 + for item2 in items + if is_overlapping((item.begin, item.end), (item2.begin, item2.end)) + ] + if len(overlapping_items) == 1: + concatenated_items.append(item) + continue + kept_overlapping_item = max(overlapping_items, key=lambda item: item.end) + if kept_overlapping_item is not item: + item.end = kept_overlapping_item.begin + concatenated_items.append(item) + for dismissed_item in ( + item2 + for item2 in overlapping_items + if item2 not in (item, kept_overlapping_item) + ): + items.remove(dismissed_item) + return concatenated_items From 82cfcf67f163c1df5904c270290d7664d9c44330 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 6 Dec 2024 10:07:30 +0100 Subject: [PATCH 015/118] add docstrings for ItemBase methods --- src/OSmOSE/data/item_base.py | 40 +++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index 29325e61..ff20f287 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -5,16 +5,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import numpy as np +from pandas import Timestamp +from OSmOSE.data.file_base import FileBase from OSmOSE.utils.timestamp_utils import is_overlapping -if TYPE_CHECKING: - import numpy as np - from pandas import Timestamp - - from OSmOSE.data.file_base import FileBase - class ItemBase: """Base class for the Item objects (e.g. AudioItem). @@ -67,10 +63,40 @@ def get_value(self) -> np.ndarray: @property def is_empty(self) -> bool: + """Return True if no File is attached to this Item.""" return self.file is None @staticmethod def concatenate_items(items: list[ItemBase]) -> list[ItemBase]: + """Resolve overlaps between Items. + + If two Items overlap within the sequence (that is if one Item begins before the end of another, + the earliest Item's end is set to the begin of the latest Item. + If multiple items overlap with one earlier Item, only one is chosen as next. + The chosen next Item is the one that ends the latest. + The Items are concatenated in-place. + + Parameters + ---------- + items: list[ItemBase] + List of Items to concatenate. + + Returns + ------- + list[ItemBase]: + The list of Items with no overlapping Items. + + Examples + -------- + >>> items = [ItemBase(begin = Timestamp("00:00:00"), end = Timestamp("00:00:15")), ItemBase(begin = Timestamp("00:00:10"), end = Timestamp("00:00:20"))] + >>> items[0].end == items[1].begin + False + >>> ItemBase.concatenate_items(items) # doctest: +ELLIPSIS + [, ] + >>> items[0].end == items[1].begin + True + + """ items = sorted(items, key=lambda item: (item.begin, item.end)) concatenated_items: list[ItemBase] = [] for item in items: From 9f2ae2b20657b62addf1d7f77f15f49283d685eb Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 6 Dec 2024 11:04:06 +0100 Subject: [PATCH 016/118] remove in-place item modification --- src/OSmOSE/data/item_base.py | 32 ++++++++++++++++++++++---------- tests/test_item.py | 8 ++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) create mode 100644 tests/test_item.py diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index ff20f287..b7c382dc 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -5,6 +5,8 @@ from __future__ import annotations +import copy + import numpy as np from pandas import Timestamp @@ -66,6 +68,16 @@ def is_empty(self) -> bool: """Return True if no File is attached to this Item.""" return self.file is None + def __eq__(self, other: ItemBase) -> bool: + """Override the default implementation.""" + if not isinstance(other, ItemBase): + return False + if self.file != other.file: + return False + if self.begin != other.begin: + return False + return not self.end != other.end + @staticmethod def concatenate_items(items: list[ItemBase]) -> list[ItemBase]: """Resolve overlaps between Items. @@ -74,7 +86,6 @@ def concatenate_items(items: list[ItemBase]) -> list[ItemBase]: the earliest Item's end is set to the begin of the latest Item. If multiple items overlap with one earlier Item, only one is chosen as next. The chosen next Item is the one that ends the latest. - The Items are concatenated in-place. Parameters ---------- @@ -91,31 +102,32 @@ def concatenate_items(items: list[ItemBase]) -> list[ItemBase]: >>> items = [ItemBase(begin = Timestamp("00:00:00"), end = Timestamp("00:00:15")), ItemBase(begin = Timestamp("00:00:10"), end = Timestamp("00:00:20"))] >>> items[0].end == items[1].begin False - >>> ItemBase.concatenate_items(items) # doctest: +ELLIPSIS - [, ] + >>> items = ItemBase.concatenate_items(items) >>> items[0].end == items[1].begin True """ - items = sorted(items, key=lambda item: (item.begin, item.end)) + items = sorted([copy.copy(item) for item in items], key=lambda item: (item.begin, item.begin-item.end)) concatenated_items: list[ItemBase] = [] for item in items: + concatenated_items.append(item) overlapping_items = [ item2 for item2 in items - if is_overlapping((item.begin, item.end), (item2.begin, item2.end)) + if item2 is not item and + is_overlapping((item.begin, item.end), (item2.begin, item2.end)) ] - if len(overlapping_items) == 1: - concatenated_items.append(item) + if not overlapping_items: continue kept_overlapping_item = max(overlapping_items, key=lambda item: item.end) - if kept_overlapping_item is not item: + if kept_overlapping_item.end > item.end: item.end = kept_overlapping_item.begin - concatenated_items.append(item) + else: + kept_overlapping_item = None for dismissed_item in ( item2 for item2 in overlapping_items - if item2 not in (item, kept_overlapping_item) + if item2 is not kept_overlapping_item ): items.remove(dismissed_item) return concatenated_items diff --git a/tests/test_item.py b/tests/test_item.py new file mode 100644 index 00000000..052b47e9 --- /dev/null +++ b/tests/test_item.py @@ -0,0 +1,8 @@ +import unittest + +class MyTestCase(unittest.TestCase): + def test_something(self): + self.assertEqual(True, False) # add assertion here + +if __name__ == '__main__': + unittest.main() From 56a4490594573c3c89de4ddb3f29d749d8c23d61 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 6 Dec 2024 11:22:52 +0100 Subject: [PATCH 017/118] add tests for item concatenation --- tests/test_item.py | 54 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/tests/test_item.py b/tests/test_item.py index 052b47e9..81260c2a 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -1,8 +1,50 @@ -import unittest +from __future__ import annotations -class MyTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, False) # add assertion here +import pytest +from pandas import Timestamp -if __name__ == '__main__': - unittest.main() +from OSmOSE.data.item_base import ItemBase + + +@pytest.mark.parametrize( + ("item_list", "expected"), + [ + pytest.param( + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + id="only_one_item_is_unchanged", + ), + pytest.param( + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")),ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + id="doubled_item_is_removed", + ), + pytest.param( + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20"))], + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20"))], + id="overlapping_item_is_truncated", + ), + pytest.param( + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:15")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20"))], + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20"))], + id="longest_item_is_prioritized", + ), + pytest.param( + [ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:25")), + ItemBase(begin=Timestamp("00:00:0"), end=Timestamp("00:00:15")), + ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35"))], + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35"))], + id="items_are_reordered", + ), + ], +) +def test_item_base(item_list: list[ItemBase], expected: list[ItemBase]) -> None: + cleaned_items = ItemBase.concatenate_items(item_list) + assert cleaned_items == expected From 592f2aa9d0905bf71f9b47ca1adfe94c9e028275 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 6 Dec 2024 11:43:23 +0100 Subject: [PATCH 018/118] add item fill_gaps method --- src/OSmOSE/data/data_base.py | 2 +- src/OSmOSE/data/item_base.py | 41 +++++++++++++- tests/test_item.py | 105 +++++++++++++++++++++++++++++------ 3 files changed, 127 insertions(+), 21 deletions(-) diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index 70b74947..cc502986 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -84,5 +84,5 @@ def from_files( items = [cls.item_cls(file, begin, end) for file in overlapping_files] items = ItemBase.concatenate_items(items) - + items = ItemBase.fill_gaps(items) return cls(items=items) diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index b7c382dc..ce390350 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -107,15 +107,18 @@ def concatenate_items(items: list[ItemBase]) -> list[ItemBase]: True """ - items = sorted([copy.copy(item) for item in items], key=lambda item: (item.begin, item.begin-item.end)) + items = sorted( + [copy.copy(item) for item in items], + key=lambda item: (item.begin, item.begin - item.end), + ) concatenated_items: list[ItemBase] = [] for item in items: concatenated_items.append(item) overlapping_items = [ item2 for item2 in items - if item2 is not item and - is_overlapping((item.begin, item.end), (item2.begin, item2.end)) + if item2 is not item + and is_overlapping((item.begin, item.end), (item2.begin, item2.end)) ] if not overlapping_items: continue @@ -131,3 +134,35 @@ def concatenate_items(items: list[ItemBase]) -> list[ItemBase]: ): items.remove(dismissed_item) return concatenated_items + + @staticmethod + def fill_gaps(items: list[ItemBase]) -> list[ItemBase]: + """Return a list with empty items added in the gaps between items. + + Parameters + ---------- + items: list[ItemBase] + List of Items to fill. + + Returns + ------- + list[ItemBase]: + List of Items with no gaps. + + Examples + -------- + >>> items = [ItemBase(begin = Timestamp("00:00:00"), end = Timestamp("00:00:10")), ItemBase(begin = Timestamp("00:00:15"), end = Timestamp("00:00:25"))] + >>> items = ItemBase.fill_gaps(items) + >>> [(item.begin.second, item.end.second) for item in items] + [(0, 10), (10, 15), (15, 25)] + + """ + items = sorted([copy.copy(item) for item in items], key=lambda item: item.begin) + filled_item_list = [] + for index, item in enumerate(items[:-1]): + next_item = items[index + 1] + filled_item_list.append(item) + if next_item.begin > item.end: + filled_item_list.append(ItemBase(begin=item.end, end=next_item.begin)) + filled_item_list.append(items[-1]) + return filled_item_list diff --git a/tests/test_item.py b/tests/test_item.py index 81260c2a..9b27b080 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -15,36 +15,107 @@ id="only_one_item_is_unchanged", ), pytest.param( - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")),ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ], [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], id="doubled_item_is_removed", ), pytest.param( - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20"))], - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20"))], + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], id="overlapping_item_is_truncated", ), pytest.param( - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:15")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20"))], - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20"))], + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:15")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], id="longest_item_is_prioritized", ), pytest.param( - [ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:25")), - ItemBase(begin=Timestamp("00:00:0"), end=Timestamp("00:00:15")), - ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35"))], - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35"))], + [ + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:25")), + ItemBase(begin=Timestamp("00:00:0"), end=Timestamp("00:00:15")), + ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), + ], + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), + ], id="items_are_reordered", ), ], ) -def test_item_base(item_list: list[ItemBase], expected: list[ItemBase]) -> None: +def test_concatenate_item(item_list: list[ItemBase], expected: list[ItemBase]) -> None: cleaned_items = ItemBase.concatenate_items(item_list) assert cleaned_items == expected + + +@pytest.mark.parametrize( + ("item_list", "expected"), + [ + pytest.param( + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + id="only_one_item_is_unchanged", + ), + pytest.param( + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + id="consecutive_items_are_unchanged", + ), + pytest.param( + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + ], + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + ], + id="one_gap_is_filled", + ), + pytest.param( + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + ItemBase(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), + ItemBase(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), + ], + [ + ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + ItemBase(begin=Timestamp("00:00:30"), end=Timestamp("00:00:35")), + ItemBase(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), + ItemBase(begin=Timestamp("00:00:45"), end=Timestamp("00:01:00")), + ItemBase(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), + ], + id="multiple_gaps_are_filled", + ), + ], +) +def test_fill_item_gaps(item_list: list[ItemBase], expected: list[ItemBase]) -> None: + filled_items = ItemBase.fill_gaps(item_list) + assert filled_items == expected From ce28be2eb8b4ac965cf87278d7b4349be4fa05b6 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 14:21:51 +0100 Subject: [PATCH 019/118] add resamble util function --- src/OSmOSE/utils/audio_utils.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/src/OSmOSE/utils/audio_utils.py b/src/OSmOSE/utils/audio_utils.py index 45544d4c..1bfc9ec9 100644 --- a/src/OSmOSE/utils/audio_utils.py +++ b/src/OSmOSE/utils/audio_utils.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +import soxr from OSmOSE.config import ( AUDIO_METADATA, @@ -156,10 +157,33 @@ def generate_sample_audio( """ if series_type == "repeat": return np.split( - np.tile(np.linspace(min_value, max_value, nb_samples), nb_files), nb_files + np.tile(np.linspace(min_value, max_value, nb_samples), nb_files), + nb_files, ) if series_type == "increase": return np.split( - np.linspace(min_value, max_value, nb_samples * nb_files), nb_files + np.linspace(min_value, max_value, nb_samples * nb_files), + nb_files, ) return np.split(np.empty(nb_samples * nb_files), nb_files) + + +def resample(data: np.ndarray, origin_sr: float, target_sr: float) -> np.ndarray: + """Resample the audio data using soxr. + + Parameters + ---------- + data: np.ndarray + The audio data to resample. + origin_sr: + The sampling rate of the audio data. + target_sr: + The sampling rate of the resampled audio data. + + Returns + ------- + np.ndarray + The resampled audio data. + + """ + return soxr.resample(data, origin_sr, target_sr) From 201e6fdea335ab56f124311640d83cec9eb536d5 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 14:23:05 +0100 Subject: [PATCH 020/118] add empty items until data boundaries --- src/OSmOSE/data/data_base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index cc502986..356c63df 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -40,6 +40,11 @@ def __init__(self, items: list[ItemBase]) -> None: self.begin = min(item.begin for item in self.items) self.end = max(item.end for item in self.items) + @property + def total_seconds(self) -> float: + """Return the total duration of the data in seconds.""" + return (self.end - self.begin).total_seconds() + def get_value(self) -> np.ndarray: """Get the concatenated values from all Items.""" return np.concatenate([item.get_value() for item in self.items]) @@ -83,6 +88,12 @@ def from_files( ] items = [cls.item_cls(file, begin, end) for file in overlapping_files] + if not items: + items.append(cls.item_cls(begin=begin, end=end)) + if (first_item := sorted(items, key=lambda item: item.begin)[0]).begin > begin: + items.append(cls.item_cls(begin=begin, end=first_item.begin)) + if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: + items.append(cls.item_cls(begin=last_item.end, end=end)) items = ItemBase.concatenate_items(items) items = ItemBase.fill_gaps(items) return cls(items=items) From 8917403afc619c959d0648291c4543b56ca928f6 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 14:23:47 +0100 Subject: [PATCH 021/118] add total_seconds item property --- src/OSmOSE/data/item_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index ce390350..7733ef74 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -68,6 +68,11 @@ def is_empty(self) -> bool: """Return True if no File is attached to this Item.""" return self.file is None + @property + def total_seconds(self) -> float: + """Return the total duration of the item in seconds.""" + return (self.end - self.begin).total_seconds() + def __eq__(self, other: ItemBase) -> bool: """Override the default implementation.""" if not isinstance(other, ItemBase): From 0a9896d4f0eaf42e5237565215fd2d5b239cf870 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 14:25:00 +0100 Subject: [PATCH 022/118] add sr and channels properties --- src/OSmOSE/data/audio_item.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index 0412b135..7707578b 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -36,3 +36,11 @@ def __init__( """ super().__init__(file, begin, end) + + @property + def sample_rate(self): + return None if self.is_empty else self.file.metadata.samplerate + + @property + def nb_channels(self): + return 0 if self.is_empty else self.file.metadata.channels From 4a4976a7b154aad2b11f2ab0e76eb113519deeab Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 14:25:16 +0100 Subject: [PATCH 023/118] resample audio data on get_value() call --- src/OSmOSE/data/audio_data.py | 46 ++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index a088badd..f09c5e18 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -2,8 +2,11 @@ from __future__ import annotations +import numpy as np + from OSmOSE.data.audio_item import AudioItem from OSmOSE.data.data_base import DataBase +from OSmOSE.utils.audio_utils import resample class AudioData(DataBase): @@ -14,13 +17,54 @@ class AudioData(DataBase): item_cls = AudioItem - def __init__(self, items: list[AudioItem]) -> None: + def __init__(self, items: list[AudioItem], sample_rate: int | None = None) -> None: """Initialize an AudioData from a list of AudioItems. Parameters ---------- items: list[AudioItem] List of the AudioItem constituting the AudioData. + sample_rate: int + The sample rate of the audio data. """ super().__init__(items) + self._check_sample_rates(sample_rate=sample_rate) + + @property + def nb_channels(self) -> int: + return max( + [1] + [item.nb_channels for item in self.items if type(item) is AudioItem] + ) + + @property + def shape(self): + data_length = int(self.sample_rate * self.total_seconds) + return data_length if self.nb_channels <= 1 else (data_length, self.nb_channels) + + def _check_sample_rates(self, sample_rate: int | None = None) -> None: + if sample_rate is not None or any( + sample_rate := item.sample_rate + for item in self.items + if item.sample_rate is not None + ): + self.sample_rate = sample_rate + else: + self.sample_rate = None + + def get_value(self): + data = np.empty(shape=self.shape) + idx = 0 + for item in self.items: + item_data = self._get_item_value(item) + data[idx : idx + len(item_data)] = item_data + idx += len(item_data) + return data + + def _get_item_value(self, item: AudioItem) -> np.ndarray: + item_data = item.get_value() + if item.is_empty: + return item_data.repeat(int(item.total_seconds * self.sample_rate)) + if item.sample_rate != self.sample_rate: + return resample(item_data, item.sample_rate, self.sample_rate) + return item_data From 3a92fc0ef65a25b02b2723118e9cf68aacb35931 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 15:09:17 +0100 Subject: [PATCH 024/118] add audio_data write method --- src/OSmOSE/config.py | 1 + src/OSmOSE/data/audio_data.py | 21 +++++++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/OSmOSE/config.py b/src/OSmOSE/config.py index 431c3a59..2b841134 100755 --- a/src/OSmOSE/config.py +++ b/src/OSmOSE/config.py @@ -35,6 +35,7 @@ TIMESTAMP_FORMAT_AUDIO_FILE = "%Y-%m-%dT%H:%M:%S.%f%z" TIMESTAMP_FORMAT_TEST_FILES = "%y%m%d%H%M%S" +TIMESTAMP_FORMAT_EXPORTED_FILES = "%Y_%m_%d_%H_%M_%S" FPDEFAULT = 0o664 # Default file permissions DPDEFAULT = stat.S_ISGID | 0o775 # Default directory permissions diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index f09c5e18..64128363 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -4,10 +4,12 @@ import numpy as np +from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES from OSmOSE.data.audio_item import AudioItem from OSmOSE.data.data_base import DataBase from OSmOSE.utils.audio_utils import resample - +import soundfile as sf +from pathlib import Path class AudioData(DataBase): """AudioData encapsulating to a collection of AudioItem objects. @@ -34,7 +36,7 @@ def __init__(self, items: list[AudioItem], sample_rate: int | None = None) -> No @property def nb_channels(self) -> int: return max( - [1] + [item.nb_channels for item in self.items if type(item) is AudioItem] + [1] + [item.nb_channels for item in self.items if type(item) is AudioItem], ) @property @@ -42,6 +44,10 @@ def shape(self): data_length = int(self.sample_rate * self.total_seconds) return data_length if self.nb_channels <= 1 else (data_length, self.nb_channels) + def __str__(self) -> str: + """Overwrite __str__.""" + return self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES) + def _check_sample_rates(self, sample_rate: int | None = None) -> None: if sample_rate is not None or any( sample_rate := item.sample_rate @@ -61,6 +67,17 @@ def get_value(self): idx += len(item_data) return data + def write(self, folder: Path) -> None: + """Write the audio data to file. + + Parameters + ---------- + folder: pathlib.Path + Folder in which to write the audio file. + + """ + sf.write(folder / f"{self}.wav" , self.get_value(), self.sample_rate) + def _get_item_value(self, item: AudioItem) -> np.ndarray: item_data = item.get_value() if item.is_empty: From 227cb0c3e4408ccd295753244230bc7cd720e2c9 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 17:06:45 +0100 Subject: [PATCH 025/118] rename remove_overlaps item method --- src/OSmOSE/data/data_base.py | 2 +- src/OSmOSE/data/item_base.py | 4 ++-- tests/test_item.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index 356c63df..ee8dd39d 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -94,6 +94,6 @@ def from_files( items.append(cls.item_cls(begin=begin, end=first_item.begin)) if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: items.append(cls.item_cls(begin=last_item.end, end=end)) - items = ItemBase.concatenate_items(items) + items = ItemBase.remove_overlaps(items) items = ItemBase.fill_gaps(items) return cls(items=items) diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index 7733ef74..2f8f5c1d 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -84,7 +84,7 @@ def __eq__(self, other: ItemBase) -> bool: return not self.end != other.end @staticmethod - def concatenate_items(items: list[ItemBase]) -> list[ItemBase]: + def remove_overlaps(items: list[ItemBase]) -> list[ItemBase]: """Resolve overlaps between Items. If two Items overlap within the sequence (that is if one Item begins before the end of another, @@ -107,7 +107,7 @@ def concatenate_items(items: list[ItemBase]) -> list[ItemBase]: >>> items = [ItemBase(begin = Timestamp("00:00:00"), end = Timestamp("00:00:15")), ItemBase(begin = Timestamp("00:00:10"), end = Timestamp("00:00:20"))] >>> items[0].end == items[1].begin False - >>> items = ItemBase.concatenate_items(items) + >>> items = ItemBase.remove_overlaps(items) >>> items[0].end == items[1].begin True diff --git a/tests/test_item.py b/tests/test_item.py index 9b27b080..7c478f58 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -61,7 +61,7 @@ ], ) def test_concatenate_item(item_list: list[ItemBase], expected: list[ItemBase]) -> None: - cleaned_items = ItemBase.concatenate_items(item_list) + cleaned_items = ItemBase.remove_overlaps(item_list) assert cleaned_items == expected From 3c5a03f464e23bbcd5bf410627c1eaeea013bf81 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 17:29:09 +0100 Subject: [PATCH 026/118] add AudioData empty filling test --- src/OSmOSE/data/audio_data.py | 3 ++- tests/test_audio.py | 24 ++++++++++++++++++++++++ tests/test_item.py | 2 +- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 64128363..67be7869 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -11,6 +11,7 @@ import soundfile as sf from pathlib import Path + class AudioData(DataBase): """AudioData encapsulating to a collection of AudioItem objects. @@ -76,7 +77,7 @@ def write(self, folder: Path) -> None: Folder in which to write the audio file. """ - sf.write(folder / f"{self}.wav" , self.get_value(), self.sample_rate) + sf.write(folder / f"{self}.wav", self.get_value(), self.sample_rate) def _get_item_value(self, item: AudioItem) -> np.ndarray: item_data = item.get_value() diff --git a/tests/test_audio.py b/tests/test_audio.py index 83c087d8..0aa6fb3e 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -263,6 +263,30 @@ def test_audio_item( generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0][38_400:57_600], id="between_files", ), + pytest.param( + { + "duration": 1, + "inter_file_duration": 1, + "sample_rate": 48_000, + "nb_files": 2, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + None, + None, + np.array( + list( + generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0][ + 0:48_000 + ] + ) + + [0.0] * 48_000 + + list( + generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0][48_000:] + ) + ), + id="empty_space_is_filled", + ), ], indirect=["audio_files"], ) diff --git a/tests/test_item.py b/tests/test_item.py index 7c478f58..97cc9011 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -60,7 +60,7 @@ ), ], ) -def test_concatenate_item(item_list: list[ItemBase], expected: list[ItemBase]) -> None: +def test_remove_overlaps(item_list: list[ItemBase], expected: list[ItemBase]) -> None: cleaned_items = ItemBase.remove_overlaps(item_list) assert cleaned_items == expected From 04ffbfe4762f1d6d40c0c829b226ee9acca79320 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 17:46:38 +0100 Subject: [PATCH 027/118] add test out_of_range AudioData --- tests/test_audio.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/test_audio.py b/tests/test_audio.py index 0aa6fb3e..a4948593 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -278,15 +278,29 @@ def test_audio_item( list( generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0][ 0:48_000 - ] + ], ) + [0.0] * 48_000 + list( - generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0][48_000:] - ) + generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0][48_000:], + ), ), id="empty_space_is_filled", ), + pytest.param( + { + "duration": 1, + "inter_file_duration": 1, + "sample_rate": 48_000, + "nb_files": 2, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + pd.Timestamp("2024-01-01 10:00:00"), + pd.Timestamp("2024-01-01 10:00:01"), + np.zeros(shape=48_000), + id="out_of_range_is_zeros", + ), ], indirect=["audio_files"], ) @@ -301,4 +315,6 @@ def test_audio_data( AudioFile(file, strptime_format=TIMESTAMP_FORMAT_TEST_FILES) for file in files ] data = AudioData.from_files(audio_files, begin=start, end=stop) + if all(item.is_empty for item in data.items): + data.sample_rate = 48_000 assert np.array_equal(data.get_value(), expected) From 1cfd18fe68bf67327a24996bd55f833e98053490 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 17:59:30 +0100 Subject: [PATCH 028/118] add tests for resampling sample count --- tests/test_audio.py | 83 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/tests/test_audio.py b/tests/test_audio.py index a4948593..041f7781 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -282,7 +282,9 @@ def test_audio_item( ) + [0.0] * 48_000 + list( - generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0][48_000:], + generate_sample_audio(nb_files=1, nb_samples=48_000 * 2)[0][ + 48_000: + ], ), ), id="empty_space_is_filled", @@ -318,3 +320,82 @@ def test_audio_data( if all(item.is_empty for item in data.items): data.sample_rate = 48_000 assert np.array_equal(data.get_value(), expected) + + +@pytest.mark.parametrize( + ("audio_files", "start", "stop", "sample_rate", "expected_nb_samples"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + None, + None, + 24_000, + 24_000, + id="downsampling", + ), + pytest.param( + { + "duration": 0.5, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + None, + None, + 96_000, + 48_000, + id="upsampling", + ), + pytest.param( + { + "duration": 2, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + pd.Timestamp("2024-01-01 12:00:01"), + None, + 96_000, + 96_000, + id="upsampling_file_part", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 2, + "inter_file_duration": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + None, + None, + 16_000, + 48_000, + id="downsampling_with_gaps", + ), + ], + indirect=["audio_files"], +) +def test_audio_resample_sample_count( + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + start: pd.Timestamp | None, + stop: pd.Timestamp | None, + sample_rate: int, + expected_nb_samples: int, +) -> None: + files, request = audio_files + audio_files = [ + AudioFile(file, strptime_format=TIMESTAMP_FORMAT_TEST_FILES) for file in files + ] + data = AudioData.from_files(audio_files, begin=start, end=stop) + data.sample_rate = sample_rate + assert data.get_value().shape[0] == expected_nb_samples From 8ab61691ca13f9aeeedb019c2390bb03c65e52b6 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 9 Dec 2024 18:20:00 +0100 Subject: [PATCH 029/118] add docstrings --- src/OSmOSE/data/audio_data.py | 27 +++++++++++++++++++++------ src/OSmOSE/data/audio_item.py | 6 ++++-- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 67be7869..ac9fcdfd 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -2,14 +2,15 @@ from __future__ import annotations +from pathlib import Path + import numpy as np +import soundfile as sf from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES from OSmOSE.data.audio_item import AudioItem from OSmOSE.data.data_base import DataBase from OSmOSE.utils.audio_utils import resample -import soundfile as sf -from pathlib import Path class AudioData(DataBase): @@ -32,16 +33,18 @@ def __init__(self, items: list[AudioItem], sample_rate: int | None = None) -> No """ super().__init__(items) - self._check_sample_rates(sample_rate=sample_rate) + self._set_sample_rate(sample_rate=sample_rate) @property def nb_channels(self) -> int: + """Number of channels of the audio data.""" return max( [1] + [item.nb_channels for item in self.items if type(item) is AudioItem], ) @property - def shape(self): + def shape(self) -> tuple[int, ...]: + """Shape of the audio data.""" data_length = int(self.sample_rate * self.total_seconds) return data_length if self.nb_channels <= 1 else (data_length, self.nb_channels) @@ -49,7 +52,14 @@ def __str__(self) -> str: """Overwrite __str__.""" return self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES) - def _check_sample_rates(self, sample_rate: int | None = None) -> None: + def _set_sample_rate(self, sample_rate: int | None = None) -> None: + """Set the AudioFile sample rate. + + If the sample_rate is specified, it is set. + If it is not specified, it is set to the sampling rate of the + first item that has one. + Else, it is set to None. + """ if sample_rate is not None or any( sample_rate := item.sample_rate for item in self.items @@ -59,7 +69,11 @@ def _check_sample_rates(self, sample_rate: int | None = None) -> None: else: self.sample_rate = None - def get_value(self): + def get_value(self) -> np.ndarray: + """Return the value of the audio data. + + The data from the audio file will be resampled if necessary. + """ data = np.empty(shape=self.shape) idx = 0 for item in self.items: @@ -80,6 +94,7 @@ def write(self, folder: Path) -> None: sf.write(folder / f"{self}.wav", self.get_value(), self.sample_rate) def _get_item_value(self, item: AudioItem) -> np.ndarray: + """Return the resampled (if needed) data from the audio item.""" item_data = item.get_value() if item.is_empty: return item_data.repeat(int(item.total_seconds * self.sample_rate)) diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index 7707578b..ab1d0e58 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -38,9 +38,11 @@ def __init__( super().__init__(file, begin, end) @property - def sample_rate(self): + def sample_rate(self) -> float: + """Sample rate of the associated AudioFile.""" return None if self.is_empty else self.file.metadata.samplerate @property - def nb_channels(self): + def nb_channels(self) -> int: + """Number of channels in the associated AudioFile.""" return 0 if self.is_empty else self.file.metadata.channels From b178a4ca098488b15f1670255c5149ec82adbb0b Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 11 Dec 2024 11:19:00 +0100 Subject: [PATCH 030/118] use generic types --- src/OSmOSE/data/audio_data.py | 2 +- src/OSmOSE/data/audio_item.py | 5 ++--- src/OSmOSE/data/data_base.py | 8 +++++--- src/OSmOSE/data/item_base.py | 19 ++++++++++++------- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index ac9fcdfd..7cfca515 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -13,7 +13,7 @@ from OSmOSE.utils.audio_utils import resample -class AudioData(DataBase): +class AudioData(DataBase[AudioItem]): """AudioData encapsulating to a collection of AudioItem objects. The audio data can be retrieved from several Files through the Items. diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index ab1d0e58..c31bec59 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -4,15 +4,14 @@ from typing import TYPE_CHECKING +from OSmOSE.data.audio_file import AudioFile from OSmOSE.data.item_base import ItemBase if TYPE_CHECKING: from pandas import Timestamp - from OSmOSE.data.audio_file import AudioFile - -class AudioItem(ItemBase): +class AudioItem(ItemBase[AudioFile]): """AudioItem corresponding to a portion of an AudioFile object.""" def __init__( diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index ee8dd39d..6a0d0f38 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np @@ -17,8 +17,10 @@ from OSmOSE.data.file_base import FileBase +TItem = TypeVar("TItem", bound=ItemBase) -class DataBase: + +class DataBase(Generic[TItem]): """Base class for Data objects. A Data object is a collection of Item objects. @@ -27,7 +29,7 @@ class DataBase: item_cls = ItemBase - def __init__(self, items: list[ItemBase]) -> None: + def __init__(self, items: list[TItem]) -> None: """Initialize an DataBase from a list of Items. Parameters diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/item_base.py index 2f8f5c1d..3538e6b5 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/item_base.py @@ -6,15 +6,20 @@ from __future__ import annotations import copy +from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np -from pandas import Timestamp from OSmOSE.data.file_base import FileBase from OSmOSE.utils.timestamp_utils import is_overlapping +if TYPE_CHECKING: + from pandas import Timestamp -class ItemBase: +TFile = TypeVar("TFile", bound=FileBase) + + +class ItemBase(Generic[TFile]): """Base class for the Item objects (e.g. AudioItem). An Item correspond to a portion of a File object. @@ -22,7 +27,7 @@ class ItemBase: def __init__( self, - file: FileBase | None = None, + file: TFile | None = None, begin: Timestamp | None = None, end: Timestamp | None = None, ) -> None: @@ -73,7 +78,7 @@ def total_seconds(self) -> float: """Return the total duration of the item in seconds.""" return (self.end - self.begin).total_seconds() - def __eq__(self, other: ItemBase) -> bool: + def __eq__(self, other: ItemBase[TFile]) -> bool: """Override the default implementation.""" if not isinstance(other, ItemBase): return False @@ -84,7 +89,7 @@ def __eq__(self, other: ItemBase) -> bool: return not self.end != other.end @staticmethod - def remove_overlaps(items: list[ItemBase]) -> list[ItemBase]: + def remove_overlaps(items: list[ItemBase[TFile]]) -> list[ItemBase[TFile]]: """Resolve overlaps between Items. If two Items overlap within the sequence (that is if one Item begins before the end of another, @@ -116,7 +121,7 @@ def remove_overlaps(items: list[ItemBase]) -> list[ItemBase]: [copy.copy(item) for item in items], key=lambda item: (item.begin, item.begin - item.end), ) - concatenated_items: list[ItemBase] = [] + concatenated_items = [] for item in items: concatenated_items.append(item) overlapping_items = [ @@ -141,7 +146,7 @@ def remove_overlaps(items: list[ItemBase]) -> list[ItemBase]: return concatenated_items @staticmethod - def fill_gaps(items: list[ItemBase]) -> list[ItemBase]: + def fill_gaps(items: list[ItemBase[TFile]]) -> list[ItemBase[TFile]]: """Return a list with empty items added in the gaps between items. Parameters From 9603d78ffde5d191a8b43e27b21773a65d1d01ec Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 11 Dec 2024 17:41:45 +0100 Subject: [PATCH 031/118] fix relations between generic and derived classes --- src/OSmOSE/data/audio_data.py | 45 +++++++++++++++++++-- src/OSmOSE/data/audio_dataset.py | 0 src/OSmOSE/data/data_base.py | 67 ++++++++++++++++++++++---------- src/OSmOSE/data/dataset_base.py | 0 4 files changed, 88 insertions(+), 24 deletions(-) create mode 100644 src/OSmOSE/data/audio_dataset.py create mode 100644 src/OSmOSE/data/dataset_base.py diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 7cfca515..22dabcac 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -2,25 +2,29 @@ from __future__ import annotations -from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import soundfile as sf from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES +from OSmOSE.data.audio_file import AudioFile from OSmOSE.data.audio_item import AudioItem from OSmOSE.data.data_base import DataBase from OSmOSE.utils.audio_utils import resample +if TYPE_CHECKING: + from pathlib import Path -class AudioData(DataBase[AudioItem]): + from pandas import Timestamp + + +class AudioData(DataBase[AudioItem, AudioFile]): """AudioData encapsulating to a collection of AudioItem objects. The audio data can be retrieved from several Files through the Items. """ - item_cls = AudioItem - def __init__(self, items: list[AudioItem], sample_rate: int | None = None) -> None: """Initialize an AudioData from a list of AudioItems. @@ -101,3 +105,36 @@ def _get_item_value(self, item: AudioItem) -> np.ndarray: if item.sample_rate != self.sample_rate: return resample(item_data, item.sample_rate, self.sample_rate) return item_data + + @classmethod + def from_files( + cls, + files: list[AudioFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> AudioData: + """Return an AudioData object from a list of AudioFiles. + + Parameters + ---------- + files: list[AudioFile] + List of AudioFiles containing the data. + begin: Timestamp | None + Begin of the data object. + Defaulted to the begin of the first file. + end: Timestamp | None + End of the data object. + Defaulted to the end of the last file. + + Returns + ------- + DataBase[AudioItem, AudioFile]: + The AudioData object. + + """ + items_base = DataBase.items_from_files(files, begin, end) + audio_items = [ + AudioItem(file=item.file, begin=item.begin, end=item.end) + for item in items_base + ] + return cls(audio_items) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py new file mode 100644 index 00000000..e69de29b diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index 6a0d0f38..4ecc1787 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -9,26 +9,25 @@ import numpy as np +from OSmOSE.data.file_base import FileBase from OSmOSE.data.item_base import ItemBase from OSmOSE.utils.timestamp_utils import is_overlapping if TYPE_CHECKING: from pandas import Timestamp - from OSmOSE.data.file_base import FileBase TItem = TypeVar("TItem", bound=ItemBase) +TFile = TypeVar("TFile", bound=FileBase) -class DataBase(Generic[TItem]): +class DataBase(Generic[TItem, TFile]): """Base class for Data objects. A Data object is a collection of Item objects. Data can be retrieved from several Files through the Items. """ - item_cls = ItemBase - def __init__(self, items: list[TItem]) -> None: """Initialize an DataBase from a list of Items. @@ -54,48 +53,76 @@ def get_value(self) -> np.ndarray: @classmethod def from_files( cls, - files: list[FileBase], + files: list[TFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> DataBase[TItem, TFile]: + """Return a base DataBase object from a list of Files. + + Parameters + ---------- + files: list[TFile] + List of Files containing the data. + begin: Timestamp | None + Begin of the data object. + Defaulted to the begin of the first file. + end: Timestamp | None + End of the data object. + Defaulted to the end of the last file. + + Returns + ------- + DataBase[TItem, TFile]: + The DataBase object. + + """ + items = cls.items_from_files(files, begin, end) + return cls(items) + + @classmethod + def items_from_files( + cls, + files: list[TFile], begin: Timestamp | None = None, end: Timestamp | None = None, - ) -> DataBase: - """Initialize a DataBase from a single File. + ) -> list[ItemBase]: + """Return a list of Items from a list of Files and timestamps. - The resulting Data object will contain a single Item. - This single Item will correspond to the whole File. + The Items range from begin to end. + They point to the files that match their timestamps. Parameters ---------- - files: list[OSmOSE.data.file_base.FileBase] + files: list[TFile] The Files encapsulated in the Data object. begin: pandas.Timestamp | None The begin of the Data object. defaulted to the begin of the first File. end: pandas.Timestamp | None The end of the Data object. - default to the end of the last File. + defaulted to the end of the last File. Returns ------- - OSmOSE.data.data_base.DataBase - The Data object. + list[ItemBase] + The list of Items that point to the files. """ begin = min(file.begin for file in files) if begin is None else begin end = max(file.end for file in files) if end is None else end - overlapping_files = [ + included_files = [ file for file in files if is_overlapping((file.begin, file.end), (begin, end)) ] - items = [cls.item_cls(file, begin, end) for file in overlapping_files] + items = [ItemBase(file, begin, end) for file in included_files] if not items: - items.append(cls.item_cls(begin=begin, end=end)) + items.append(ItemBase(begin=begin, end=end)) if (first_item := sorted(items, key=lambda item: item.begin)[0]).begin > begin: - items.append(cls.item_cls(begin=begin, end=first_item.begin)) + items.append(ItemBase(begin=begin, end=first_item.begin)) if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: - items.append(cls.item_cls(begin=last_item.end, end=end)) + items.append(ItemBase(begin=last_item.end, end=end)) items = ItemBase.remove_overlaps(items) - items = ItemBase.fill_gaps(items) - return cls(items=items) + return ItemBase.fill_gaps(items) diff --git a/src/OSmOSE/data/dataset_base.py b/src/OSmOSE/data/dataset_base.py new file mode 100644 index 00000000..e69de29b From 796ea73842c2ec8e1834acd2968c880861e6c997 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 11 Dec 2024 17:49:33 +0100 Subject: [PATCH 032/118] add dataset classes --- src/OSmOSE/data/audio_dataset.py | 33 ++++++++++++++++++++++++++++ src/OSmOSE/data/dataset_base.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index e69de29b..b4717860 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from pathlib import Path + +from pandas import Timedelta, Timestamp + +from OSmOSE.data.audio_data import AudioData +from OSmOSE.data.audio_file import AudioFile +from OSmOSE.data.dataset_base import DatasetBase + + +class AudioDataset(DatasetBase[AudioData, AudioFile]): + def __init__(self, data: list[AudioData]): + super().__init__(data) + + @classmethod + def from_folder( + cls, + folder: Path, + begin: Timestamp, + end: Timestamp, + data_duration: Timedelta, + strptime_format: str, + ) -> AudioDataset: + files = [ + AudioFile(file, strptime_format=strptime_format) + for file in folder.glob("*.wav") + ] + data_base = DatasetBase.data_from_files(files, begin, end, data_duration) + audio_data = [ + AudioData.from_files(files, data.begin, data.end) for data in data_base + ] + return cls(audio_data) diff --git a/src/OSmOSE/data/dataset_base.py b/src/OSmOSE/data/dataset_base.py index e69de29b..b7f48e94 100644 --- a/src/OSmOSE/data/dataset_base.py +++ b/src/OSmOSE/data/dataset_base.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + +from pandas import Timedelta, Timestamp, date_range + +from OSmOSE.data.data_base import DataBase +from OSmOSE.data.file_base import FileBase + +TData = TypeVar("TData", bound=DataBase) +TFile = TypeVar("TFile", bound=FileBase) + + +class DatasetBase(Generic[TData, TFile]): + def __init__(self, data: list[TData]): + self.data = data + + @property + def begin(self): + return min(data.begin for data in self.data) + + @property + def end(self): + return max(data.end for data in self.data) + + @classmethod + def data_from_files( + cls, + files: list[TFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + data_duration: Timedelta | None = None, + ) -> list[DataBase]: + return [ + DataBase.from_files(files, begin=b, end=b + data_duration) + for b in date_range(begin, end, freq=data_duration)[:-1] + ] From d4677b64207f662cba007bd5b6d0e32b7b7c89dd Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 12 Dec 2024 11:58:06 +0100 Subject: [PATCH 033/118] add docstrings in dataset_base --- src/OSmOSE/data/dataset_base.py | 42 ++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/src/OSmOSE/data/dataset_base.py b/src/OSmOSE/data/dataset_base.py index b7f48e94..e0c8c035 100644 --- a/src/OSmOSE/data/dataset_base.py +++ b/src/OSmOSE/data/dataset_base.py @@ -12,15 +12,24 @@ class DatasetBase(Generic[TData, TFile]): - def __init__(self, data: list[TData]): + """Base class for Dataset objects. + + Datasets are collections of Data, with methods + that simplify repeated operations on the data. + """ + + def __init__(self, data: list[TData]) -> None: + """Instantiate a Dataset object from the Data objects.""" self.data = data @property - def begin(self): + def begin(self) -> Timestamp: + """Begin of the first data object.""" return min(data.begin for data in self.data) @property - def end(self): + def end(self) -> Timestamp: + """End of the last data object.""" return max(data.end for data in self.data) @classmethod @@ -31,6 +40,33 @@ def data_from_files( end: Timestamp | None = None, data_duration: Timedelta | None = None, ) -> list[DataBase]: + """Return a list of DataBase objects from File objects. + + These DataBase are linked to the file through ItemBase objects. + Specialized Dataset classes can use these DataBase objects parameters to + instantiate specialized Data objects. + + Parameters + ---------- + files: list[TFile] + The list of files from which the Data objects are built. + begin: Timestamp | None + Begin of the first data object. + Defaulted to the begin of the first file. + end: Timestamp | None + End of the last data object. + Defaulted to the end of the last file. + data_duration: Timedelta | None + Duration of the data objects. + If provided, data will be evenly distributed between begin and end. + Else, one data object will cover the whole period. + + Returns + ------- + list[DataBase]: + A list of DataBase objects. + + """ return [ DataBase.from_files(files, begin=b, end=b + data_duration) for b in date_range(begin, end, freq=data_duration)[:-1] From dae7b210d0dabf4a43e8c1b981b99307507a0f10 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 12 Dec 2024 17:57:46 +0100 Subject: [PATCH 034/118] add specialized constructors from generics --- src/OSmOSE/data/audio_data.py | 8 ++++++++ src/OSmOSE/data/audio_dataset.py | 11 ++++++----- src/OSmOSE/data/audio_file.py | 4 ++++ src/OSmOSE/data/audio_item.py | 12 ++++++++++++ src/OSmOSE/data/dataset_base.py | 14 ++++++++++++++ 5 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 22dabcac..e3db5f3a 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -112,6 +112,7 @@ def from_files( files: list[AudioFile], begin: Timestamp | None = None, end: Timestamp | None = None, + sample_rate: float | None = None, ) -> AudioData: """Return an AudioData object from a list of AudioFiles. @@ -132,9 +133,16 @@ def from_files( The AudioData object. """ + return cls.from_base_data(DataBase.from_files(files, begin, end), sample_rate) items_base = DataBase.items_from_files(files, begin, end) audio_items = [ AudioItem(file=item.file, begin=item.begin, end=item.end) for item in items_base ] return cls(audio_items) + + @classmethod + def from_base_data( + cls, data: DataBase, sample_rate: float | None = None + ) -> AudioData: + return cls([AudioItem.from_base_item(item) for item in data.items], sample_rate) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index b4717860..84ecbe13 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -26,8 +26,9 @@ def from_folder( AudioFile(file, strptime_format=strptime_format) for file in folder.glob("*.wav") ] - data_base = DatasetBase.data_from_files(files, begin, end, data_duration) - audio_data = [ - AudioData.from_files(files, data.begin, data.end) for data in data_base - ] - return cls(audio_data) + base_dataset = DatasetBase.from_files(files, begin, end, data_duration) + return cls.from_base_dataset(base_dataset) + + @classmethod + def from_base_dataset(cls, base_dataset: DatasetBase) -> AudioDataset: + return cls([AudioData.from_base_data(data) for data in base_dataset.data]) diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index 0162ddcc..d689a8d9 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -65,3 +65,7 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: start_sample = round((start - self.begin).total_seconds() * sample_rate) stop_sample = round((stop - self.begin).total_seconds() * sample_rate) return sf.read(self.path, start=start_sample, stop=stop_sample)[0] + + @classmethod + def from_base_file(cls, file: FileBase) -> AudioFile: + return cls(path=file.path, begin=file.begin) diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index c31bec59..a223dfd9 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from OSmOSE.data.audio_file import AudioFile +from OSmOSE.data.file_base import FileBase from OSmOSE.data.item_base import ItemBase if TYPE_CHECKING: @@ -45,3 +46,14 @@ def sample_rate(self) -> float: def nb_channels(self) -> int: """Number of channels in the associated AudioFile.""" return 0 if self.is_empty else self.file.metadata.channels + + @classmethod + def from_base_item(cls, item: ItemBase) -> AudioItem: + file = item.file + if not file or isinstance(file, AudioFile): + return cls(file=file, begin=item.begin, end=item.end) + if isinstance(file, FileBase): + return cls( + file=AudioFile.from_base_file(file), begin=item.begin, end=item.end + ) + raise TypeError diff --git a/src/OSmOSE/data/dataset_base.py b/src/OSmOSE/data/dataset_base.py index e0c8c035..b3d848e4 100644 --- a/src/OSmOSE/data/dataset_base.py +++ b/src/OSmOSE/data/dataset_base.py @@ -71,3 +71,17 @@ def data_from_files( DataBase.from_files(files, begin=b, end=b + data_duration) for b in date_range(begin, end, freq=data_duration)[:-1] ] + + @classmethod + def from_files( + cls, + files: list[TFile], + begin: Timestamp, + end: Timestamp, + data_duration: Timedelta, + ) -> DatasetBase: + data_base = [ + DataBase.from_files(files, begin=b, end=b + data_duration) + for b in date_range(begin, end, freq=data_duration)[:-1] + ] + return cls(data_base) From bea4aa4ba37ce2cd009e6fc0946323c1eed48c1e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 13 Dec 2024 10:16:09 +0100 Subject: [PATCH 035/118] consider default dataset attributes --- src/OSmOSE/data/audio_dataset.py | 6 ++--- src/OSmOSE/data/dataset_base.py | 46 +++++++++++++------------------- 2 files changed, 21 insertions(+), 31 deletions(-) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index 84ecbe13..e04d73e1 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -17,10 +17,10 @@ def __init__(self, data: list[AudioData]): def from_folder( cls, folder: Path, - begin: Timestamp, - end: Timestamp, - data_duration: Timedelta, strptime_format: str, + begin: Timestamp | None = None, + end: Timestamp | None = None, + data_duration: Timedelta | None = None, ) -> AudioDataset: files = [ AudioFile(file, strptime_format=strptime_format) diff --git a/src/OSmOSE/data/dataset_base.py b/src/OSmOSE/data/dataset_base.py index b3d848e4..91d220e6 100644 --- a/src/OSmOSE/data/dataset_base.py +++ b/src/OSmOSE/data/dataset_base.py @@ -33,23 +33,19 @@ def end(self) -> Timestamp: return max(data.end for data in self.data) @classmethod - def data_from_files( + def from_files( cls, files: list[TFile], begin: Timestamp | None = None, end: Timestamp | None = None, data_duration: Timedelta | None = None, - ) -> list[DataBase]: - """Return a list of DataBase objects from File objects. - - These DataBase are linked to the file through ItemBase objects. - Specialized Dataset classes can use these DataBase objects parameters to - instantiate specialized Data objects. + ) -> DatasetBase: + """Return a base DatasetBase object from a list of Files. Parameters ---------- files: list[TFile] - The list of files from which the Data objects are built. + The list of files contained in the Dataset. begin: Timestamp | None Begin of the first data object. Defaulted to the begin of the first file. @@ -59,29 +55,23 @@ def data_from_files( data_duration: Timedelta | None Duration of the data objects. If provided, data will be evenly distributed between begin and end. - Else, one data object will cover the whole period. + Else, one data object will cover the whole time period. Returns ------- - list[DataBase]: - A list of DataBase objects. + DataBase[TItem, TFile]: + The DataBase object. """ - return [ - DataBase.from_files(files, begin=b, end=b + data_duration) - for b in date_range(begin, end, freq=data_duration)[:-1] - ] - - @classmethod - def from_files( - cls, - files: list[TFile], - begin: Timestamp, - end: Timestamp, - data_duration: Timedelta, - ) -> DatasetBase: - data_base = [ - DataBase.from_files(files, begin=b, end=b + data_duration) - for b in date_range(begin, end, freq=data_duration)[:-1] - ] + if not begin: + begin = min(file.begin for file in files) + if not end: + end = max(file.end for file in files) + if data_duration: + data_base = [ + DataBase.from_files(files, begin=b, end=b + data_duration) + for b in date_range(begin, end, freq=data_duration)[:-1] + ] + else: + data_base = [DataBase.from_files(files, begin=begin, end=end)] return cls(data_base) From bafa2c378c35a0f1146ee5298e01a7878d0dce87 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 13 Dec 2024 10:39:57 +0100 Subject: [PATCH 036/118] update docstrings --- src/OSmOSE/data/audio_data.py | 38 +++++++++++++++++------- src/OSmOSE/data/audio_dataset.py | 51 +++++++++++++++++++++++++++++--- src/OSmOSE/data/audio_file.py | 1 + src/OSmOSE/data/audio_item.py | 5 +++- src/OSmOSE/data/data_base.py | 3 +- src/OSmOSE/data/dataset_base.py | 6 ++++ 6 files changed, 88 insertions(+), 16 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index e3db5f3a..47016caa 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -1,4 +1,8 @@ -"""AudioData encapsulating to a collection of AudioItem objects.""" +"""AudioData represent audio data scattered through different AudioFiles. + +The AudioData has a collection of AudioItem. +The data is accessed via an AudioItem object per AudioFile. +""" from __future__ import annotations @@ -20,9 +24,10 @@ class AudioData(DataBase[AudioItem, AudioFile]): - """AudioData encapsulating to a collection of AudioItem objects. + """AudioData represent audio data scattered through different AudioFiles. - The audio data can be retrieved from several Files through the Items. + The AudioData has a collection of AudioItem. + The data is accessed via an AudioItem object per AudioFile. """ def __init__(self, items: list[AudioItem], sample_rate: int | None = None) -> None: @@ -126,6 +131,8 @@ def from_files( end: Timestamp | None End of the data object. Defaulted to the end of the last file. + sample_rate: float | None + Sample rate of the AudioData. Returns ------- @@ -134,15 +141,26 @@ def from_files( """ return cls.from_base_data(DataBase.from_files(files, begin, end), sample_rate) - items_base = DataBase.items_from_files(files, begin, end) - audio_items = [ - AudioItem(file=item.file, begin=item.begin, end=item.end) - for item in items_base - ] - return cls(audio_items) @classmethod def from_base_data( - cls, data: DataBase, sample_rate: float | None = None + cls, + data: DataBase, + sample_rate: float | None = None, ) -> AudioData: + """Return an AudioData object from a DataBase object. + + Parameters + ---------- + data: DataBase + DataBase object to convert to AudioData. + sample_rate: float | None + Sample rate of the AudioData. + + Returns + ------- + AudioData: + The AudioData object. + + """ return cls([AudioItem.from_base_item(item) for item in data.items], sample_rate) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index e04d73e1..b20e344c 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -1,16 +1,33 @@ -from __future__ import annotations +"""AudioDataset is a collection of AudioData objects. + +AudioDataset is a collection of AudioData, with methods +that simplify repeated operations on the audio data. +""" -from pathlib import Path +from __future__ import annotations -from pandas import Timedelta, Timestamp +from typing import TYPE_CHECKING from OSmOSE.data.audio_data import AudioData from OSmOSE.data.audio_file import AudioFile from OSmOSE.data.dataset_base import DatasetBase +if TYPE_CHECKING: + from pathlib import Path + + from pandas import Timedelta, Timestamp + class AudioDataset(DatasetBase[AudioData, AudioFile]): - def __init__(self, data: list[AudioData]): + """AudioDataset is a collection of AudioData objects. + + AudioDataset is a collection of AudioData, with methods + that simplify repeated operations on the audio data. + + """ + + def __init__(self, data: list[AudioData]) -> None: + """Initialize an AudioDataset.""" super().__init__(data) @classmethod @@ -22,6 +39,31 @@ def from_folder( end: Timestamp | None = None, data_duration: Timedelta | None = None, ) -> AudioDataset: + """Return an AudioDataset from a folder containing the audio files. + + Parameters + ---------- + folder: Path + The folder containing the audio files. + strptime_format: str + The strptime format of the timestamps in the audio file names. + begin: Timestamp | None + The begin of the audio dataset. + Defaulted to the begin of the first file. + end: Timestamp | None + The end of the audio dataset. + Defaulted to the end of the last file. + data_duration: Timedelta | None + Duration of the audio data objects. + If provided, audio data will be evenly distributed between begin and end. + Else, one data object will cover the whole time period. + + Returns + ------- + Audiodataset: + The audio dataset. + + """ files = [ AudioFile(file, strptime_format=strptime_format) for file in folder.glob("*.wav") @@ -31,4 +73,5 @@ def from_folder( @classmethod def from_base_dataset(cls, base_dataset: DatasetBase) -> AudioDataset: + """Return an AudioDataset object from a DatasetBase object.""" return cls([AudioData.from_base_data(data) for data in base_dataset.data]) diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index d689a8d9..ac970e4d 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -68,4 +68,5 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: @classmethod def from_base_file(cls, file: FileBase) -> AudioFile: + """Return an AudioFile object from a FileBase object.""" return cls(path=file.path, begin=file.begin) diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index a223dfd9..05fd1fad 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -49,11 +49,14 @@ def nb_channels(self) -> int: @classmethod def from_base_item(cls, item: ItemBase) -> AudioItem: + """Return an AudioItem object from an ItemBase object.""" file = item.file if not file or isinstance(file, AudioFile): return cls(file=file, begin=item.begin, end=item.end) if isinstance(file, FileBase): return cls( - file=AudioFile.from_base_file(file), begin=item.begin, end=item.end + file=AudioFile.from_base_file(file), + begin=item.begin, + end=item.end, ) raise TypeError diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index 4ecc1787..278a46db 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -1,6 +1,7 @@ """DataBase: Base class for the Data objects (e.g. AudioData). -Data corresponds to a collection of Items. +Data corresponds to data scattered through different AudioFiles. +The data is accessed via an AudioItem object per AudioFile. """ from __future__ import annotations diff --git a/src/OSmOSE/data/dataset_base.py b/src/OSmOSE/data/dataset_base.py index 91d220e6..3f9af366 100644 --- a/src/OSmOSE/data/dataset_base.py +++ b/src/OSmOSE/data/dataset_base.py @@ -1,3 +1,9 @@ +"""DatasetBase: Base class for the Dataset objects (e.g. AudioDataset). + +Datasets are collections of Data, with methods +that simplify repeated operations on the data. +""" + from __future__ import annotations from typing import Generic, TypeVar From 59badbe7098c57448dfabf89c32d79b0d9110f0d Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 13 Dec 2024 11:39:12 +0100 Subject: [PATCH 037/118] add audiodataset sample_rate property --- src/OSmOSE/data/audio_dataset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index b20e344c..af63044d 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -30,6 +30,16 @@ def __init__(self, data: list[AudioData]) -> None: """Initialize an AudioDataset.""" super().__init__(data) + @property + def sample_rate(self) -> set[float]: + """Return the sample rate of the audio data.""" + return {data.sample_rate for data in self.data} + + @sample_rate.setter + def sample_rate(self, sample_rate: float) -> None: + for data in self.data: + data.sample_rate = sample_rate + @classmethod def from_folder( cls, From 6d95f2cb94c29c97b2ac823e6ca67b89ca16e1f8 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 13 Dec 2024 11:52:19 +0100 Subject: [PATCH 038/118] add is_empty property to data base --- src/OSmOSE/data/data_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/data_base.py index 278a46db..ad1e8abc 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/data_base.py @@ -47,6 +47,11 @@ def total_seconds(self) -> float: """Return the total duration of the data in seconds.""" return (self.end - self.begin).total_seconds() + @property + def is_empty(self) -> bool: + """Return true if every item of this data object is empty.""" + return all(item.is_empty for item in self.items) + def get_value(self) -> np.ndarray: """Get the concatenated values from all Items.""" return np.concatenate([item.get_value() for item in self.items]) From 8f928e3e0d0f9188d96d5540885041041a7d1d88 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 13 Dec 2024 12:10:31 +0100 Subject: [PATCH 039/118] rename SthgBase to BaseSthg --- src/OSmOSE/data/audio_data.py | 18 ++-- src/OSmOSE/data/audio_dataset.py | 10 +- src/OSmOSE/data/audio_file.py | 8 +- src/OSmOSE/data/audio_item.py | 12 +-- .../data/{data_base.py => base_data.py} | 48 +++++----- .../data/{dataset_base.py => base_dataset.py} | 22 ++--- .../data/{file_base.py => base_file.py} | 6 +- .../data/{item_base.py => base_item.py} | 40 ++++---- tests/test_item.py | 94 +++++++++---------- 9 files changed, 129 insertions(+), 129 deletions(-) rename src/OSmOSE/data/{data_base.py => base_data.py} (73%) rename src/OSmOSE/data/{dataset_base.py => base_dataset.py} (79%) rename src/OSmOSE/data/{file_base.py => base_file.py} (94%) rename src/OSmOSE/data/{item_base.py => base_item.py} (81%) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 47016caa..9752e162 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -14,7 +14,7 @@ from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES from OSmOSE.data.audio_file import AudioFile from OSmOSE.data.audio_item import AudioItem -from OSmOSE.data.data_base import DataBase +from OSmOSE.data.base_data import BaseData from OSmOSE.utils.audio_utils import resample if TYPE_CHECKING: @@ -23,7 +23,7 @@ from pandas import Timestamp -class AudioData(DataBase[AudioItem, AudioFile]): +class AudioData(BaseData[AudioItem, AudioFile]): """AudioData represent audio data scattered through different AudioFiles. The AudioData has a collection of AudioItem. @@ -136,24 +136,24 @@ def from_files( Returns ------- - DataBase[AudioItem, AudioFile]: - The AudioData object. + AudioData: + The AudioData object. """ - return cls.from_base_data(DataBase.from_files(files, begin, end), sample_rate) + return cls.from_base_data(BaseData.from_files(files, begin, end), sample_rate) @classmethod def from_base_data( cls, - data: DataBase, + data: BaseData, sample_rate: float | None = None, ) -> AudioData: - """Return an AudioData object from a DataBase object. + """Return an AudioData object from a BaseData object. Parameters ---------- - data: DataBase - DataBase object to convert to AudioData. + data: BaseData + BaseData object to convert to AudioData. sample_rate: float | None Sample rate of the AudioData. diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index af63044d..939e5582 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -10,7 +10,7 @@ from OSmOSE.data.audio_data import AudioData from OSmOSE.data.audio_file import AudioFile -from OSmOSE.data.dataset_base import DatasetBase +from OSmOSE.data.base_dataset import BaseDataset if TYPE_CHECKING: from pathlib import Path @@ -18,7 +18,7 @@ from pandas import Timedelta, Timestamp -class AudioDataset(DatasetBase[AudioData, AudioFile]): +class AudioDataset(BaseDataset[AudioData, AudioFile]): """AudioDataset is a collection of AudioData objects. AudioDataset is a collection of AudioData, with methods @@ -78,10 +78,10 @@ def from_folder( AudioFile(file, strptime_format=strptime_format) for file in folder.glob("*.wav") ] - base_dataset = DatasetBase.from_files(files, begin, end, data_duration) + base_dataset = BaseDataset.from_files(files, begin, end, data_duration) return cls.from_base_dataset(base_dataset) @classmethod - def from_base_dataset(cls, base_dataset: DatasetBase) -> AudioDataset: - """Return an AudioDataset object from a DatasetBase object.""" + def from_base_dataset(cls, base_dataset: BaseDataset) -> AudioDataset: + """Return an AudioDataset object from a BaseDataset object.""" return cls([AudioData.from_base_data(data) for data in base_dataset.data]) diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index ac970e4d..65a9d28e 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -11,10 +11,10 @@ import soundfile as sf from pandas import Timedelta, Timestamp -from OSmOSE.data.file_base import FileBase +from OSmOSE.data.base_file import BaseFile -class AudioFile(FileBase): +class AudioFile(BaseFile): """Audio file associated with timestamps.""" def __init__( @@ -67,6 +67,6 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: return sf.read(self.path, start=start_sample, stop=stop_sample)[0] @classmethod - def from_base_file(cls, file: FileBase) -> AudioFile: - """Return an AudioFile object from a FileBase object.""" + def from_base_file(cls, file: BaseFile) -> AudioFile: + """Return an AudioFile object from a BaseFile object.""" return cls(path=file.path, begin=file.begin) diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index 05fd1fad..c730d930 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -5,14 +5,14 @@ from typing import TYPE_CHECKING from OSmOSE.data.audio_file import AudioFile -from OSmOSE.data.file_base import FileBase -from OSmOSE.data.item_base import ItemBase +from OSmOSE.data.base_file import BaseFile +from OSmOSE.data.base_item import BaseItem if TYPE_CHECKING: from pandas import Timestamp -class AudioItem(ItemBase[AudioFile]): +class AudioItem(BaseItem[AudioFile]): """AudioItem corresponding to a portion of an AudioFile object.""" def __init__( @@ -48,12 +48,12 @@ def nb_channels(self) -> int: return 0 if self.is_empty else self.file.metadata.channels @classmethod - def from_base_item(cls, item: ItemBase) -> AudioItem: - """Return an AudioItem object from an ItemBase object.""" + def from_base_item(cls, item: BaseItem) -> AudioItem: + """Return an AudioItem object from an BaseItem object.""" file = item.file if not file or isinstance(file, AudioFile): return cls(file=file, begin=item.begin, end=item.end) - if isinstance(file, FileBase): + if isinstance(file, BaseFile): return cls( file=AudioFile.from_base_file(file), begin=item.begin, diff --git a/src/OSmOSE/data/data_base.py b/src/OSmOSE/data/base_data.py similarity index 73% rename from src/OSmOSE/data/data_base.py rename to src/OSmOSE/data/base_data.py index ad1e8abc..ebd464a9 100644 --- a/src/OSmOSE/data/data_base.py +++ b/src/OSmOSE/data/base_data.py @@ -1,7 +1,7 @@ -"""DataBase: Base class for the Data objects (e.g. AudioData). +"""BaseData: Base class for the Data objects. -Data corresponds to data scattered through different AudioFiles. -The data is accessed via an AudioItem object per AudioFile. +Data corresponds to data scattered through different Files. +The data is accessed via an Item object per File. """ from __future__ import annotations @@ -10,31 +10,31 @@ import numpy as np -from OSmOSE.data.file_base import FileBase -from OSmOSE.data.item_base import ItemBase +from OSmOSE.data.base_file import BaseFile +from OSmOSE.data.base_item import BaseItem from OSmOSE.utils.timestamp_utils import is_overlapping if TYPE_CHECKING: from pandas import Timestamp -TItem = TypeVar("TItem", bound=ItemBase) -TFile = TypeVar("TFile", bound=FileBase) +TItem = TypeVar("TItem", bound=BaseItem) +TFile = TypeVar("TFile", bound=BaseFile) -class DataBase(Generic[TItem, TFile]): - """Base class for Data objects. +class BaseData(Generic[TItem, TFile]): + """Base class for the Data objects. - A Data object is a collection of Item objects. - Data can be retrieved from several Files through the Items. + Data corresponds to data scattered through different Files. + The data is accessed via an Item object per File. """ def __init__(self, items: list[TItem]) -> None: - """Initialize an DataBase from a list of Items. + """Initialize an BaseData from a list of Items. Parameters ---------- - items: list[ItemBase] + items: list[BaseItem] List of the Items constituting the Data. """ @@ -62,7 +62,7 @@ def from_files( files: list[TFile], begin: Timestamp | None = None, end: Timestamp | None = None, - ) -> DataBase[TItem, TFile]: + ) -> BaseData[TItem, TFile]: """Return a base DataBase object from a list of Files. Parameters @@ -78,8 +78,8 @@ def from_files( Returns ------- - DataBase[TItem, TFile]: - The DataBase object. + BaseData[TItem, TFile]: + The BaseData object. """ items = cls.items_from_files(files, begin, end) @@ -91,7 +91,7 @@ def items_from_files( files: list[TFile], begin: Timestamp | None = None, end: Timestamp | None = None, - ) -> list[ItemBase]: + ) -> list[BaseItem]: """Return a list of Items from a list of Files and timestamps. The Items range from begin to end. @@ -110,7 +110,7 @@ def items_from_files( Returns ------- - list[ItemBase] + list[BaseItem] The list of Items that point to the files. """ @@ -123,12 +123,12 @@ def items_from_files( if is_overlapping((file.begin, file.end), (begin, end)) ] - items = [ItemBase(file, begin, end) for file in included_files] + items = [BaseItem(file, begin, end) for file in included_files] if not items: - items.append(ItemBase(begin=begin, end=end)) + items.append(BaseItem(begin=begin, end=end)) if (first_item := sorted(items, key=lambda item: item.begin)[0]).begin > begin: - items.append(ItemBase(begin=begin, end=first_item.begin)) + items.append(BaseItem(begin=begin, end=first_item.begin)) if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: - items.append(ItemBase(begin=last_item.end, end=end)) - items = ItemBase.remove_overlaps(items) - return ItemBase.fill_gaps(items) + items.append(BaseItem(begin=last_item.end, end=end)) + items = BaseItem.remove_overlaps(items) + return BaseItem.fill_gaps(items) diff --git a/src/OSmOSE/data/dataset_base.py b/src/OSmOSE/data/base_dataset.py similarity index 79% rename from src/OSmOSE/data/dataset_base.py rename to src/OSmOSE/data/base_dataset.py index 3f9af366..c3bea23e 100644 --- a/src/OSmOSE/data/dataset_base.py +++ b/src/OSmOSE/data/base_dataset.py @@ -1,4 +1,4 @@ -"""DatasetBase: Base class for the Dataset objects (e.g. AudioDataset). +"""BaseDataset: Base class for the Dataset objects. Datasets are collections of Data, with methods that simplify repeated operations on the data. @@ -10,14 +10,14 @@ from pandas import Timedelta, Timestamp, date_range -from OSmOSE.data.data_base import DataBase -from OSmOSE.data.file_base import FileBase +from OSmOSE.data.base_data import BaseData +from OSmOSE.data.base_file import BaseFile -TData = TypeVar("TData", bound=DataBase) -TFile = TypeVar("TFile", bound=FileBase) +TData = TypeVar("TData", bound=BaseData) +TFile = TypeVar("TFile", bound=BaseFile) -class DatasetBase(Generic[TData, TFile]): +class BaseDataset(Generic[TData, TFile]): """Base class for Dataset objects. Datasets are collections of Data, with methods @@ -45,8 +45,8 @@ def from_files( begin: Timestamp | None = None, end: Timestamp | None = None, data_duration: Timedelta | None = None, - ) -> DatasetBase: - """Return a base DatasetBase object from a list of Files. + ) -> BaseDataset: + """Return a base BaseDataset object from a list of Files. Parameters ---------- @@ -65,7 +65,7 @@ def from_files( Returns ------- - DataBase[TItem, TFile]: + BaseDataset[TItem, TFile]: The DataBase object. """ @@ -75,9 +75,9 @@ def from_files( end = max(file.end for file in files) if data_duration: data_base = [ - DataBase.from_files(files, begin=b, end=b + data_duration) + BaseData.from_files(files, begin=b, end=b + data_duration) for b in date_range(begin, end, freq=data_duration)[:-1] ] else: - data_base = [DataBase.from_files(files, begin=begin, end=end)] + data_base = [BaseData.from_files(files, begin=begin, end=end)] return cls(data_base) diff --git a/src/OSmOSE/data/file_base.py b/src/OSmOSE/data/base_file.py similarity index 94% rename from src/OSmOSE/data/file_base.py rename to src/OSmOSE/data/base_file.py index 65f2c190..9025ef62 100644 --- a/src/OSmOSE/data/file_base.py +++ b/src/OSmOSE/data/base_file.py @@ -1,4 +1,4 @@ -"""FileBase: Base class for the File objects (e.g. AudioFile). +"""BaseFile: Base class for the File objects. A File object associates file-written data to timestamps. """ @@ -18,8 +18,8 @@ from OSmOSE.utils.timestamp_utils import strptime_from_text -class FileBase: - """Base class for the File objects (e.g. AudioFile). +class BaseFile: + """Base class for the File objects. A File object associates file-written data to timestamps. """ diff --git a/src/OSmOSE/data/item_base.py b/src/OSmOSE/data/base_item.py similarity index 81% rename from src/OSmOSE/data/item_base.py rename to src/OSmOSE/data/base_item.py index 3538e6b5..8f5078c4 100644 --- a/src/OSmOSE/data/item_base.py +++ b/src/OSmOSE/data/base_item.py @@ -1,4 +1,4 @@ -"""ItemBase: Base class for the Item objects (e.g. AudioItem). +"""BaseItem: Base class for the Item objects. Items correspond to a portion of a File object. """ @@ -10,17 +10,17 @@ import numpy as np -from OSmOSE.data.file_base import FileBase +from OSmOSE.data.base_file import BaseFile from OSmOSE.utils.timestamp_utils import is_overlapping if TYPE_CHECKING: from pandas import Timestamp -TFile = TypeVar("TFile", bound=FileBase) +TFile = TypeVar("TFile", bound=BaseFile) -class ItemBase(Generic[TFile]): - """Base class for the Item objects (e.g. AudioItem). +class BaseItem(Generic[TFile]): + """Base class for the Item objects. An Item correspond to a portion of a File object. """ @@ -31,11 +31,11 @@ def __init__( begin: Timestamp | None = None, end: Timestamp | None = None, ) -> None: - """Initialize an ItemBase from a File and begin/end timestamps. + """Initialize an BaseItem from a File and begin/end timestamps. Parameters ---------- - file: OSmOSE.data.file_base.FileBase + file: TFile The File in which this Item belongs. begin: pandas.Timestamp (optional) The timestamp at which this item begins. @@ -78,9 +78,9 @@ def total_seconds(self) -> float: """Return the total duration of the item in seconds.""" return (self.end - self.begin).total_seconds() - def __eq__(self, other: ItemBase[TFile]) -> bool: + def __eq__(self, other: BaseItem[TFile]) -> bool: """Override the default implementation.""" - if not isinstance(other, ItemBase): + if not isinstance(other, BaseItem): return False if self.file != other.file: return False @@ -89,7 +89,7 @@ def __eq__(self, other: ItemBase[TFile]) -> bool: return not self.end != other.end @staticmethod - def remove_overlaps(items: list[ItemBase[TFile]]) -> list[ItemBase[TFile]]: + def remove_overlaps(items: list[BaseItem[TFile]]) -> list[BaseItem[TFile]]: """Resolve overlaps between Items. If two Items overlap within the sequence (that is if one Item begins before the end of another, @@ -99,20 +99,20 @@ def remove_overlaps(items: list[ItemBase[TFile]]) -> list[ItemBase[TFile]]: Parameters ---------- - items: list[ItemBase] + items: list[BaseItem] List of Items to concatenate. Returns ------- - list[ItemBase]: + list[BaseItem]: The list of Items with no overlapping Items. Examples -------- - >>> items = [ItemBase(begin = Timestamp("00:00:00"), end = Timestamp("00:00:15")), ItemBase(begin = Timestamp("00:00:10"), end = Timestamp("00:00:20"))] + >>> items = [BaseItem(begin = Timestamp("00:00:00"), end = Timestamp("00:00:15")), BaseItem(begin = Timestamp("00:00:10"), end = Timestamp("00:00:20"))] >>> items[0].end == items[1].begin False - >>> items = ItemBase.remove_overlaps(items) + >>> items = BaseItem.remove_overlaps(items) >>> items[0].end == items[1].begin True @@ -146,23 +146,23 @@ def remove_overlaps(items: list[ItemBase[TFile]]) -> list[ItemBase[TFile]]: return concatenated_items @staticmethod - def fill_gaps(items: list[ItemBase[TFile]]) -> list[ItemBase[TFile]]: + def fill_gaps(items: list[BaseItem[TFile]]) -> list[BaseItem[TFile]]: """Return a list with empty items added in the gaps between items. Parameters ---------- - items: list[ItemBase] + items: list[BaseItem] List of Items to fill. Returns ------- - list[ItemBase]: + list[BaseItem]: List of Items with no gaps. Examples -------- - >>> items = [ItemBase(begin = Timestamp("00:00:00"), end = Timestamp("00:00:10")), ItemBase(begin = Timestamp("00:00:15"), end = Timestamp("00:00:25"))] - >>> items = ItemBase.fill_gaps(items) + >>> items = [BaseItem(begin = Timestamp("00:00:00"), end = Timestamp("00:00:10")), BaseItem(begin = Timestamp("00:00:15"), end = Timestamp("00:00:25"))] + >>> items = BaseItem.fill_gaps(items) >>> [(item.begin.second, item.end.second) for item in items] [(0, 10), (10, 15), (15, 25)] @@ -173,6 +173,6 @@ def fill_gaps(items: list[ItemBase[TFile]]) -> list[ItemBase[TFile]]: next_item = items[index + 1] filled_item_list.append(item) if next_item.begin > item.end: - filled_item_list.append(ItemBase(begin=item.end, end=next_item.begin)) + filled_item_list.append(BaseItem(begin=item.end, end=next_item.begin)) filled_item_list.append(items[-1]) return filled_item_list diff --git a/tests/test_item.py b/tests/test_item.py index 97cc9011..fb464dc8 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -3,65 +3,65 @@ import pytest from pandas import Timestamp -from OSmOSE.data.item_base import ItemBase +from OSmOSE.data.base_item import BaseItem @pytest.mark.parametrize( ("item_list", "expected"), [ pytest.param( - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], id="only_one_item_is_unchanged", ), pytest.param( [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), ], - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], id="doubled_item_is_removed", ), pytest.param( [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), ], [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), ], id="overlapping_item_is_truncated", ), pytest.param( [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:15")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:15")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), ], [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), ], id="longest_item_is_prioritized", ), pytest.param( [ - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:25")), - ItemBase(begin=Timestamp("00:00:0"), end=Timestamp("00:00:15")), - ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:25")), + BaseItem(begin=Timestamp("00:00:0"), end=Timestamp("00:00:15")), + BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), ], [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), ], id="items_are_reordered", ), ], ) -def test_remove_overlaps(item_list: list[ItemBase], expected: list[ItemBase]) -> None: - cleaned_items = ItemBase.remove_overlaps(item_list) +def test_remove_overlaps(item_list: list[BaseItem], expected: list[BaseItem]) -> None: + cleaned_items = BaseItem.remove_overlaps(item_list) assert cleaned_items == expected @@ -69,53 +69,53 @@ def test_remove_overlaps(item_list: list[ItemBase], expected: list[ItemBase]) -> ("item_list", "expected"), [ pytest.param( - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], - [ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], id="only_one_item_is_unchanged", ), pytest.param( [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), ], [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), ], id="consecutive_items_are_unchanged", ), pytest.param( [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), ], [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), ], id="one_gap_is_filled", ), pytest.param( [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), - ItemBase(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), - ItemBase(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + BaseItem(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), + BaseItem(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), ], [ - ItemBase(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ItemBase(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ItemBase(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), - ItemBase(begin=Timestamp("00:00:30"), end=Timestamp("00:00:35")), - ItemBase(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), - ItemBase(begin=Timestamp("00:00:45"), end=Timestamp("00:01:00")), - ItemBase(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), + BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + BaseItem(begin=Timestamp("00:00:30"), end=Timestamp("00:00:35")), + BaseItem(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), + BaseItem(begin=Timestamp("00:00:45"), end=Timestamp("00:01:00")), + BaseItem(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), ], id="multiple_gaps_are_filled", ), ], ) -def test_fill_item_gaps(item_list: list[ItemBase], expected: list[ItemBase]) -> None: - filled_items = ItemBase.fill_gaps(item_list) +def test_fill_item_gaps(item_list: list[BaseItem], expected: list[BaseItem]) -> None: + filled_items = BaseItem.fill_gaps(item_list) assert filled_items == expected From d9db7c60d417593ee2642ead37df72595593014d Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 13 Dec 2024 12:22:29 +0100 Subject: [PATCH 040/118] add dataset write method --- src/OSmOSE/data/base_data.py | 5 +++++ src/OSmOSE/data/base_dataset.py | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index ebd464a9..c6c9fc48 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -6,6 +6,7 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np @@ -56,6 +57,10 @@ def get_value(self) -> np.ndarray: """Get the concatenated values from all Items.""" return np.concatenate([item.get_value() for item in self.items]) + def write(self, path: Path) -> None: + """Abstract method for writing the data.""" + return + @classmethod def from_files( cls, diff --git a/src/OSmOSE/data/base_dataset.py b/src/OSmOSE/data/base_dataset.py index c3bea23e..60b1d914 100644 --- a/src/OSmOSE/data/base_dataset.py +++ b/src/OSmOSE/data/base_dataset.py @@ -6,6 +6,7 @@ from __future__ import annotations +from pathlib import Path from typing import Generic, TypeVar from pandas import Timedelta, Timestamp, date_range @@ -38,6 +39,18 @@ def end(self) -> Timestamp: """End of the last data object.""" return max(data.end for data in self.data) + def write(self, folder: Path) -> None: + """Write all data objects in the specified folder. + + Parameters + ---------- + folder: Path + Folder in which to write the data. + + """ + for data in self.data: + data.write(folder) + @classmethod def from_files( cls, From d5c8908d8170f04b584f9fdeadaf7adfa0595698 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 10:15:12 +0100 Subject: [PATCH 041/118] set sample_rate of empty data in dataset --- src/OSmOSE/data/audio_dataset.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index 939e5582..95990b54 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -6,6 +6,7 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING from OSmOSE.data.audio_data import AudioData @@ -29,6 +30,18 @@ class AudioDataset(BaseDataset[AudioData, AudioFile]): def __init__(self, data: list[AudioData]) -> None: """Initialize an AudioDataset.""" super().__init__(data) + if ( + len( + sample_rates := { + data.sample_rate for data in data if data.sample_rate is not None + } + ) + != 1 + ): + logging.warning("Audio dataset contains different sample rates.") + else: + for empty_data in (data for data in data if data.sample_rate is None): + empty_data.sample_rate = min(sample_rates) @property def sample_rate(self) -> set[float]: From a2cdc4e63df14dceb1be38783be54b842aca9ea4 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 11:49:13 +0100 Subject: [PATCH 042/118] move remove_overlaps to util functions --- src/OSmOSE/data/audio_dataset.py | 4 +- src/OSmOSE/data/base_item.py | 58 -------------------------- src/OSmOSE/utils/timestamp_utils.py | 64 ++++++++++++++++++++++++++++- tests/test_item.py | 3 +- 4 files changed, 67 insertions(+), 62 deletions(-) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index 95990b54..2e2680bd 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -29,12 +29,11 @@ class AudioDataset(BaseDataset[AudioData, AudioFile]): def __init__(self, data: list[AudioData]) -> None: """Initialize an AudioDataset.""" - super().__init__(data) if ( len( sample_rates := { data.sample_rate for data in data if data.sample_rate is not None - } + }, ) != 1 ): @@ -42,6 +41,7 @@ def __init__(self, data: list[AudioData]) -> None: else: for empty_data in (data for data in data if data.sample_rate is None): empty_data.sample_rate = min(sample_rates) + super().__init__(data) @property def sample_rate(self) -> set[float]: diff --git a/src/OSmOSE/data/base_item.py b/src/OSmOSE/data/base_item.py index 8f5078c4..cb9d79e6 100644 --- a/src/OSmOSE/data/base_item.py +++ b/src/OSmOSE/data/base_item.py @@ -11,7 +11,6 @@ import numpy as np from OSmOSE.data.base_file import BaseFile -from OSmOSE.utils.timestamp_utils import is_overlapping if TYPE_CHECKING: from pandas import Timestamp @@ -88,63 +87,6 @@ def __eq__(self, other: BaseItem[TFile]) -> bool: return False return not self.end != other.end - @staticmethod - def remove_overlaps(items: list[BaseItem[TFile]]) -> list[BaseItem[TFile]]: - """Resolve overlaps between Items. - - If two Items overlap within the sequence (that is if one Item begins before the end of another, - the earliest Item's end is set to the begin of the latest Item. - If multiple items overlap with one earlier Item, only one is chosen as next. - The chosen next Item is the one that ends the latest. - - Parameters - ---------- - items: list[BaseItem] - List of Items to concatenate. - - Returns - ------- - list[BaseItem]: - The list of Items with no overlapping Items. - - Examples - -------- - >>> items = [BaseItem(begin = Timestamp("00:00:00"), end = Timestamp("00:00:15")), BaseItem(begin = Timestamp("00:00:10"), end = Timestamp("00:00:20"))] - >>> items[0].end == items[1].begin - False - >>> items = BaseItem.remove_overlaps(items) - >>> items[0].end == items[1].begin - True - - """ - items = sorted( - [copy.copy(item) for item in items], - key=lambda item: (item.begin, item.begin - item.end), - ) - concatenated_items = [] - for item in items: - concatenated_items.append(item) - overlapping_items = [ - item2 - for item2 in items - if item2 is not item - and is_overlapping((item.begin, item.end), (item2.begin, item2.end)) - ] - if not overlapping_items: - continue - kept_overlapping_item = max(overlapping_items, key=lambda item: item.end) - if kept_overlapping_item.end > item.end: - item.end = kept_overlapping_item.begin - else: - kept_overlapping_item = None - for dismissed_item in ( - item2 - for item2 in overlapping_items - if item2 is not kept_overlapping_item - ): - items.remove(dismissed_item) - return concatenated_items - @staticmethod def fill_gaps(items: list[BaseItem[TFile]]) -> list[BaseItem[TFile]]: """Return a list with empty items added in the gaps between items. diff --git a/src/OSmOSE/utils/timestamp_utils.py b/src/OSmOSE/utils/timestamp_utils.py index fe4015b2..a989085b 100644 --- a/src/OSmOSE/utils/timestamp_utils.py +++ b/src/OSmOSE/utils/timestamp_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations # Backwards compatibility with Python < 3.10 +import copy import os import re from datetime import datetime, timedelta @@ -417,7 +418,7 @@ def is_osmose_format_timestamp(timestamp: str) -> bool: def is_overlapping( - event1: tuple[Timestamp, Timestamp], event2: tuple[Timestamp, Timestamp] + event1: tuple[Timestamp, Timestamp], event2: tuple[Timestamp, Timestamp], ) -> bool: """Return True if the two events are overlapping, False otherwise. @@ -443,6 +444,67 @@ def is_overlapping( """ return event1[0] < event2[1] and event1[1] > event2[0] +def remove_overlaps(items: list) -> list: + """Resolve overlaps between objects that have begin and end attributes. + + If two objects overlap within the sequence + (that is if one object begins before the end of another), + the earliest object's end is set to the begin of the latest object. + If multiple objects overlap with one earlier object, only one is chosen as next. + The chosen next object is the one that ends the latest. + + Parameters + ---------- + items: list + List of objects to concatenate. + + Returns + ------- + list: + The list of objects with no overlap. + + Examples + -------- + >>> from dataclasses import dataclass + >>> @dataclass + ... class Item: + ... begin: Timestamp + ... end: Timestamp + >>> items = [Item(begin=Timestamp("00:00:00"),end=Timestamp("00:00:15")), Item(begin=Timestamp("00:00:10"),end=Timestamp("00:00:20"))] + >>> items[0].end == items[1].begin + False + >>> items = remove_overlaps(items) + >>> items[0].end == items[1].begin + True + + """ + items = sorted( + [copy.copy(item) for item in items], + key=lambda item: (item.begin, item.begin - item.end), + ) + concatenated_items = [] + for item in items: + concatenated_items.append(item) + overlapping_items = [ + item2 + for item2 in items + if item2 is not item + and is_overlapping((item.begin, item.end), (item2.begin, item2.end)) + ] + if not overlapping_items: + continue + kept_overlapping_item = max(overlapping_items, key=lambda item: item.end) + if kept_overlapping_item.end > item.end: + item.end = kept_overlapping_item.begin + else: + kept_overlapping_item = None + for dismissed_item in ( + item2 + for item2 in overlapping_items + if item2 is not kept_overlapping_item + ): + items.remove(dismissed_item) + return concatenated_items def get_timestamps( path_osmose_dataset: str, diff --git a/tests/test_item.py b/tests/test_item.py index fb464dc8..b6ac8c15 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -4,6 +4,7 @@ from pandas import Timestamp from OSmOSE.data.base_item import BaseItem +from OSmOSE.utils.timestamp_utils import remove_overlaps @pytest.mark.parametrize( @@ -61,7 +62,7 @@ ], ) def test_remove_overlaps(item_list: list[BaseItem], expected: list[BaseItem]) -> None: - cleaned_items = BaseItem.remove_overlaps(item_list) + cleaned_items = remove_overlaps(item_list) assert cleaned_items == expected From 447d8929534d3a3939775ddf5dff770fc3bbc07f Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 12:07:18 +0100 Subject: [PATCH 043/118] --amend --- src/OSmOSE/data/base_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index c6c9fc48..f9f7602f 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -13,7 +13,7 @@ from OSmOSE.data.base_file import BaseFile from OSmOSE.data.base_item import BaseItem -from OSmOSE.utils.timestamp_utils import is_overlapping +from OSmOSE.utils.timestamp_utils import is_overlapping, remove_overlaps if TYPE_CHECKING: from pandas import Timestamp @@ -135,5 +135,5 @@ def items_from_files( items.append(BaseItem(begin=begin, end=first_item.begin)) if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: items.append(BaseItem(begin=last_item.end, end=end)) - items = BaseItem.remove_overlaps(items) + items = remove_overlaps(items) return BaseItem.fill_gaps(items) From ac1d2e27ff1e48beb24a27dbe88766d720db8031 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 12:08:13 +0100 Subject: [PATCH 044/118] --amend --- src/OSmOSE/utils/timestamp_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/OSmOSE/utils/timestamp_utils.py b/src/OSmOSE/utils/timestamp_utils.py index a989085b..a5eec194 100644 --- a/src/OSmOSE/utils/timestamp_utils.py +++ b/src/OSmOSE/utils/timestamp_utils.py @@ -418,7 +418,8 @@ def is_osmose_format_timestamp(timestamp: str) -> bool: def is_overlapping( - event1: tuple[Timestamp, Timestamp], event2: tuple[Timestamp, Timestamp], + event1: tuple[Timestamp, Timestamp], + event2: tuple[Timestamp, Timestamp], ) -> bool: """Return True if the two events are overlapping, False otherwise. @@ -444,6 +445,7 @@ def is_overlapping( """ return event1[0] < event2[1] and event1[1] > event2[0] + def remove_overlaps(items: list) -> list: """Resolve overlaps between objects that have begin and end attributes. @@ -489,7 +491,7 @@ def remove_overlaps(items: list) -> list: item2 for item2 in items if item2 is not item - and is_overlapping((item.begin, item.end), (item2.begin, item2.end)) + and is_overlapping((item.begin, item.end), (item2.begin, item2.end)) ] if not overlapping_items: continue @@ -499,13 +501,12 @@ def remove_overlaps(items: list) -> list: else: kept_overlapping_item = None for dismissed_item in ( - item2 - for item2 in overlapping_items - if item2 is not kept_overlapping_item + item2 for item2 in overlapping_items if item2 is not kept_overlapping_item ): items.remove(dismissed_item) return concatenated_items + def get_timestamps( path_osmose_dataset: str, campaign_name: str, From c6b5746ce666b7b902849586df4bdf7785d27039 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 12:08:23 +0100 Subject: [PATCH 045/118] add dataset tests --- pyproject.toml | 1 + src/OSmOSE/config.py | 2 +- src/OSmOSE/data/base_dataset.py | 6 ++- tests/test_audio.py | 94 +++++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a3aef9a..eed7fb70 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,4 +70,5 @@ select = ["ALL"] "D", # Docstring-related stuff "SLF001", # Access to private variables "BLE001", # Blind exceptions + "PLR0913", # Too many arguments in methods ] \ No newline at end of file diff --git a/src/OSmOSE/config.py b/src/OSmOSE/config.py index 2b841134..90080649 100755 --- a/src/OSmOSE/config.py +++ b/src/OSmOSE/config.py @@ -34,7 +34,7 @@ OSMOSE_PATH = namedtuple("path_list", __global_path_dict.keys())(**__global_path_dict) TIMESTAMP_FORMAT_AUDIO_FILE = "%Y-%m-%dT%H:%M:%S.%f%z" -TIMESTAMP_FORMAT_TEST_FILES = "%y%m%d%H%M%S" +TIMESTAMP_FORMAT_TEST_FILES = "%y%m%d%H%M%S%f" TIMESTAMP_FORMAT_EXPORTED_FILES = "%Y_%m_%d_%H_%M_%S" FPDEFAULT = 0o664 # Default file permissions DPDEFAULT = stat.S_ISGID | 0o775 # Default directory permissions diff --git a/src/OSmOSE/data/base_dataset.py b/src/OSmOSE/data/base_dataset.py index 60b1d914..11d13c3c 100644 --- a/src/OSmOSE/data/base_dataset.py +++ b/src/OSmOSE/data/base_dataset.py @@ -6,14 +6,16 @@ from __future__ import annotations -from pathlib import Path -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar from pandas import Timedelta, Timestamp, date_range from OSmOSE.data.base_data import BaseData from OSmOSE.data.base_file import BaseFile +if TYPE_CHECKING: + from pathlib import Path + TData = TypeVar("TData", bound=BaseData) TFile = TypeVar("TFile", bound=BaseFile) diff --git a/tests/test_audio.py b/tests/test_audio.py index 041f7781..01b1c9ea 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -8,6 +8,7 @@ from OSmOSE.config import TIMESTAMP_FORMAT_TEST_FILES from OSmOSE.data.audio_data import AudioData +from OSmOSE.data.audio_dataset import AudioDataset from OSmOSE.data.audio_file import AudioFile from OSmOSE.data.audio_item import AudioItem from OSmOSE.utils.audio_utils import generate_sample_audio @@ -399,3 +400,96 @@ def test_audio_resample_sample_count( data = AudioData.from_files(audio_files, begin=start, end=stop) data.sample_rate = sample_rate assert data.get_value().shape[0] == expected_nb_samples + + +@pytest.mark.parametrize( + ("audio_files", "begin", "end", "duration", "expected_audio_data"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + None, + None, + None, + generate_sample_audio(1, 48_000), + id="one_entire_file", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 3, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + None, + None, + pd.Timedelta(seconds=1), + generate_sample_audio( + nb_files=3, nb_samples=48_000, series_type="increase" + ), + id="multiple_consecutive_files", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 2, + "inter_file_duration": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + None, + None, + pd.Timedelta(seconds=1), + [ + generate_sample_audio(nb_files=1, nb_samples=96_000)[0][0:48_000], + generate_sample_audio( + nb_files=1, nb_samples=48_000, min_value=0.0, max_value=0.0 + )[0], + generate_sample_audio(nb_files=1, nb_samples=96_000)[0][48_000:], + ], + id="two_separated_files", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 3, + "inter_file_duration": -0.5, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "repeat", + }, + None, + None, + pd.Timedelta(seconds=1), + generate_sample_audio(nb_files=2, nb_samples=48_000), + id="overlapping_files", + ), + ], + indirect=["audio_files"], +) +def test_audio_dataset_from_folder( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + begin: pd.Timestamp | None, + end: pd.Timestamp | None, + duration: pd.Timedelta | None, + expected_audio_data: list[tuple[int, bool]], +) -> None: + dataset = AudioDataset.from_folder( + tmp_path, + strptime_format=TIMESTAMP_FORMAT_TEST_FILES, + begin=begin, + end=end, + data_duration=duration, + ) + assert all( + np.array_equal(data.get_value(), expected) + for (data, expected) in zip(dataset.data, expected_audio_data) + ) From 1956e06a5ae9fb1efbf7fa4ea8eec8892a1c5c4e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 12:43:15 +0100 Subject: [PATCH 046/118] add overlap utils as data_utils --- src/OSmOSE/data/base_data.py | 4 +- src/OSmOSE/utils/data_utils.py | 107 ++++++++++++++++++++++++++++ src/OSmOSE/utils/timestamp_utils.py | 91 ----------------------- tests/test_data_utils.py | 101 ++++++++++++++++++++++++++ tests/test_item.py | 2 +- tests/test_timestamp_utils.py | 98 ------------------------- 6 files changed, 211 insertions(+), 192 deletions(-) create mode 100644 src/OSmOSE/utils/data_utils.py create mode 100644 tests/test_data_utils.py diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index f9f7602f..6b9e62e7 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -13,7 +13,7 @@ from OSmOSE.data.base_file import BaseFile from OSmOSE.data.base_item import BaseItem -from OSmOSE.utils.timestamp_utils import is_overlapping, remove_overlaps +from OSmOSE.utils.data_utils import EventClass, is_overlapping, remove_overlaps if TYPE_CHECKING: from pandas import Timestamp @@ -125,7 +125,7 @@ def items_from_files( included_files = [ file for file in files - if is_overlapping((file.begin, file.end), (begin, end)) + if is_overlapping(file, EventClass(begin=begin, end=end)) ] items = [BaseItem(file, begin, end) for file in included_files] diff --git a/src/OSmOSE/utils/data_utils.py b/src/OSmOSE/utils/data_utils.py new file mode 100644 index 00000000..44fc9cac --- /dev/null +++ b/src/OSmOSE/utils/data_utils.py @@ -0,0 +1,107 @@ +import copy +from dataclasses import dataclass +from typing import Protocol + +from pandas import Timestamp + + +class Event(Protocol): + begin: Timestamp + end: Timestamp + + +@dataclass +class EventClass: + begin: Timestamp + end: Timestamp + + +def is_overlapping( + event1: Event, + event2: Event, +) -> bool: + """Return True if the two events are overlapping, False otherwise. + + Events are objects that have begin and end Timestamp attributes. + + Parameters + ---------- + event1: Event + The first event. + event2: Event + The second event. + + Returns + ------- + bool: + True if the two events are overlapping, False otherwise. + + Examples + -------- + >>> is_overlapping(EventClass(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")), EventClass(begin=Timestamp("2024-01-01 12:00:00"),end=Timestamp("2024-01-02 12:00:00"))) + True + >>> is_overlapping(EventClass(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")), EventClass(begin=Timestamp("2024-01-02 00:00:00"),end=Timestamp("2024-01-02 12:00:00"))) + False + + """ + return event1.begin < event2.end and event1.end > event2.begin + + +def remove_overlaps(items: list) -> list: + """Resolve overlaps between objects that have begin and end attributes. + + If two objects overlap within the sequence + (that is if one object begins before the end of another), + the earliest object's end is set to the begin of the latest object. + If multiple objects overlap with one earlier object, only one is chosen as next. + The chosen next object is the one that ends the latest. + + Parameters + ---------- + items: list + List of objects to concatenate. + + Returns + ------- + list: + The list of objects with no overlap. + + Examples + -------- + >>> from dataclasses import dataclass + >>> @dataclass + ... class Item: + ... begin: Timestamp + ... end: Timestamp + >>> items = [Item(begin=Timestamp("00:00:00"),end=Timestamp("00:00:15")), Item(begin=Timestamp("00:00:10"),end=Timestamp("00:00:20"))] + >>> items[0].end == items[1].begin + False + >>> items = remove_overlaps(items) + >>> items[0].end == items[1].begin + True + + """ + items = sorted( + [copy.copy(item) for item in items], + key=lambda item: (item.begin, item.begin - item.end), + ) + concatenated_items = [] + for item in items: + concatenated_items.append(item) + overlapping_items = [ + item2 + for item2 in items + if item2 is not item and is_overlapping(item, item2) + ] + if not overlapping_items: + continue + kept_overlapping_item = max(overlapping_items, key=lambda item: item.end) + if kept_overlapping_item.end > item.end: + item.end = kept_overlapping_item.begin + else: + kept_overlapping_item = None + for dismissed_item in ( + item2 for item2 in overlapping_items if item2 is not kept_overlapping_item + ): + items.remove(dismissed_item) + return concatenated_items diff --git a/src/OSmOSE/utils/timestamp_utils.py b/src/OSmOSE/utils/timestamp_utils.py index a5eec194..03178cbf 100644 --- a/src/OSmOSE/utils/timestamp_utils.py +++ b/src/OSmOSE/utils/timestamp_utils.py @@ -2,7 +2,6 @@ from __future__ import annotations # Backwards compatibility with Python < 3.10 -import copy import os import re from datetime import datetime, timedelta @@ -417,96 +416,6 @@ def is_osmose_format_timestamp(timestamp: str) -> bool: return True -def is_overlapping( - event1: tuple[Timestamp, Timestamp], - event2: tuple[Timestamp, Timestamp], -) -> bool: - """Return True if the two events are overlapping, False otherwise. - - Parameters - ---------- - event1: tuple[pandas.Timestamp, pandas.Timestamp] - The first event. - event2: tuple[pandas.Timestamp, pandas.Timestamp] - The second event. - - Returns - ------- - bool: - True if the two events are overlapping, False otherwise. - - Examples - -------- - >>> is_overlapping((Timestamp("2024-01-01 00:00:00"),(Timestamp("2024-01-02 00:00:00"))), (Timestamp("2024-01-01 12:00:00"),(Timestamp("2024-01-02 12:00:00")))) - True - >>> is_overlapping((Timestamp("2024-01-01 00:00:00"),(Timestamp("2024-01-02 00:00:00"))), (Timestamp("2024-01-02 00:00:00"),(Timestamp("2024-01-02 12:00:00")))) - False - - """ - return event1[0] < event2[1] and event1[1] > event2[0] - - -def remove_overlaps(items: list) -> list: - """Resolve overlaps between objects that have begin and end attributes. - - If two objects overlap within the sequence - (that is if one object begins before the end of another), - the earliest object's end is set to the begin of the latest object. - If multiple objects overlap with one earlier object, only one is chosen as next. - The chosen next object is the one that ends the latest. - - Parameters - ---------- - items: list - List of objects to concatenate. - - Returns - ------- - list: - The list of objects with no overlap. - - Examples - -------- - >>> from dataclasses import dataclass - >>> @dataclass - ... class Item: - ... begin: Timestamp - ... end: Timestamp - >>> items = [Item(begin=Timestamp("00:00:00"),end=Timestamp("00:00:15")), Item(begin=Timestamp("00:00:10"),end=Timestamp("00:00:20"))] - >>> items[0].end == items[1].begin - False - >>> items = remove_overlaps(items) - >>> items[0].end == items[1].begin - True - - """ - items = sorted( - [copy.copy(item) for item in items], - key=lambda item: (item.begin, item.begin - item.end), - ) - concatenated_items = [] - for item in items: - concatenated_items.append(item) - overlapping_items = [ - item2 - for item2 in items - if item2 is not item - and is_overlapping((item.begin, item.end), (item2.begin, item2.end)) - ] - if not overlapping_items: - continue - kept_overlapping_item = max(overlapping_items, key=lambda item: item.end) - if kept_overlapping_item.end > item.end: - item.end = kept_overlapping_item.begin - else: - kept_overlapping_item = None - for dismissed_item in ( - item2 for item2 in overlapping_items if item2 is not kept_overlapping_item - ): - items.remove(dismissed_item) - return concatenated_items - - def get_timestamps( path_osmose_dataset: str, campaign_name: str, diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py new file mode 100644 index 00000000..50c32d92 --- /dev/null +++ b/tests/test_data_utils.py @@ -0,0 +1,101 @@ +import pytest +from pandas import Timestamp + +from OSmOSE.utils.data_utils import EventClass, is_overlapping + + +@pytest.mark.parametrize( + ("event1", "event2", "expected"), + [ + pytest.param( + EventClass( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + EventClass( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + True, + id="same_event", + ), + pytest.param( + EventClass( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + EventClass( + begin=Timestamp("2024-01-01 12:00:00"), + end=Timestamp("2024-01-02 12:00:00"), + ), + True, + id="overlapping_events", + ), + pytest.param( + EventClass( + begin=Timestamp("2024-01-01 12:00:00"), + end=Timestamp("2024-01-02 12:00:00"), + ), + EventClass( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + True, + id="overlapping_events_reversed", + ), + pytest.param( + EventClass( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + EventClass( + begin=Timestamp("2024-01-01 12:00:00"), + end=Timestamp("2024-01-01 12:01:00"), + ), + True, + id="embedded_events", + ), + pytest.param( + EventClass( + begin=Timestamp("2024-01-01 0:00:00"), + end=Timestamp("2024-01-01 12:00:00"), + ), + EventClass( + begin=Timestamp("2024-01-02 00:00:00"), + end=Timestamp("2024-01-02 12:00:00"), + ), + False, + id="non_overlapping_events", + ), + pytest.param( + EventClass( + begin=Timestamp("2024-01-02 0:00:00"), + end=Timestamp("2024-01-02 12:00:00"), + ), + EventClass( + begin=Timestamp("2024-01-01 00:00:00"), + end=Timestamp("2024-01-01 12:00:00"), + ), + False, + id="non_overlapping_events_reversed", + ), + pytest.param( + EventClass( + begin=Timestamp("2024-01-01 0:00:00"), + end=Timestamp("2024-01-01 12:00:00"), + ), + EventClass( + begin=Timestamp("2024-01-01 12:00:00"), + end=Timestamp("2024-01-02 00:00:00"), + ), + False, + id="border_sharing_isnt_overlapping", + ), + ], +) +def test_overlapping_events( + event1: EventClass, + event2: EventClass, + expected: bool, +) -> None: + assert is_overlapping(event1, event2) == expected diff --git a/tests/test_item.py b/tests/test_item.py index b6ac8c15..1419d2d6 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -4,7 +4,7 @@ from pandas import Timestamp from OSmOSE.data.base_item import BaseItem -from OSmOSE.utils.timestamp_utils import remove_overlaps +from OSmOSE.utils.data_utils import remove_overlaps @pytest.mark.parametrize( diff --git a/tests/test_timestamp_utils.py b/tests/test_timestamp_utils.py index df095ecd..1422afa6 100644 --- a/tests/test_timestamp_utils.py +++ b/tests/test_timestamp_utils.py @@ -12,7 +12,6 @@ associate_timestamps, build_regex_from_datetime_template, is_datetime_template_valid, - is_overlapping, localize_timestamp, parse_timestamps_csv, reformat_timestamp, @@ -1053,100 +1052,3 @@ def test_adapt_timestamp_csv_to_osmose( assert adapt_timestamp_csv_to_osmose(timestamps, date_template, timezone).equals( expected, ) - - -@pytest.mark.parametrize( - ("event1", "event2", "expected"), - [ - pytest.param( - ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-02 00:00:00"), - ), - ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-02 00:00:00"), - ), - True, - id="same_event", - ), - pytest.param( - ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-02 00:00:00"), - ), - ( - Timestamp("2024-01-01 12:00:00"), - Timestamp("2024-01-02 12:00:00"), - ), - True, - id="overlapping_events", - ), - pytest.param( - ( - Timestamp("2024-01-01 12:00:00"), - Timestamp("2024-01-02 12:00:00"), - ), - ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-02 00:00:00"), - ), - True, - id="overlapping_events_reversed", - ), - pytest.param( - ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-02 00:00:00"), - ), - ( - Timestamp("2024-01-01 12:00:00"), - Timestamp("2024-01-01 12:01:00"), - ), - True, - id="embedded_events", - ), - pytest.param( - ( - Timestamp("2024-01-01 0:00:00"), - Timestamp("2024-01-01 12:00:00"), - ), - ( - Timestamp("2024-01-02 00:00:00"), - Timestamp("2024-01-02 12:00:00"), - ), - False, - id="non_overlapping_events", - ), - pytest.param( - ( - Timestamp("2024-01-02 0:00:00"), - Timestamp("2024-01-02 12:00:00"), - ), - ( - Timestamp("2024-01-01 00:00:00"), - Timestamp("2024-01-01 12:00:00"), - ), - False, - id="non_overlapping_events_reversed", - ), - pytest.param( - ( - Timestamp("2024-01-01 0:00:00"), - Timestamp("2024-01-01 12:00:00"), - ), - ( - Timestamp("2024-01-01 12:00:00"), - Timestamp("2024-01-02 00:00:00"), - ), - False, - id="border_sharing_isnt_overlapping", - ), - ], -) -def test_overlapping_events( - event1: tuple[Timestamp, Timestamp], - event2: tuple[Timestamp, Timestamp], - expected: bool, -) -> None: - assert is_overlapping(event1, event2) == expected From 7857ca36131c3828110ce1f90929b8ab9c02632d Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 14:33:20 +0100 Subject: [PATCH 047/118] add Event base class for all data classes --- src/OSmOSE/data/base_data.py | 13 ++- src/OSmOSE/data/base_dataset.py | 3 +- src/OSmOSE/data/base_file.py | 3 +- src/OSmOSE/data/base_item.py | 36 +------- src/OSmOSE/utils/data_utils.py | 137 ++++++++++++++++++++--------- tests/test_data_utils.py | 151 ++++++++++++++++++++++++++++---- tests/test_item.py | 122 -------------------------- 7 files changed, 240 insertions(+), 225 deletions(-) delete mode 100644 tests/test_item.py diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index 6b9e62e7..c7cd0c5c 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -6,16 +6,17 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np from OSmOSE.data.base_file import BaseFile from OSmOSE.data.base_item import BaseItem -from OSmOSE.utils.data_utils import EventClass, is_overlapping, remove_overlaps +from OSmOSE.utils.data_utils import Event, fill_gaps, is_overlapping, remove_overlaps if TYPE_CHECKING: + from pathlib import Path + from pandas import Timestamp @@ -23,7 +24,7 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseData(Generic[TItem, TFile]): +class BaseData(Generic[TItem, TFile], Event): """Base class for the Data objects. Data corresponds to data scattered through different Files. @@ -123,9 +124,7 @@ def items_from_files( end = max(file.end for file in files) if end is None else end included_files = [ - file - for file in files - if is_overlapping(file, EventClass(begin=begin, end=end)) + file for file in files if is_overlapping(file, Event(begin=begin, end=end)) ] items = [BaseItem(file, begin, end) for file in included_files] @@ -136,4 +135,4 @@ def items_from_files( if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: items.append(BaseItem(begin=last_item.end, end=end)) items = remove_overlaps(items) - return BaseItem.fill_gaps(items) + return fill_gaps(items, BaseItem) diff --git a/src/OSmOSE/data/base_dataset.py b/src/OSmOSE/data/base_dataset.py index 11d13c3c..97d3e28c 100644 --- a/src/OSmOSE/data/base_dataset.py +++ b/src/OSmOSE/data/base_dataset.py @@ -12,6 +12,7 @@ from OSmOSE.data.base_data import BaseData from OSmOSE.data.base_file import BaseFile +from OSmOSE.utils.data_utils import Event if TYPE_CHECKING: from pathlib import Path @@ -20,7 +21,7 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseDataset(Generic[TData, TFile]): +class BaseDataset(Generic[TData, TFile], Event): """Base class for Dataset objects. Datasets are collections of Data, with methods diff --git a/src/OSmOSE/data/base_file.py b/src/OSmOSE/data/base_file.py index 9025ef62..59861e93 100644 --- a/src/OSmOSE/data/base_file.py +++ b/src/OSmOSE/data/base_file.py @@ -15,10 +15,11 @@ from pathlib import Path +from OSmOSE.utils.data_utils import Event from OSmOSE.utils.timestamp_utils import strptime_from_text -class BaseFile: +class BaseFile(Event): """Base class for the File objects. A File object associates file-written data to timestamps. diff --git a/src/OSmOSE/data/base_item.py b/src/OSmOSE/data/base_item.py index cb9d79e6..b9ec8001 100644 --- a/src/OSmOSE/data/base_item.py +++ b/src/OSmOSE/data/base_item.py @@ -5,12 +5,12 @@ from __future__ import annotations -import copy from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np from OSmOSE.data.base_file import BaseFile +from OSmOSE.utils.data_utils import Event if TYPE_CHECKING: from pandas import Timestamp @@ -18,7 +18,7 @@ TFile = TypeVar("TFile", bound=BaseFile) -class BaseItem(Generic[TFile]): +class BaseItem(Generic[TFile], Event): """Base class for the Item objects. An Item correspond to a portion of a File object. @@ -86,35 +86,3 @@ def __eq__(self, other: BaseItem[TFile]) -> bool: if self.begin != other.begin: return False return not self.end != other.end - - @staticmethod - def fill_gaps(items: list[BaseItem[TFile]]) -> list[BaseItem[TFile]]: - """Return a list with empty items added in the gaps between items. - - Parameters - ---------- - items: list[BaseItem] - List of Items to fill. - - Returns - ------- - list[BaseItem]: - List of Items with no gaps. - - Examples - -------- - >>> items = [BaseItem(begin = Timestamp("00:00:00"), end = Timestamp("00:00:10")), BaseItem(begin = Timestamp("00:00:15"), end = Timestamp("00:00:25"))] - >>> items = BaseItem.fill_gaps(items) - >>> [(item.begin.second, item.end.second) for item in items] - [(0, 10), (10, 15), (15, 25)] - - """ - items = sorted([copy.copy(item) for item in items], key=lambda item: item.begin) - filled_item_list = [] - for index, item in enumerate(items[:-1]): - next_item = items[index + 1] - filled_item_list.append(item) - if next_item.begin > item.end: - filled_item_list.append(BaseItem(begin=item.end, end=next_item.begin)) - filled_item_list.append(items[-1]) - return filled_item_list diff --git a/src/OSmOSE/utils/data_utils.py b/src/OSmOSE/utils/data_utils.py index 44fc9cac..906fcc0e 100644 --- a/src/OSmOSE/utils/data_utils.py +++ b/src/OSmOSE/utils/data_utils.py @@ -1,24 +1,32 @@ +"""Util classes and functions for data objects.""" + +from __future__ import annotations + import copy from dataclasses import dataclass -from typing import Protocol +from typing import TYPE_CHECKING, TypeVar -from pandas import Timestamp +if TYPE_CHECKING: + from pandas import Timestamp -class Event(Protocol): - begin: Timestamp - end: Timestamp +@dataclass +class Event: + """Dataclass containing begin an end attributes. + Classes that have a begin and an end should inherit from Event. + """ -@dataclass -class EventClass: begin: Timestamp end: Timestamp +TEvent = TypeVar("TEvent", bound=Event) + + def is_overlapping( - event1: Event, - event2: Event, + event1: TEvent | Event, + event2: TEvent | Event, ) -> bool: """Return True if the two events are overlapping, False otherwise. @@ -38,33 +46,33 @@ def is_overlapping( Examples -------- - >>> is_overlapping(EventClass(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")), EventClass(begin=Timestamp("2024-01-01 12:00:00"),end=Timestamp("2024-01-02 12:00:00"))) + >>> is_overlapping(Event(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")), Event(begin=Timestamp("2024-01-01 12:00:00"),end=Timestamp("2024-01-02 12:00:00"))) True - >>> is_overlapping(EventClass(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")), EventClass(begin=Timestamp("2024-01-02 00:00:00"),end=Timestamp("2024-01-02 12:00:00"))) + >>> is_overlapping(Event(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")), Event(begin=Timestamp("2024-01-02 00:00:00"),end=Timestamp("2024-01-02 12:00:00"))) False """ return event1.begin < event2.end and event1.end > event2.begin -def remove_overlaps(items: list) -> list: - """Resolve overlaps between objects that have begin and end attributes. +def remove_overlaps(events: list[TEvent]) -> list[TEvent]: + """Resolve overlaps between events. - If two objects overlap within the sequence - (that is if one object begins before the end of another), - the earliest object's end is set to the begin of the latest object. - If multiple objects overlap with one earlier object, only one is chosen as next. - The chosen next object is the one that ends the latest. + If two events overlap within the whole events collection + (that is if one event begins before the end of another event), + the earliest event's end is set to the begin of the latest object. + If multiple events overlap with one earlier event, only one is chosen as next. + The chosen next event is the one that ends the latest. Parameters ---------- - items: list - List of objects to concatenate. + events: list + List of events in which to remove the overlaps. Returns ------- list: - The list of objects with no overlap. + The list of events with no overlap. Examples -------- @@ -74,34 +82,77 @@ def remove_overlaps(items: list) -> list: ... begin: Timestamp ... end: Timestamp >>> items = [Item(begin=Timestamp("00:00:00"),end=Timestamp("00:00:15")), Item(begin=Timestamp("00:00:10"),end=Timestamp("00:00:20"))] - >>> items[0].end == items[1].begin + >>> events[0].end == events[1].begin False - >>> items = remove_overlaps(items) - >>> items[0].end == items[1].begin + >>> items = remove_overlaps(events) + >>> events[0].end == events[1].begin True """ - items = sorted( - [copy.copy(item) for item in items], - key=lambda item: (item.begin, item.begin - item.end), + events = sorted( + [copy.copy(event) for event in events], + key=lambda event: (event.begin, event.begin - event.end), ) - concatenated_items = [] - for item in items: - concatenated_items.append(item) - overlapping_items = [ - item2 - for item2 in items - if item2 is not item and is_overlapping(item, item2) + concatenated_events = [] + for event in events: + concatenated_events.append(event) + overlapping_events = [ + event2 + for event2 in events + if event2 is not event and is_overlapping(event, event2) ] - if not overlapping_items: + if not overlapping_events: continue - kept_overlapping_item = max(overlapping_items, key=lambda item: item.end) - if kept_overlapping_item.end > item.end: - item.end = kept_overlapping_item.begin + kept_overlapping_event = max(overlapping_events, key=lambda item: item.end) + if kept_overlapping_event.end > event.end: + event.end = kept_overlapping_event.begin else: - kept_overlapping_item = None - for dismissed_item in ( - item2 for item2 in overlapping_items if item2 is not kept_overlapping_item + kept_overlapping_event = None + for dismissed_event in ( + event2 + for event2 in overlapping_events + if event2 is not kept_overlapping_event ): - items.remove(dismissed_item) - return concatenated_items + events.remove(dismissed_event) + return concatenated_events + + +def fill_gaps(events: list[TEvent], filling_class: type[TEvent]) -> list[TEvent]: + """Return a list with empty events added in the gaps between items. + + The created empty events are instantiated from the class filling_class. + + Parameters + ---------- + events: list[TEvent] + List of events to fill. + filling_class: type[TEvent] + The class used for instantiating empty events in the gaps. + + Returns + ------- + list[TEvent]: + List of events with no gaps. + + Examples + -------- + >>> events = [Event(begin = Timestamp("00:00:00"), end = Timestamp("00:00:10")), Event(begin = Timestamp("00:00:15"), end = Timestamp("00:00:25"))] + >>> events = fill_gaps(events, Event) + >>> [(event.begin.second, event.end.second) for event in events] + [(0, 10), (10, 15), (15, 25)] + + """ + events = sorted( + [copy.copy(event) for event in events], + key=lambda event: event.begin, + ) + filled_event_list = [] + for index, event in enumerate(events[:-1]): + next_event = events[index + 1] + filled_event_list.append(event) + if next_event.begin > event.end: + filled_event_list.append( + filling_class(begin=event.end, end=next_event.begin), + ) + filled_event_list.append(events[-1]) + return filled_event_list diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 50c32d92..60988361 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -1,18 +1,20 @@ +from __future__ import annotations + import pytest from pandas import Timestamp -from OSmOSE.utils.data_utils import EventClass, is_overlapping +from OSmOSE.utils.data_utils import Event, fill_gaps, is_overlapping, remove_overlaps @pytest.mark.parametrize( ("event1", "event2", "expected"), [ pytest.param( - EventClass( + Event( begin=Timestamp("2024-01-01 00:00:00"), end=Timestamp("2024-01-02 00:00:00"), ), - EventClass( + Event( begin=Timestamp("2024-01-01 00:00:00"), end=Timestamp("2024-01-02 00:00:00"), ), @@ -20,11 +22,11 @@ id="same_event", ), pytest.param( - EventClass( + Event( begin=Timestamp("2024-01-01 00:00:00"), end=Timestamp("2024-01-02 00:00:00"), ), - EventClass( + Event( begin=Timestamp("2024-01-01 12:00:00"), end=Timestamp("2024-01-02 12:00:00"), ), @@ -32,11 +34,11 @@ id="overlapping_events", ), pytest.param( - EventClass( + Event( begin=Timestamp("2024-01-01 12:00:00"), end=Timestamp("2024-01-02 12:00:00"), ), - EventClass( + Event( begin=Timestamp("2024-01-01 00:00:00"), end=Timestamp("2024-01-02 00:00:00"), ), @@ -44,11 +46,11 @@ id="overlapping_events_reversed", ), pytest.param( - EventClass( + Event( begin=Timestamp("2024-01-01 00:00:00"), end=Timestamp("2024-01-02 00:00:00"), ), - EventClass( + Event( begin=Timestamp("2024-01-01 12:00:00"), end=Timestamp("2024-01-01 12:01:00"), ), @@ -56,11 +58,11 @@ id="embedded_events", ), pytest.param( - EventClass( + Event( begin=Timestamp("2024-01-01 0:00:00"), end=Timestamp("2024-01-01 12:00:00"), ), - EventClass( + Event( begin=Timestamp("2024-01-02 00:00:00"), end=Timestamp("2024-01-02 12:00:00"), ), @@ -68,11 +70,11 @@ id="non_overlapping_events", ), pytest.param( - EventClass( + Event( begin=Timestamp("2024-01-02 0:00:00"), end=Timestamp("2024-01-02 12:00:00"), ), - EventClass( + Event( begin=Timestamp("2024-01-01 00:00:00"), end=Timestamp("2024-01-01 12:00:00"), ), @@ -80,11 +82,11 @@ id="non_overlapping_events_reversed", ), pytest.param( - EventClass( + Event( begin=Timestamp("2024-01-01 0:00:00"), end=Timestamp("2024-01-01 12:00:00"), ), - EventClass( + Event( begin=Timestamp("2024-01-01 12:00:00"), end=Timestamp("2024-01-02 00:00:00"), ), @@ -94,8 +96,123 @@ ], ) def test_overlapping_events( - event1: EventClass, - event2: EventClass, + event1: Event, + event2: Event, expected: bool, ) -> None: assert is_overlapping(event1, event2) == expected + + +@pytest.mark.parametrize( + ("events", "expected"), + [ + pytest.param( + [Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + id="only_one_event_is_unchanged", + ), + pytest.param( + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + ], + [Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + id="doubled_event_is_removed", + ), + pytest.param( + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + id="overlapping_event_is_truncated", + ), + pytest.param( + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:15")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + id="longest_event_is_prioritized", + ), + pytest.param( + [ + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:25")), + Event(begin=Timestamp("00:00:0"), end=Timestamp("00:00:15")), + Event(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), + ], + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + Event(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), + ], + id="events_are_reordered", + ), + ], +) +def test_remove_overlaps(events: list[Event], expected: list[Event]) -> None: + cleaned_events = remove_overlaps(events) + assert cleaned_events == expected + + +@pytest.mark.parametrize( + ("events", "expected"), + [ + pytest.param( + [Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + [Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], + id="only_one_event_is_unchanged", + ), + pytest.param( + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + ], + id="consecutive_events_are_unchanged", + ), + pytest.param( + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + ], + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + Event(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + ], + id="one_gap_is_filled", + ), + pytest.param( + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + Event(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), + Event(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), + ], + [ + Event(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), + Event(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), + Event(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), + Event(begin=Timestamp("00:00:30"), end=Timestamp("00:00:35")), + Event(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), + Event(begin=Timestamp("00:00:45"), end=Timestamp("00:01:00")), + Event(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), + ], + id="multiple_gaps_are_filled", + ), + ], +) +def test_fill_event_gaps(events: list[Event], expected: list[Event]) -> None: + filled_events = fill_gaps(events, Event) + assert filled_events == expected diff --git a/tests/test_item.py b/tests/test_item.py deleted file mode 100644 index 1419d2d6..00000000 --- a/tests/test_item.py +++ /dev/null @@ -1,122 +0,0 @@ -from __future__ import annotations - -import pytest -from pandas import Timestamp - -from OSmOSE.data.base_item import BaseItem -from OSmOSE.utils.data_utils import remove_overlaps - - -@pytest.mark.parametrize( - ("item_list", "expected"), - [ - pytest.param( - [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], - [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], - id="only_one_item_is_unchanged", - ), - pytest.param( - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - ], - [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], - id="doubled_item_is_removed", - ), - pytest.param( - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ], - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ], - id="overlapping_item_is_truncated", - ), - pytest.param( - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:15")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:15")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ], - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ], - id="longest_item_is_prioritized", - ), - pytest.param( - [ - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:25")), - BaseItem(begin=Timestamp("00:00:0"), end=Timestamp("00:00:15")), - BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), - ], - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:35")), - ], - id="items_are_reordered", - ), - ], -) -def test_remove_overlaps(item_list: list[BaseItem], expected: list[BaseItem]) -> None: - cleaned_items = remove_overlaps(item_list) - assert cleaned_items == expected - - -@pytest.mark.parametrize( - ("item_list", "expected"), - [ - pytest.param( - [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], - [BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10"))], - id="only_one_item_is_unchanged", - ), - pytest.param( - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ], - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - ], - id="consecutive_items_are_unchanged", - ), - pytest.param( - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), - ], - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), - ], - id="one_gap_is_filled", - ), - pytest.param( - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), - BaseItem(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), - BaseItem(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), - ], - [ - BaseItem(begin=Timestamp("00:00:00"), end=Timestamp("00:00:10")), - BaseItem(begin=Timestamp("00:00:10"), end=Timestamp("00:00:20")), - BaseItem(begin=Timestamp("00:00:20"), end=Timestamp("00:00:30")), - BaseItem(begin=Timestamp("00:00:30"), end=Timestamp("00:00:35")), - BaseItem(begin=Timestamp("00:00:35"), end=Timestamp("00:00:45")), - BaseItem(begin=Timestamp("00:00:45"), end=Timestamp("00:01:00")), - BaseItem(begin=Timestamp("00:01:00"), end=Timestamp("00:02:00")), - ], - id="multiple_gaps_are_filled", - ), - ], -) -def test_fill_item_gaps(item_list: list[BaseItem], expected: list[BaseItem]) -> None: - filled_items = BaseItem.fill_gaps(item_list) - assert filled_items == expected From 5554aa44642ee6a2ac23d8194f85dbc591abc366 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 14:57:09 +0100 Subject: [PATCH 048/118] move all data_utils stuff to Event class --- src/OSmOSE/data/base_data.py | 8 +- src/OSmOSE/data/base_dataset.py | 2 +- src/OSmOSE/data/base_file.py | 2 +- src/OSmOSE/data/base_item.py | 2 +- src/OSmOSE/data/event.py | 150 +++++++++++++++++++ src/OSmOSE/utils/data_utils.py | 158 -------------------- tests/{test_data_utils.py => test_event.py} | 8 +- 7 files changed, 161 insertions(+), 169 deletions(-) create mode 100644 src/OSmOSE/data/event.py delete mode 100644 src/OSmOSE/utils/data_utils.py rename tests/{test_data_utils.py => test_event.py} (97%) diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index c7cd0c5c..e6477ae1 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -12,7 +12,7 @@ from OSmOSE.data.base_file import BaseFile from OSmOSE.data.base_item import BaseItem -from OSmOSE.utils.data_utils import Event, fill_gaps, is_overlapping, remove_overlaps +from OSmOSE.data.event import Event if TYPE_CHECKING: from pathlib import Path @@ -124,7 +124,7 @@ def items_from_files( end = max(file.end for file in files) if end is None else end included_files = [ - file for file in files if is_overlapping(file, Event(begin=begin, end=end)) + file for file in files if file.overlaps(Event(begin=begin, end=end)) ] items = [BaseItem(file, begin, end) for file in included_files] @@ -134,5 +134,5 @@ def items_from_files( items.append(BaseItem(begin=begin, end=first_item.begin)) if (last_item := sorted(items, key=lambda item: item.end)[-1]).end < end: items.append(BaseItem(begin=last_item.end, end=end)) - items = remove_overlaps(items) - return fill_gaps(items, BaseItem) + items = Event.remove_overlaps(items) + return Event.fill_gaps(items, BaseItem) diff --git a/src/OSmOSE/data/base_dataset.py b/src/OSmOSE/data/base_dataset.py index 97d3e28c..deb57b84 100644 --- a/src/OSmOSE/data/base_dataset.py +++ b/src/OSmOSE/data/base_dataset.py @@ -12,7 +12,7 @@ from OSmOSE.data.base_data import BaseData from OSmOSE.data.base_file import BaseFile -from OSmOSE.utils.data_utils import Event +from OSmOSE.data.event import Event if TYPE_CHECKING: from pathlib import Path diff --git a/src/OSmOSE/data/base_file.py b/src/OSmOSE/data/base_file.py index 59861e93..37899ad8 100644 --- a/src/OSmOSE/data/base_file.py +++ b/src/OSmOSE/data/base_file.py @@ -15,7 +15,7 @@ from pathlib import Path -from OSmOSE.utils.data_utils import Event +from OSmOSE.data.event import Event from OSmOSE.utils.timestamp_utils import strptime_from_text diff --git a/src/OSmOSE/data/base_item.py b/src/OSmOSE/data/base_item.py index b9ec8001..3ba3204f 100644 --- a/src/OSmOSE/data/base_item.py +++ b/src/OSmOSE/data/base_item.py @@ -10,7 +10,7 @@ import numpy as np from OSmOSE.data.base_file import BaseFile -from OSmOSE.utils.data_utils import Event +from OSmOSE.data.event import Event if TYPE_CHECKING: from pandas import Timestamp diff --git a/src/OSmOSE/data/event.py b/src/OSmOSE/data/event.py new file mode 100644 index 00000000..21a6ef3e --- /dev/null +++ b/src/OSmOSE/data/event.py @@ -0,0 +1,150 @@ +"""Event class""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from pandas import Timestamp + + +@dataclass +class Event: + """Events are bounded between begin an end attributes. + + Classes that have a begin and an end should inherit from Event. + """ + + begin: Timestamp + end: Timestamp + + def overlaps(self, other: type[Event] | Event) -> bool: + """Return True if the other event shares time with the current event. + + Parameters + ---------- + other: type[Event] | Event + The other event. + + Returns + ------- + bool: + True if the two events are overlapping, False otherwise. + + Examples + -------- + >>> from pandas import Timestamp + >>> Event(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")).overlaps(Event(begin=Timestamp("2024-01-01 12:00:00"),end=Timestamp("2024-01-02 12:00:00"))) + True + >>> Event(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")).overlaps(Event(begin=Timestamp("2024-01-02 00:00:00"),end=Timestamp("2024-01-02 12:00:00"))) + False + + """ + return self.begin < other.end and self.end > other.begin + + @classmethod + def remove_overlaps(cls, events: list[TEvent]) -> list[TEvent]: + """Resolve overlaps between events. + + If two events overlap within the whole events collection + (that is if one event begins before the end of another event), + the earliest event's end is set to the begin of the latest object. + If multiple events overlap with one earlier event, only one is chosen as next. + The chosen next event is the one that ends the latest. + + Parameters + ---------- + events: list + List of events in which to remove the overlaps. + + Returns + ------- + list: + The list of events with no overlap. + + Examples + -------- + >>> from pandas import Timestamp + >>> events = [Event(begin=Timestamp("00:00:00"),end=Timestamp("00:00:15")), Event(begin=Timestamp("00:00:10"),end=Timestamp("00:00:20"))] + >>> events[0].end == events[1].begin + False + >>> events = Event.remove_overlaps(events) + >>> events[0].end == events[1].begin + True + + """ + events = sorted( + [copy.copy(event) for event in events], + key=lambda event: (event.begin, event.begin - event.end), + ) + concatenated_events = [] + for event in events: + concatenated_events.append(event) + overlapping_events = [ + event2 + for event2 in events + if event2 is not event and event.overlaps(event2) + ] + if not overlapping_events: + continue + kept_overlapping_event = max(overlapping_events, key=lambda item: item.end) + if kept_overlapping_event.end > event.end: + event.end = kept_overlapping_event.begin + else: + kept_overlapping_event = None + for dismissed_event in ( + event2 + for event2 in overlapping_events + if event2 is not kept_overlapping_event + ): + events.remove(dismissed_event) + return concatenated_events + + @classmethod + def fill_gaps( + cls, events: list[TEvent], filling_class: type[TEvent] + ) -> list[TEvent]: + """Return a list with empty events added in the gaps between items. + + The created empty events are instantiated from the class filling_class. + + Parameters + ---------- + events: list[TEvent] + List of events to fill. + filling_class: type[TEvent] + The class used for instantiating empty events in the gaps. + + Returns + ------- + list[TEvent]: + List of events with no gaps. + + Examples + -------- + >>> from pandas import Timestamp + >>> events = [Event(begin = Timestamp("00:00:00"), end = Timestamp("00:00:10")), Event(begin = Timestamp("00:00:15"), end = Timestamp("00:00:25"))] + >>> events = Event.fill_gaps(events, Event) + >>> [(event.begin.second, event.end.second) for event in events] + [(0, 10), (10, 15), (15, 25)] + + """ + events = sorted( + [copy.copy(event) for event in events], + key=lambda event: event.begin, + ) + filled_event_list = [] + for index, event in enumerate(events[:-1]): + next_event = events[index + 1] + filled_event_list.append(event) + if next_event.begin > event.end: + filled_event_list.append( + filling_class(begin=event.end, end=next_event.begin), + ) + filled_event_list.append(events[-1]) + return filled_event_list + + +TEvent = TypeVar("TEvent", bound=Event) diff --git a/src/OSmOSE/utils/data_utils.py b/src/OSmOSE/utils/data_utils.py deleted file mode 100644 index 906fcc0e..00000000 --- a/src/OSmOSE/utils/data_utils.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Util classes and functions for data objects.""" - -from __future__ import annotations - -import copy -from dataclasses import dataclass -from typing import TYPE_CHECKING, TypeVar - -if TYPE_CHECKING: - from pandas import Timestamp - - -@dataclass -class Event: - """Dataclass containing begin an end attributes. - - Classes that have a begin and an end should inherit from Event. - """ - - begin: Timestamp - end: Timestamp - - -TEvent = TypeVar("TEvent", bound=Event) - - -def is_overlapping( - event1: TEvent | Event, - event2: TEvent | Event, -) -> bool: - """Return True if the two events are overlapping, False otherwise. - - Events are objects that have begin and end Timestamp attributes. - - Parameters - ---------- - event1: Event - The first event. - event2: Event - The second event. - - Returns - ------- - bool: - True if the two events are overlapping, False otherwise. - - Examples - -------- - >>> is_overlapping(Event(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")), Event(begin=Timestamp("2024-01-01 12:00:00"),end=Timestamp("2024-01-02 12:00:00"))) - True - >>> is_overlapping(Event(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")), Event(begin=Timestamp("2024-01-02 00:00:00"),end=Timestamp("2024-01-02 12:00:00"))) - False - - """ - return event1.begin < event2.end and event1.end > event2.begin - - -def remove_overlaps(events: list[TEvent]) -> list[TEvent]: - """Resolve overlaps between events. - - If two events overlap within the whole events collection - (that is if one event begins before the end of another event), - the earliest event's end is set to the begin of the latest object. - If multiple events overlap with one earlier event, only one is chosen as next. - The chosen next event is the one that ends the latest. - - Parameters - ---------- - events: list - List of events in which to remove the overlaps. - - Returns - ------- - list: - The list of events with no overlap. - - Examples - -------- - >>> from dataclasses import dataclass - >>> @dataclass - ... class Item: - ... begin: Timestamp - ... end: Timestamp - >>> items = [Item(begin=Timestamp("00:00:00"),end=Timestamp("00:00:15")), Item(begin=Timestamp("00:00:10"),end=Timestamp("00:00:20"))] - >>> events[0].end == events[1].begin - False - >>> items = remove_overlaps(events) - >>> events[0].end == events[1].begin - True - - """ - events = sorted( - [copy.copy(event) for event in events], - key=lambda event: (event.begin, event.begin - event.end), - ) - concatenated_events = [] - for event in events: - concatenated_events.append(event) - overlapping_events = [ - event2 - for event2 in events - if event2 is not event and is_overlapping(event, event2) - ] - if not overlapping_events: - continue - kept_overlapping_event = max(overlapping_events, key=lambda item: item.end) - if kept_overlapping_event.end > event.end: - event.end = kept_overlapping_event.begin - else: - kept_overlapping_event = None - for dismissed_event in ( - event2 - for event2 in overlapping_events - if event2 is not kept_overlapping_event - ): - events.remove(dismissed_event) - return concatenated_events - - -def fill_gaps(events: list[TEvent], filling_class: type[TEvent]) -> list[TEvent]: - """Return a list with empty events added in the gaps between items. - - The created empty events are instantiated from the class filling_class. - - Parameters - ---------- - events: list[TEvent] - List of events to fill. - filling_class: type[TEvent] - The class used for instantiating empty events in the gaps. - - Returns - ------- - list[TEvent]: - List of events with no gaps. - - Examples - -------- - >>> events = [Event(begin = Timestamp("00:00:00"), end = Timestamp("00:00:10")), Event(begin = Timestamp("00:00:15"), end = Timestamp("00:00:25"))] - >>> events = fill_gaps(events, Event) - >>> [(event.begin.second, event.end.second) for event in events] - [(0, 10), (10, 15), (15, 25)] - - """ - events = sorted( - [copy.copy(event) for event in events], - key=lambda event: event.begin, - ) - filled_event_list = [] - for index, event in enumerate(events[:-1]): - next_event = events[index + 1] - filled_event_list.append(event) - if next_event.begin > event.end: - filled_event_list.append( - filling_class(begin=event.end, end=next_event.begin), - ) - filled_event_list.append(events[-1]) - return filled_event_list diff --git a/tests/test_data_utils.py b/tests/test_event.py similarity index 97% rename from tests/test_data_utils.py rename to tests/test_event.py index 60988361..330543c8 100644 --- a/tests/test_data_utils.py +++ b/tests/test_event.py @@ -3,7 +3,7 @@ import pytest from pandas import Timestamp -from OSmOSE.utils.data_utils import Event, fill_gaps, is_overlapping, remove_overlaps +from OSmOSE.data.event import Event @pytest.mark.parametrize( @@ -100,7 +100,7 @@ def test_overlapping_events( event2: Event, expected: bool, ) -> None: - assert is_overlapping(event1, event2) == expected + assert event1.overlaps(event2) is expected @pytest.mark.parametrize( @@ -158,7 +158,7 @@ def test_overlapping_events( ], ) def test_remove_overlaps(events: list[Event], expected: list[Event]) -> None: - cleaned_events = remove_overlaps(events) + cleaned_events = Event.remove_overlaps(events) assert cleaned_events == expected @@ -214,5 +214,5 @@ def test_remove_overlaps(events: list[Event], expected: list[Event]) -> None: ], ) def test_fill_event_gaps(events: list[Event], expected: list[Event]) -> None: - filled_events = fill_gaps(events, Event) + filled_events = Event.fill_gaps(events, Event) assert filled_events == expected From 7f7b88942f5edd3f7cc48f0292cf05c3c0ae95bd Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 15:06:29 +0100 Subject: [PATCH 049/118] linting --- src/OSmOSE/data/audio_file.py | 6 ++++-- src/OSmOSE/data/base_data.py | 2 +- src/OSmOSE/data/base_file.py | 6 ++++-- src/OSmOSE/data/event.py | 12 +++++++----- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index 65a9d28e..c47160cb 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -25,7 +25,8 @@ def __init__( ) -> None: """Initialize an AudioFile object with a path and a begin timestamp. - The begin timestamp can either be provided as a parameter or parsed from the filename according to the provided strptime_format. + The begin timestamp can either be provided as a parameter + or parsed from the filename according to the provided strptime_format. Parameters ---------- @@ -34,7 +35,8 @@ def __init__( begin: pandas.Timestamp | None Timestamp corresponding to the first data point in the file. If it is not provided, strptime_format is mandatory. - If both begin and strptime_format are provided, begin will overrule the timestamp embedded in the filename. + If both begin and strptime_format are provided, + begin will overrule the timestamp embedded in the filename. strptime_format: str | None The strptime format used in the text. It should use valid strftime codes (https://strftime.org/). diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index e6477ae1..2a18eb02 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -58,7 +58,7 @@ def get_value(self) -> np.ndarray: """Get the concatenated values from all Items.""" return np.concatenate([item.get_value() for item in self.items]) - def write(self, path: Path) -> None: + def write(self, path: Path) -> None: # noqa: ARG002 """Abstract method for writing the data.""" return diff --git a/src/OSmOSE/data/base_file.py b/src/OSmOSE/data/base_file.py index 37899ad8..8ceda5c9 100644 --- a/src/OSmOSE/data/base_file.py +++ b/src/OSmOSE/data/base_file.py @@ -34,7 +34,8 @@ def __init__( ) -> None: """Initialize a File object with a path and timestamps. - The begin timestamp can either be provided as a parameter or parsed from the filename according to the provided strptime_format. + The begin timestamp can either be provided as a parameter + or parsed from the filename according to the provided strptime_format. Parameters ---------- @@ -43,7 +44,8 @@ def __init__( begin: pandas.Timestamp | None Timestamp corresponding to the first data point in the file. If it is not provided, strptime_format is mandatory. - If both begin and strptime_format are provided, begin will overrule the timestamp embedded in the filename. + If both begin and strptime_format are provided, + begin will overrule the timestamp embedded in the filename. end: pandas.Timestamp | None (Optional) Timestamp after the last data point in the file. strptime_format: str | None diff --git a/src/OSmOSE/data/event.py b/src/OSmOSE/data/event.py index 21a6ef3e..33700f7d 100644 --- a/src/OSmOSE/data/event.py +++ b/src/OSmOSE/data/event.py @@ -1,4 +1,4 @@ -"""Event class""" +"""Event class.""" from __future__ import annotations @@ -41,7 +41,7 @@ def overlaps(self, other: type[Event] | Event) -> bool: >>> Event(begin=Timestamp("2024-01-01 00:00:00"),end=Timestamp("2024-01-02 00:00:00")).overlaps(Event(begin=Timestamp("2024-01-02 00:00:00"),end=Timestamp("2024-01-02 12:00:00"))) False - """ + """ # noqa: E501 return self.begin < other.end and self.end > other.begin @classmethod @@ -74,7 +74,7 @@ def remove_overlaps(cls, events: list[TEvent]) -> list[TEvent]: >>> events[0].end == events[1].begin True - """ + """ # noqa: E501 events = sorted( [copy.copy(event) for event in events], key=lambda event: (event.begin, event.begin - event.end), @@ -104,7 +104,9 @@ def remove_overlaps(cls, events: list[TEvent]) -> list[TEvent]: @classmethod def fill_gaps( - cls, events: list[TEvent], filling_class: type[TEvent] + cls, + events: list[TEvent], + filling_class: type[TEvent], ) -> list[TEvent]: """Return a list with empty events added in the gaps between items. @@ -130,7 +132,7 @@ def fill_gaps( >>> [(event.begin.second, event.end.second) for event in events] [(0, 10), (10, 15), (15, 25)] - """ + """ # noqa: E501 events = sorted( [copy.copy(event) for event in events], key=lambda event: event.begin, From 92446ce82067d1d1b09683c2eed160ac9d492a27 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 15:08:53 +0100 Subject: [PATCH 050/118] move total_seconds property to Event --- src/OSmOSE/data/base_data.py | 5 ----- src/OSmOSE/data/base_item.py | 5 ----- src/OSmOSE/data/event.py | 5 +++++ 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index 2a18eb02..14bc4e00 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -44,11 +44,6 @@ def __init__(self, items: list[TItem]) -> None: self.begin = min(item.begin for item in self.items) self.end = max(item.end for item in self.items) - @property - def total_seconds(self) -> float: - """Return the total duration of the data in seconds.""" - return (self.end - self.begin).total_seconds() - @property def is_empty(self) -> bool: """Return true if every item of this data object is empty.""" diff --git a/src/OSmOSE/data/base_item.py b/src/OSmOSE/data/base_item.py index 3ba3204f..c329fdc9 100644 --- a/src/OSmOSE/data/base_item.py +++ b/src/OSmOSE/data/base_item.py @@ -72,11 +72,6 @@ def is_empty(self) -> bool: """Return True if no File is attached to this Item.""" return self.file is None - @property - def total_seconds(self) -> float: - """Return the total duration of the item in seconds.""" - return (self.end - self.begin).total_seconds() - def __eq__(self, other: BaseItem[TFile]) -> bool: """Override the default implementation.""" if not isinstance(other, BaseItem): diff --git a/src/OSmOSE/data/event.py b/src/OSmOSE/data/event.py index 33700f7d..85b57c65 100644 --- a/src/OSmOSE/data/event.py +++ b/src/OSmOSE/data/event.py @@ -20,6 +20,11 @@ class Event: begin: Timestamp end: Timestamp + @property + def total_seconds(self) -> float: + """Return the total duration of the data in seconds.""" + return (self.end - self.begin).total_seconds() + def overlaps(self, other: type[Event] | Event) -> bool: """Return True if the other event shares time with the current event. From 1998e528fb67189a3405832b0643dfeed69a10ec Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 16 Dec 2024 17:08:23 +0100 Subject: [PATCH 051/118] add kwargs to Event.fill_gaps --- src/OSmOSE/data/audio_data.py | 16 ++++++++++++++-- src/OSmOSE/data/base_data.py | 20 +++++++++++++++++--- src/OSmOSE/data/event.py | 5 ++++- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 9752e162..86c6c9c7 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -30,7 +30,13 @@ class AudioData(BaseData[AudioItem, AudioFile]): The data is accessed via an AudioItem object per AudioFile. """ - def __init__(self, items: list[AudioItem], sample_rate: int | None = None) -> None: + def __init__( + self, + items: list[AudioItem] | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + sample_rate: int | None = None, + ) -> None: """Initialize an AudioData from a list of AudioItems. Parameters @@ -39,9 +45,15 @@ def __init__(self, items: list[AudioItem], sample_rate: int | None = None) -> No List of the AudioItem constituting the AudioData. sample_rate: int The sample rate of the audio data. + begin: Timestamp | None + Only effective if items is None. + Set the begin of the empty data. + end: Timestamp | None + Only effective if items is None. + Set the end of the empty data. """ - super().__init__(items) + super().__init__(items=items, begin=begin, end=end) self._set_sample_rate(sample_rate=sample_rate) @property diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index 14bc4e00..a00c8bb8 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -31,15 +31,29 @@ class BaseData(Generic[TItem, TFile], Event): The data is accessed via an Item object per File. """ - def __init__(self, items: list[TItem]) -> None: - """Initialize an BaseData from a list of Items. + def __init__( + self, + items: list[TItem] | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> None: + """Initialize a BaseData from a list of Items. Parameters ---------- - items: list[BaseItem] + items: list[BaseItem] | None List of the Items constituting the Data. + Defaulted to an empty item ranging from begin to end. + begin: Timestamp | None + Only effective if items is None. + Set the begin of the empty data. + end: Timestamp | None + Only effective if items is None. + Set the end of the empty data. """ + if not items: + items = [BaseItem(begin=begin, end=end)] self.items = items self.begin = min(item.begin for item in self.items) self.end = max(item.end for item in self.items) diff --git a/src/OSmOSE/data/event.py b/src/OSmOSE/data/event.py index 85b57c65..d8b7cc4a 100644 --- a/src/OSmOSE/data/event.py +++ b/src/OSmOSE/data/event.py @@ -112,6 +112,7 @@ def fill_gaps( cls, events: list[TEvent], filling_class: type[TEvent], + **kwargs: any, ) -> list[TEvent]: """Return a list with empty events added in the gaps between items. @@ -123,6 +124,8 @@ def fill_gaps( List of events to fill. filling_class: type[TEvent] The class used for instantiating empty events in the gaps. + kwargs: any + Additional parameters to pass to the filling instance constructor. Returns ------- @@ -148,7 +151,7 @@ def fill_gaps( filled_event_list.append(event) if next_event.begin > event.end: filled_event_list.append( - filling_class(begin=event.end, end=next_event.begin), + filling_class(begin=event.end, end=next_event.begin, **kwargs), ) filled_event_list.append(events[-1]) return filled_event_list From 6f27ac24e2b4bc09d6d5bdcad61825aab527aa34 Mon Sep 17 00:00:00 2001 From: Gauthier BERTHOMIEU Date: Tue, 17 Dec 2024 09:46:34 +0100 Subject: [PATCH 052/118] create output dirs before writing data --- src/OSmOSE/data/audio_data.py | 1 + src/OSmOSE/data/base_data.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 86c6c9c7..c9dc6bec 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -112,6 +112,7 @@ def write(self, folder: Path) -> None: Folder in which to write the audio file. """ + super().write(path=folder) sf.write(folder / f"{self}.wav", self.get_value(), self.sample_rate) def _get_item_value(self, item: AudioItem) -> np.ndarray: diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index a00c8bb8..21882094 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -10,6 +10,7 @@ import numpy as np +from OSmOSE.config import DPDEFAULT from OSmOSE.data.base_file import BaseFile from OSmOSE.data.base_item import BaseItem from OSmOSE.data.event import Event @@ -67,9 +68,12 @@ def get_value(self) -> np.ndarray: """Get the concatenated values from all Items.""" return np.concatenate([item.get_value() for item in self.items]) - def write(self, path: Path) -> None: # noqa: ARG002 - """Abstract method for writing the data.""" - return + def write(self, path: Path) -> None: + """Create the directory in which the data will be written. + + The actual data writing is left to the specified classes. + """ + path.mkdir(parents=True, exist_ok=True, mode=DPDEFAULT) @classmethod def from_files( From ef8ddcf149000d7b046cb1037777deea05aca965 Mon Sep 17 00:00:00 2001 From: Gauthier BERTHOMIEU Date: Tue, 17 Dec 2024 15:17:56 +0100 Subject: [PATCH 053/118] add spectro file --- src/OSmOSE/data/audio_item.py | 4 +- src/OSmOSE/data/spectro_file.py | 91 +++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 src/OSmOSE/data/spectro_file.py diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index c730d930..0b4128e4 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -21,7 +21,7 @@ def __init__( begin: Timestamp | None = None, end: Timestamp | None = None, ) -> None: - """Initialize an AudioItem from an AudioItem and begin/end timestamps. + """Initialize an AudioItem from an AudioFile and begin/end timestamps. Parameters ---------- @@ -49,7 +49,7 @@ def nb_channels(self) -> int: @classmethod def from_base_item(cls, item: BaseItem) -> AudioItem: - """Return an AudioItem object from an BaseItem object.""" + """Return an AudioItem object from a BaseItem object.""" file = item.file if not file or isinstance(file, AudioFile): return cls(file=file, begin=item.begin, end=item.end) diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py new file mode 100644 index 00000000..66d21864 --- /dev/null +++ b/src/OSmOSE/data/spectro_file.py @@ -0,0 +1,91 @@ +"""Spectro file associated with timestamps. + +Spectro files are npz files with Time and Sxx arrays. +Metadata (time_resolution) are stored as separate arrays. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from pandas import Timedelta, Timestamp + +from OSmOSE.data.base_file import BaseFile + +if TYPE_CHECKING: + from os import PathLike + + +class SpectroFile(BaseFile): + """Spectro file associated with timestamps. + + Spectro files are npz files with Time and Sxx arrays. + Metadata (time_resolution) are stored as separate arrays. + """ + + def __init__( + self, + path: PathLike | str, + begin: Timestamp | None = None, + strptime_format: str | None = None, + ) -> None: + """Initialize a SpectroFile object with a path and a begin timestamp. + + The begin timestamp can either be provided as a parameter + or parsed from the filename according to the provided strptime_format. + + Parameters + ---------- + path: PathLike | str + Full path to the file. + begin: pandas.Timestamp | None + Timestamp corresponding to the first data bin in the file. + If it is not provided, strptime_format is mandatory. + If both begin and strptime_format are provided, + begin will overrule the timestamp embedded in the filename. + strptime_format: str | None + The strptime format used in the text. + It should use valid strftime codes (https://strftime.org/). + Example: '%y%m%d_%H:%M:%S'. + + """ + super().__init__(path=path, begin=begin, strptime_format=strptime_format) + self._read_metadata(path=path) + self.end = self.begin + self.duration + + def _read_metadata(self, path: PathLike) -> None: + with np.load(path) as data: + time_resolution = float(data["time_resolution"]) + nb_points = data["Time"].shape[0] + self.time_resolution = Timedelta(seconds=time_resolution) + self.duration = self.time_resolution * nb_points + self.nb_points = nb_points + + def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: + """Return the spectro data between start and stop from the file. + + The data is a 2D array representing the sxx values of the spectrogram. + + Parameters + ---------- + start: pandas.Timestamp + Timestamp corresponding to the first time bin to read. + stop: pandas.Timestamp + Timestamp after the last time bin to read. + + Returns + ------- + numpy.ndarray: + The spectrogram data between start and stop. + + """ + start_bin = round((start - self.begin) / self.time_resolution) + stop_bin = round((stop - self.begin) / self.time_resolution) + with np.load(self.path) as data: + return data["Sxx"][:, start_bin:stop_bin] + + @classmethod + def from_base_file(cls, file: BaseFile) -> SpectroFile: + """Return a SpectroFile object from a BaseFile object.""" + return cls(path=file.path, begin=file.begin) From b3471f866bf452269478868488b97ab465d9b8e0 Mon Sep 17 00:00:00 2001 From: Gauthier BERTHOMIEU Date: Tue, 17 Dec 2024 15:26:34 +0100 Subject: [PATCH 054/118] add spectro item --- src/OSmOSE/data/spectro_item.py | 57 +++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/OSmOSE/data/spectro_item.py diff --git a/src/OSmOSE/data/spectro_item.py b/src/OSmOSE/data/spectro_item.py new file mode 100644 index 00000000..b722100d --- /dev/null +++ b/src/OSmOSE/data/spectro_item.py @@ -0,0 +1,57 @@ +"""SpectroItem corresponding to a portion of a SpectroFile object.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from OSmOSE.data.base_file import BaseFile +from OSmOSE.data.base_item import BaseItem +from OSmOSE.data.spectro_file import SpectroFile + +if TYPE_CHECKING: + from pandas import Timedelta, Timestamp + + +class SpectroItem(BaseItem[SpectroFile]): + """SpectroItem corresponding to a portion of a SpectroFile object.""" + + def __init__( + self, + file: SpectroFile | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> None: + """Initialize a SpectroItem from a SpectroFile and begin/end timestamps. + + Parameters + ---------- + file: OSmOSE.data.spectro_file.SpectroFile + The SpectroFile in which this Item belongs. + begin: pandas.Timestamp (optional) + The timestamp at which this item begins. + It is defaulted to the SpectroFile begin. + end: pandas.Timestamp (optional) + The timestamp at which this item ends. + It is defaulted to the SpectroFile end. + + """ + super().__init__(file, begin, end) + + @property + def time_resolution(self) -> Timedelta: + """Time resolution of the associated SpectroFile.""" + return None if self.is_empty else self.file.time_resolution + + @classmethod + def from_base_item(cls, item: BaseItem) -> SpectroItem: + """Return a SpectroItem object from a BaseItem object.""" + file = item.file + if not file or isinstance(file, SpectroFile): + return cls(file=file, begin=item.begin, end=item.end) + if isinstance(file, BaseFile): + return cls( + file=SpectroFile.from_base_file(file), + begin=item.begin, + end=item.end, + ) + raise TypeError From 08c9e3a6e34f391d3846c63ebc340ac0bc3d864a Mon Sep 17 00:00:00 2001 From: Gauthier BERTHOMIEU Date: Tue, 17 Dec 2024 16:14:13 +0100 Subject: [PATCH 055/118] replace event.total_seconds by event.duration --- src/OSmOSE/data/audio_data.py | 4 ++-- src/OSmOSE/data/event.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index c9dc6bec..936d2661 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -66,7 +66,7 @@ def nb_channels(self) -> int: @property def shape(self) -> tuple[int, ...]: """Shape of the audio data.""" - data_length = int(self.sample_rate * self.total_seconds) + data_length = int(self.sample_rate * self.duration.total_seconds()) return data_length if self.nb_channels <= 1 else (data_length, self.nb_channels) def __str__(self) -> str: @@ -119,7 +119,7 @@ def _get_item_value(self, item: AudioItem) -> np.ndarray: """Return the resampled (if needed) data from the audio item.""" item_data = item.get_value() if item.is_empty: - return item_data.repeat(int(item.total_seconds * self.sample_rate)) + return item_data.repeat(int(item.duration.total_seconds() * self.sample_rate)) if item.sample_rate != self.sample_rate: return resample(item_data, item.sample_rate, self.sample_rate) return item_data diff --git a/src/OSmOSE/data/event.py b/src/OSmOSE/data/event.py index d8b7cc4a..afe80df5 100644 --- a/src/OSmOSE/data/event.py +++ b/src/OSmOSE/data/event.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, TypeVar if TYPE_CHECKING: - from pandas import Timestamp + from pandas import Timedelta, Timestamp @dataclass @@ -21,9 +21,9 @@ class Event: end: Timestamp @property - def total_seconds(self) -> float: + def duration(self) -> Timedelta: """Return the total duration of the data in seconds.""" - return (self.end - self.begin).total_seconds() + return self.end - self.begin def overlaps(self, other: type[Event] | Event) -> bool: """Return True if the other event shares time with the current event. From f18277e8d919d391a74306439e9d8ddb7fd27ade Mon Sep 17 00:00:00 2001 From: Gauthier BERTHOMIEU Date: Tue, 17 Dec 2024 16:39:04 +0100 Subject: [PATCH 056/118] add spectro data --- src/OSmOSE/data/audio_data.py | 4 +- src/OSmOSE/data/spectro_data.py | 155 ++++++++++++++++++++++++++++++++ src/OSmOSE/data/spectro_file.py | 2 +- src/OSmOSE/data/spectro_item.py | 1 + 4 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 src/OSmOSE/data/spectro_data.py diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 936d2661..3904a16b 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -119,7 +119,9 @@ def _get_item_value(self, item: AudioItem) -> np.ndarray: """Return the resampled (if needed) data from the audio item.""" item_data = item.get_value() if item.is_empty: - return item_data.repeat(int(item.duration.total_seconds() * self.sample_rate)) + return item_data.repeat( + int(item.duration.total_seconds() * self.sample_rate) + ) if item.sample_rate != self.sample_rate: return resample(item_data, item.sample_rate, self.sample_rate) return item_data diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py new file mode 100644 index 00000000..653671ca --- /dev/null +++ b/src/OSmOSE/data/spectro_data.py @@ -0,0 +1,155 @@ +"""SpectroData represent spectrogram data retrieved from SpectroFiles. + +The SpectroData has a collection of SpectroItem. +The data is accessed via a SpectroItem object per SpectroFile. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES +from OSmOSE.data.base_data import BaseData +from OSmOSE.data.spectro_file import SpectroFile +from OSmOSE.data.spectro_item import SpectroItem + +if TYPE_CHECKING: + from pathlib import Path + + from pandas import Timedelta, Timestamp + + +class SpectroData(BaseData[SpectroItem, SpectroFile]): + """SpectroData represent Spectro data scattered through different SpectroFiles. + + The SpectroData has a collection of SpectroItem. + The data is accessed via a SpectroItem object per SpectroFile. + """ + + def __init__( + self, + items: list[SpectroItem] | None = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + time_resolution: Timedelta | None = None, + ) -> None: + """Initialize a SpectroData from a list of SpectroItems. + + Parameters + ---------- + items: list[SpectroItem] + List of the SpectroItem constituting the SpectroData. + time_resolution: Timedelta + The time resolution of the Spectro data. + begin: Timestamp | None + Only effective if items is None. + Set the begin of the empty data. + end: Timestamp | None + Only effective if items is None. + Set the end of the empty data. + + """ + super().__init__(items=items, begin=begin, end=end) + self._set_time_resolution(time_resolution=time_resolution) + + @property + def shape(self) -> tuple[int, ...]: + """Shape of the Spectro data.""" + return max(item.shape[0] for item in self.items), sum( + item.shape[1] for item in self.items + ) + + def __str__(self) -> str: + """Overwrite __str__.""" + return self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES) + + def _set_time_resolution(self, time_resolution: Timedelta) -> None: + """Set the SpectroFile time resolution.""" + if len(tr := {item.time_resolution for item in self.items}) > 1: + raise ValueError("Items don't have the same time resolution") + self.time_resolution = tr.pop() if len(tr) == 1 else time_resolution + + def get_value(self) -> np.ndarray: + """Return the value of the Spectro data. + + The data from the Spectro file will be resampled if necessary. + """ + data = np.zeros(shape=self.shape) + idx = 0 + for item in self.items: + item_data = self._get_item_value(item) + time_bins = item_data.shape[1] + data[:, idx : idx + time_bins] = item_data + idx += time_bins + return data + + def write(self, folder: Path) -> None: + """Write the Spectro data to file. + + Parameters + ---------- + folder: pathlib.Path + Folder in which to write the Spectro file. + + """ + super().write(path=folder) + # TODO: implement npz write + + def _get_item_value(self, item: SpectroItem) -> np.ndarray: + """Return the resampled (if needed) data from the Spectro item.""" + item_data = item.get_value() + if item.is_empty: + return item_data.repeat(round(item.duration / self.time_resolution)) + if item.time_resolution != self.time_resolution: + raise ValueError("Time resolutions don't match.") + return item_data + + @classmethod + def from_files( + cls, + files: list[SpectroFile], + begin: Timestamp | None = None, + end: Timestamp | None = None, + ) -> SpectroData: + """Return a SpectroData object from a list of SpectroFiles. + + Parameters + ---------- + files: list[SpectroFile] + List of SpectroFiles containing the data. + begin: Timestamp | None + Begin of the data object. + Defaulted to the begin of the first file. + end: Timestamp | None + End of the data object. + Defaulted to the end of the last file. + + Returns + ------- + SpectroData: + The SpectroData object. + + """ + return cls.from_base_data(BaseData.from_files(files, begin, end)) + + @classmethod + def from_base_data( + cls, + data: BaseData, + ) -> SpectroData: + """Return an SpectroData object from a BaseData object. + + Parameters + ---------- + data: BaseData + BaseData object to convert to SpectroData. + + Returns + ------- + SpectroData: + The SpectroData object. + + """ + return cls([SpectroItem.from_base_item(item) for item in data.items]) diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index 66d21864..5006ac43 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -59,7 +59,7 @@ def _read_metadata(self, path: PathLike) -> None: time_resolution = float(data["time_resolution"]) nb_points = data["Time"].shape[0] self.time_resolution = Timedelta(seconds=time_resolution) - self.duration = self.time_resolution * nb_points + self.end = self.begin + self.time_resolution * nb_points self.nb_points = nb_points def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: diff --git a/src/OSmOSE/data/spectro_item.py b/src/OSmOSE/data/spectro_item.py index b722100d..ca239ddc 100644 --- a/src/OSmOSE/data/spectro_item.py +++ b/src/OSmOSE/data/spectro_item.py @@ -36,6 +36,7 @@ def __init__( """ super().__init__(file, begin, end) + self.shape = self.get_value().shape @property def time_resolution(self) -> Timedelta: From 66d22f8b124ceca37a5e5fe969dde49cf643b756 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 19 Dec 2024 09:58:02 +0100 Subject: [PATCH 057/118] first spectrogram prototype --- src/OSmOSE/data/spectro_data.py | 81 +++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 9 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 653671ca..4dd61e6b 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -8,9 +8,12 @@ from typing import TYPE_CHECKING +import matplotlib.pyplot as plt import numpy as np +from scipy.signal import ShortTimeFFT from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES +from OSmOSE.data.audio_data import AudioData from OSmOSE.data.base_data import BaseData from OSmOSE.data.spectro_file import SpectroFile from OSmOSE.data.spectro_item import SpectroItem @@ -31,9 +34,11 @@ class SpectroData(BaseData[SpectroItem, SpectroFile]): def __init__( self, items: list[SpectroItem] | None = None, + audio_data: AudioData = None, begin: Timestamp | None = None, end: Timestamp | None = None, time_resolution: Timedelta | None = None, + fft : ShortTimeFFT | None = None, ) -> None: """Initialize a SpectroData from a list of SpectroItems. @@ -52,7 +57,9 @@ def __init__( """ super().__init__(items=items, begin=begin, end=end) - self._set_time_resolution(time_resolution=time_resolution) + # self._set_time_resolution(time_resolution=time_resolution) + self.audio_data = audio_data + self.fft = fft @property def shape(self) -> tuple[int, ...]: @@ -76,14 +83,56 @@ def get_value(self) -> np.ndarray: The data from the Spectro file will be resampled if necessary. """ - data = np.zeros(shape=self.shape) - idx = 0 - for item in self.items: - item_data = self._get_item_value(item) - time_bins = item_data.shape[1] - data[:, idx : idx + time_bins] = item_data - idx += time_bins - return data + if not all(item.is_empty for item in self.items): + return self._get_value_from_items(self.items) + if not self.audio_data or not self.fft: + raise ValueError("SpectroData has not been initialized") + + return self.fft.spectrogram(self.audio_data.get_value()) + + def save_spectrogram(self, folder: Path, custom_frequency_scale = "linear") -> None: + super().write(folder) + my_dpi = 100 + fact_x = 1.3 + fact_y = 1.3 + fig, ax = plt.subplots( + nrows=1, + ncols=1, + figsize=(fact_x * 1800 / my_dpi, fact_y * 512 / my_dpi), + dpi=my_dpi, + ) + + sx = self.get_value() + time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] + freq = self.fft.f + log_spectro = 10 * np.log10(abs(sx) + 1e-12) + + color_map = plt.get_cmap("viridis") + + if custom_frequency_scale == "linear": + plt.pcolormesh(time, freq, log_spectro, cmap=color_map) + elif custom_frequency_scale == "log": + plt.pcolormesh(time, freq, log_spectro, cmap=color_map) + plt.yscale("log") + plt.ylim(freq[freq > 0].min(), self.fft.fs / 2) + + # plt.clim(vmin=min(log_spectro, key=lambda s: s), vmax=max(log_spectro, key=lambda s: s)) + + fig.axes[0].get_xaxis().set_visible(False) + fig.axes[0].get_yaxis().set_visible(False) + ax.set_frame_on(False) + ax.spines["right"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["top"].set_visible(False) + plt.axis("off") + plt.subplots_adjust( + top=1, bottom=0, right=1, left=0, hspace=0, wspace=0, + ) # delete white borders + # Saving spectrogram plot to file + # plt.show(bbox_inches="tight", pad_inches=0) + plt.savefig(f"{folder / str(self)}", bbox_inches="tight", pad_inches=0) + plt.close() def write(self, folder: Path) -> None: """Write the Spectro data to file. @@ -97,6 +146,16 @@ def write(self, folder: Path) -> None: super().write(path=folder) # TODO: implement npz write + def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: + data = np.zeros(shape=self.shape) + idx = 0 + for item in self.items: + item_data = self._get_item_value(item) + time_bins = item_data.shape[1] + data[:, idx: idx + time_bins] = item_data + idx += time_bins + return data + def _get_item_value(self, item: SpectroItem) -> np.ndarray: """Return the resampled (if needed) data from the Spectro item.""" item_data = item.get_value() @@ -153,3 +212,7 @@ def from_base_data( """ return cls([SpectroItem.from_base_item(item) for item in data.items]) + + @classmethod + def from_audio_data(cls, data: AudioData, fft: ShortTimeFFT) -> SpectroData: + return cls(audio_data=data, fft=fft, begin=data.begin, end=data.end) From 416579a8a9ce0fdef8b21ed4eff9535d6d17149e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 7 Jan 2025 15:52:54 +0100 Subject: [PATCH 058/118] pyplot.Axes as SpectroData property --- src/OSmOSE/data/spectro_data.py | 74 +++++++++++++++++---------------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 4dd61e6b..2e7e66de 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -60,6 +60,38 @@ def __init__( # self._set_time_resolution(time_resolution=time_resolution) self.audio_data = audio_data self.fft = fft + self._ax = None + + @property + def ax(self) -> plt.Axes: + if self._ax is not None: + return self._ax + + # Legacy OSEkit behaviour, done in the getter so that plt figure is created on demand only. + _, ax = plt.subplots( + nrows=1, + ncols=1, + figsize=(1.3 * 1800 / 100, 1.3 * 512 / 100), + dpi=100, + ) + + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + ax.set_frame_on(False) + ax.spines["right"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["top"].set_visible(False) + plt.axis("off") + plt.subplots_adjust( + top=1, bottom=0, right=1, left=0, hspace=0, wspace=0, + ) + self.ax = ax + return ax + + @ax.setter + def ax(self, ax: plt.Axes | None) -> None: + self._ax = ax @property def shape(self) -> tuple[int, ...]: @@ -90,47 +122,17 @@ def get_value(self) -> np.ndarray: return self.fft.spectrogram(self.audio_data.get_value()) - def save_spectrogram(self, folder: Path, custom_frequency_scale = "linear") -> None: - super().write(folder) - my_dpi = 100 - fact_x = 1.3 - fact_y = 1.3 - fig, ax = plt.subplots( - nrows=1, - ncols=1, - figsize=(fact_x * 1800 / my_dpi, fact_y * 512 / my_dpi), - dpi=my_dpi, - ) - + def plot(self): sx = self.get_value() time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f log_spectro = 10 * np.log10(abs(sx) + 1e-12) + self.ax.pcolormesh(time, freq, log_spectro) - color_map = plt.get_cmap("viridis") - - if custom_frequency_scale == "linear": - plt.pcolormesh(time, freq, log_spectro, cmap=color_map) - elif custom_frequency_scale == "log": - plt.pcolormesh(time, freq, log_spectro, cmap=color_map) - plt.yscale("log") - plt.ylim(freq[freq > 0].min(), self.fft.fs / 2) - - # plt.clim(vmin=min(log_spectro, key=lambda s: s), vmax=max(log_spectro, key=lambda s: s)) - - fig.axes[0].get_xaxis().set_visible(False) - fig.axes[0].get_yaxis().set_visible(False) - ax.set_frame_on(False) - ax.spines["right"].set_visible(False) - ax.spines["left"].set_visible(False) - ax.spines["bottom"].set_visible(False) - ax.spines["top"].set_visible(False) - plt.axis("off") - plt.subplots_adjust( - top=1, bottom=0, right=1, left=0, hspace=0, wspace=0, - ) # delete white borders - # Saving spectrogram plot to file - # plt.show(bbox_inches="tight", pad_inches=0) + def save_spectrogram(self, folder: Path) -> None: + super().write(folder) + self.plot() + plt.figure(self.ax.get_figure().number) plt.savefig(f"{folder / str(self)}", bbox_inches="tight", pad_inches=0) plt.close() From 7e14991e5690eb2af0aac57254d50992b8421b46 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 7 Jan 2025 16:10:24 +0100 Subject: [PATCH 059/118] replace ax property with ax SpectroData.plot() parameter --- src/OSmOSE/data/spectro_data.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 2e7e66de..444f7edc 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -60,14 +60,11 @@ def __init__( # self._set_time_resolution(time_resolution=time_resolution) self.audio_data = audio_data self.fft = fft - self._ax = None - @property - def ax(self) -> plt.Axes: - if self._ax is not None: - return self._ax + @staticmethod + def get_default_ax() -> plt.Axes: - # Legacy OSEkit behaviour, done in the getter so that plt figure is created on demand only. + # Legacy OSEkit behaviour. _, ax = plt.subplots( nrows=1, ncols=1, @@ -86,13 +83,8 @@ def ax(self) -> plt.Axes: plt.subplots_adjust( top=1, bottom=0, right=1, left=0, hspace=0, wspace=0, ) - self.ax = ax return ax - @ax.setter - def ax(self, ax: plt.Axes | None) -> None: - self._ax = ax - @property def shape(self) -> tuple[int, ...]: """Shape of the Spectro data.""" @@ -122,17 +114,17 @@ def get_value(self) -> np.ndarray: return self.fft.spectrogram(self.audio_data.get_value()) - def plot(self): + def plot(self, ax: plt.Axes | None = None) -> None: + ax = ax if ax is not None else SpectroData.get_default_ax() sx = self.get_value() time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f log_spectro = 10 * np.log10(abs(sx) + 1e-12) - self.ax.pcolormesh(time, freq, log_spectro) + ax.pcolormesh(time, freq, log_spectro) - def save_spectrogram(self, folder: Path) -> None: + def save_spectrogram(self, folder: Path, ax: plt.Axes | None = None) -> None: super().write(folder) - self.plot() - plt.figure(self.ax.get_figure().number) + self.plot(ax) plt.savefig(f"{folder / str(self)}", bbox_inches="tight", pad_inches=0) plt.close() From 6809796737f62013438092494b93313b0f73656c Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 7 Jan 2025 16:19:52 +0100 Subject: [PATCH 060/118] move log part to get_value method --- src/OSmOSE/data/spectro_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 444f7edc..d78ffaec 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -112,15 +112,15 @@ def get_value(self) -> np.ndarray: if not self.audio_data or not self.fft: raise ValueError("SpectroData has not been initialized") - return self.fft.spectrogram(self.audio_data.get_value()) + sx = self.fft.spectrogram(self.audio_data.get_value()) + return 10 * np.log10(abs(sx) + np.nextafter(0,1)) def plot(self, ax: plt.Axes | None = None) -> None: ax = ax if ax is not None else SpectroData.get_default_ax() sx = self.get_value() time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f - log_spectro = 10 * np.log10(abs(sx) + 1e-12) - ax.pcolormesh(time, freq, log_spectro) + ax.pcolormesh(time, freq, sx) def save_spectrogram(self, folder: Path, ax: plt.Axes | None = None) -> None: super().write(folder) From 912f595448c870cb304537e33cbeea642ee83877 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 8 Jan 2025 12:13:33 +0100 Subject: [PATCH 061/118] write npz and read npz metadata --- src/OSmOSE/data/spectro_data.py | 8 +++++++- src/OSmOSE/data/spectro_file.py | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index d78ffaec..03f68fd5 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -138,7 +138,13 @@ def write(self, folder: Path) -> None: """ super().write(path=folder) - # TODO: implement npz write + sx = self.get_value() + time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] + freq = self.fft.f + window = self.fft.win + hop = self.fft.hop + fs = [self.fft.fs] + np.savez(file = folder / f"{self}.npz", fs = fs, time = time, freq = freq, window = window, hop = hop, sx = sx) def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: data = np.zeros(shape=self.shape) diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index 5006ac43..7c66ae1f 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -56,11 +56,18 @@ def __init__( def _read_metadata(self, path: PathLike) -> None: with np.load(path) as data: - time_resolution = float(data["time_resolution"]) - nb_points = data["Time"].shape[0] - self.time_resolution = Timedelta(seconds=time_resolution) - self.end = self.begin + self.time_resolution * nb_points - self.nb_points = nb_points + sample_rate = data["fs"][0] + time = data["time"] + freq = data["freq"] + + self.sample_rate = sample_rate + + delta_times = [Timedelta(seconds=time[i] - time[i-1]).round(freq = "ns") for i in range(1,time.shape[0])] + most_frequent_delta_time = max(((v, delta_times.count(v)) for v in set(delta_times)), key=lambda i: i[1])[0] + self.time_resolution = most_frequent_delta_time + self.end = (self.begin + Timedelta(seconds = time[-1]) + self.time_resolution).round(freq = "us") + + self.freq = freq def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the spectro data between start and stop from the file. From 24f5a38607fdb33a3c211fc698d8bbe3252b12b7 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 8 Jan 2025 14:44:40 +0100 Subject: [PATCH 062/118] data read from spectro file --- src/OSmOSE/data/spectro_data.py | 2 +- src/OSmOSE/data/spectro_file.py | 24 +++++++++++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 03f68fd5..3f516bb9 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -142,7 +142,7 @@ def write(self, folder: Path) -> None: time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f window = self.fft.win - hop = self.fft.hop + hop = [self.fft.hop] fs = [self.fft.fs] np.savez(file = folder / f"{self}.npz", fs = fs, time = time, freq = freq, window = window, hop = hop, sx = sx) diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index 7c66ae1f..f7c172d0 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING import numpy as np +import pandas as pd from pandas import Timedelta, Timestamp from OSmOSE.data.base_file import BaseFile @@ -52,13 +53,14 @@ def __init__( """ super().__init__(path=path, begin=begin, strptime_format=strptime_format) self._read_metadata(path=path) - self.end = self.begin + self.duration def _read_metadata(self, path: PathLike) -> None: with np.load(path) as data: sample_rate = data["fs"][0] time = data["time"] freq = data["freq"] + hop = data["hop"][0] + window = data["window"] self.sample_rate = sample_rate @@ -69,7 +71,10 @@ def _read_metadata(self, path: PathLike) -> None: self.freq = freq - def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: + self.window = window + self.hop = hop + + def read(self, start: Timestamp, stop: Timestamp) -> pd.DataFrame: """Return the spectro data between start and stop from the file. The data is a 2D array representing the sxx values of the spectrogram. @@ -87,10 +92,19 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: The spectrogram data between start and stop. """ - start_bin = round((start - self.begin) / self.time_resolution) - stop_bin = round((stop - self.begin) / self.time_resolution) with np.load(self.path) as data: - return data["Sxx"][:, start_bin:stop_bin] + time = data["time"] + + start_bin = next(idx for idx,t in enumerate(time) if self.begin + Timedelta(seconds = t) > start) - 1 + start_bin = max(start_bin, 0) + + stop_bin = next(idx for idx,t in list(enumerate(time))[::-1] if self.begin + Timedelta(seconds = t) < stop) + 1 + stop_bin = min(stop_bin, time.shape[0]-1) + + sx = data["sx"][:, start_bin:stop_bin+1] + time = time[start_bin:stop_bin+1] - time[start_bin] + + return pd.DataFrame({"time": time, **dict(zip(self.freq,sx))}) @classmethod def from_base_file(cls, file: BaseFile) -> SpectroFile: From 07913d21516ad1816b5d9afdb70f4493ba7acced Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 8 Jan 2025 17:11:50 +0100 Subject: [PATCH 063/118] add SpectroData from SpectroFile logic --- src/OSmOSE/data/spectro_data.py | 43 +++++++++++++++++++++------------ src/OSmOSE/data/spectro_file.py | 6 ++--- src/OSmOSE/data/spectro_item.py | 16 +++++++++++- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 3f516bb9..bff97eea 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -10,6 +10,8 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd +from pandas import DataFrame from scipy.signal import ShortTimeFFT from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES @@ -146,23 +148,34 @@ def write(self, folder: Path) -> None: fs = [self.fft.fs] np.savez(file = folder / f"{self}.npz", fs = fs, time = time, freq = freq, window = window, hop = hop, sx = sx) - def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: - data = np.zeros(shape=self.shape) - idx = 0 - for item in self.items: - item_data = self._get_item_value(item) - time_bins = item_data.shape[1] - data[:, idx: idx + time_bins] = item_data - idx += time_bins - return data - - def _get_item_value(self, item: SpectroItem) -> np.ndarray: + def _get_value_from_items(self, items: list[SpectroItem]) -> DataFrame: + if not all(np.array_equal(items[0].file.freq,i.file.freq) for i in items[1:] if not i.is_empty): + raise ValueError("Items don't have the same frequency bins.") + + if len({i.file.time_resolution for i in items if not i.is_empty}) > 1: + raise ValueError("Items don't have the same time resolution.") + + time_resolution = next(i.file.time_resolution for i in items if not i.is_empty) + freq = next(i.file.freq for i in items if not i.is_empty) + + joined_df = self._get_item_value(items[0], time_resolution, freq) + + for item in items[1:]: + time_offset = joined_df["time"].iloc[-1] + time_resolution.total_seconds() + item_data = self._get_item_value(item, time_resolution, freq) + item_data["time"] += time_offset + joined_df = pd.concat((joined_df, item_data)) + + return joined_df.iloc[:, 1:].T.to_numpy() + + + def _get_item_value(self, item: SpectroItem, time_resolution: Timedelta | None = None, freq: np.ndarray | None = None) -> DataFrame: """Return the resampled (if needed) data from the Spectro item.""" - item_data = item.get_value() + item_data = item.get_value(freq) if item.is_empty: - return item_data.repeat(round(item.duration / self.time_resolution)) - if item.time_resolution != self.time_resolution: - raise ValueError("Time resolutions don't match.") + time = np.arange(item.duration // time_resolution) * time_resolution.total_seconds() + for t in time: + item_data.loc[item_data.shape[0]] = [t, *[-120.] * (item_data.shape[1]-1)] return item_data @classmethod diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index f7c172d0..17480218 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -99,10 +99,10 @@ def read(self, start: Timestamp, stop: Timestamp) -> pd.DataFrame: start_bin = max(start_bin, 0) stop_bin = next(idx for idx,t in list(enumerate(time))[::-1] if self.begin + Timedelta(seconds = t) < stop) + 1 - stop_bin = min(stop_bin, time.shape[0]-1) + stop_bin = min(stop_bin, time.shape[0]) - sx = data["sx"][:, start_bin:stop_bin+1] - time = time[start_bin:stop_bin+1] - time[start_bin] + sx = data["sx"][:, start_bin:stop_bin] + time = time[start_bin:stop_bin] - time[start_bin] return pd.DataFrame({"time": time, **dict(zip(self.freq,sx))}) diff --git a/src/OSmOSE/data/spectro_item.py b/src/OSmOSE/data/spectro_item.py index ca239ddc..1bc2aff6 100644 --- a/src/OSmOSE/data/spectro_item.py +++ b/src/OSmOSE/data/spectro_item.py @@ -4,6 +4,9 @@ from typing import TYPE_CHECKING +import numpy as np +from pandas import DataFrame + from OSmOSE.data.base_file import BaseFile from OSmOSE.data.base_item import BaseItem from OSmOSE.data.spectro_file import SpectroFile @@ -36,7 +39,7 @@ def __init__( """ super().__init__(file, begin, end) - self.shape = self.get_value().shape + #self.shape = self.get_value().shape @property def time_resolution(self) -> Timedelta: @@ -56,3 +59,14 @@ def from_base_item(cls, item: BaseItem) -> SpectroItem: end=item.end, ) raise TypeError + + def get_value(self, freq: np.ndarray | None = None) -> DataFrame: + """Get the values from the File between the begin and stop timestamps. + + If the Item is empty, return a single 0. + """ + return ( + DataFrame(columns = ["time", *freq]) + if self.is_empty + else self.file.read(start=self.begin, stop=self.end) + ) From 69a61ca0486d5857346314c675684d5951acc3e6 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 8 Jan 2025 17:17:13 +0100 Subject: [PATCH 064/118] recreate sft on npz file loading --- src/OSmOSE/data/spectro_data.py | 52 ++++++++++++++++++++++++++------- src/OSmOSE/data/spectro_file.py | 35 +++++++++++++++++----- src/OSmOSE/data/spectro_item.py | 4 +-- 3 files changed, 71 insertions(+), 20 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index bff97eea..d8c9ddd0 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -40,7 +40,7 @@ def __init__( begin: Timestamp | None = None, end: Timestamp | None = None, time_resolution: Timedelta | None = None, - fft : ShortTimeFFT | None = None, + fft: ShortTimeFFT | None = None, ) -> None: """Initialize a SpectroData from a list of SpectroItems. @@ -83,7 +83,12 @@ def get_default_ax() -> plt.Axes: ax.spines["top"].set_visible(False) plt.axis("off") plt.subplots_adjust( - top=1, bottom=0, right=1, left=0, hspace=0, wspace=0, + top=1, + bottom=0, + right=1, + left=0, + hspace=0, + wspace=0, ) return ax @@ -115,7 +120,7 @@ def get_value(self) -> np.ndarray: raise ValueError("SpectroData has not been initialized") sx = self.fft.spectrogram(self.audio_data.get_value()) - return 10 * np.log10(abs(sx) + np.nextafter(0,1)) + return 10 * np.log10(abs(sx) + np.nextafter(0, 1)) def plot(self, ax: plt.Axes | None = None) -> None: ax = ax if ax is not None else SpectroData.get_default_ax() @@ -146,10 +151,22 @@ def write(self, folder: Path) -> None: window = self.fft.win hop = [self.fft.hop] fs = [self.fft.fs] - np.savez(file = folder / f"{self}.npz", fs = fs, time = time, freq = freq, window = window, hop = hop, sx = sx) + np.savez( + file=folder / f"{self}.npz", + fs=fs, + time=time, + freq=freq, + window=window, + hop=hop, + sx=sx, + ) def _get_value_from_items(self, items: list[SpectroItem]) -> DataFrame: - if not all(np.array_equal(items[0].file.freq,i.file.freq) for i in items[1:] if not i.is_empty): + if not all( + np.array_equal(items[0].file.freq, i.file.freq) + for i in items[1:] + if not i.is_empty + ): raise ValueError("Items don't have the same frequency bins.") if len({i.file.time_resolution for i in items if not i.is_empty}) > 1: @@ -168,14 +185,24 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> DataFrame: return joined_df.iloc[:, 1:].T.to_numpy() - - def _get_item_value(self, item: SpectroItem, time_resolution: Timedelta | None = None, freq: np.ndarray | None = None) -> DataFrame: + def _get_item_value( + self, + item: SpectroItem, + time_resolution: Timedelta | None = None, + freq: np.ndarray | None = None, + ) -> DataFrame: """Return the resampled (if needed) data from the Spectro item.""" item_data = item.get_value(freq) if item.is_empty: - time = np.arange(item.duration // time_resolution) * time_resolution.total_seconds() + time = ( + np.arange(item.duration // time_resolution) + * time_resolution.total_seconds() + ) for t in time: - item_data.loc[item_data.shape[0]] = [t, *[-120.] * (item_data.shape[1]-1)] + item_data.loc[item_data.shape[0]] = [ + t, + *[-120.0] * (item_data.shape[1] - 1), + ] return item_data @classmethod @@ -204,12 +231,15 @@ def from_files( The SpectroData object. """ - return cls.from_base_data(BaseData.from_files(files, begin, end)) + f0 = files[0] + fft = ShortTimeFFT(win=f0.window, hop=f0.hop, fs=f0.sample_rate) + return cls.from_base_data(BaseData.from_files(files, begin, end), fft=fft) @classmethod def from_base_data( cls, data: BaseData, + fft: ShortTimeFFT, ) -> SpectroData: """Return an SpectroData object from a BaseData object. @@ -224,7 +254,7 @@ def from_base_data( The SpectroData object. """ - return cls([SpectroItem.from_base_item(item) for item in data.items]) + return cls([SpectroItem.from_base_item(item) for item in data.items], fft=fft) @classmethod def from_audio_data(cls, data: AudioData, fft: ShortTimeFFT) -> SpectroData: diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index 17480218..15deb282 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -59,15 +59,22 @@ def _read_metadata(self, path: PathLike) -> None: sample_rate = data["fs"][0] time = data["time"] freq = data["freq"] - hop = data["hop"][0] + hop = int(data["hop"][0]) window = data["window"] self.sample_rate = sample_rate - delta_times = [Timedelta(seconds=time[i] - time[i-1]).round(freq = "ns") for i in range(1,time.shape[0])] - most_frequent_delta_time = max(((v, delta_times.count(v)) for v in set(delta_times)), key=lambda i: i[1])[0] + delta_times = [ + Timedelta(seconds=time[i] - time[i - 1]).round(freq="ns") + for i in range(1, time.shape[0]) + ] + most_frequent_delta_time = max( + ((v, delta_times.count(v)) for v in set(delta_times)), key=lambda i: i[1] + )[0] self.time_resolution = most_frequent_delta_time - self.end = (self.begin + Timedelta(seconds = time[-1]) + self.time_resolution).round(freq = "us") + self.end = ( + self.begin + Timedelta(seconds=time[-1]) + self.time_resolution + ).round(freq="us") self.freq = freq @@ -95,16 +102,30 @@ def read(self, start: Timestamp, stop: Timestamp) -> pd.DataFrame: with np.load(self.path) as data: time = data["time"] - start_bin = next(idx for idx,t in enumerate(time) if self.begin + Timedelta(seconds = t) > start) - 1 + start_bin = ( + next( + idx + for idx, t in enumerate(time) + if self.begin + Timedelta(seconds=t) > start + ) + - 1 + ) start_bin = max(start_bin, 0) - stop_bin = next(idx for idx,t in list(enumerate(time))[::-1] if self.begin + Timedelta(seconds = t) < stop) + 1 + stop_bin = ( + next( + idx + for idx, t in list(enumerate(time))[::-1] + if self.begin + Timedelta(seconds=t) < stop + ) + + 1 + ) stop_bin = min(stop_bin, time.shape[0]) sx = data["sx"][:, start_bin:stop_bin] time = time[start_bin:stop_bin] - time[start_bin] - return pd.DataFrame({"time": time, **dict(zip(self.freq,sx))}) + return pd.DataFrame({"time": time, **dict(zip(self.freq, sx))}) @classmethod def from_base_file(cls, file: BaseFile) -> SpectroFile: diff --git a/src/OSmOSE/data/spectro_item.py b/src/OSmOSE/data/spectro_item.py index 1bc2aff6..5f243657 100644 --- a/src/OSmOSE/data/spectro_item.py +++ b/src/OSmOSE/data/spectro_item.py @@ -39,7 +39,7 @@ def __init__( """ super().__init__(file, begin, end) - #self.shape = self.get_value().shape + # self.shape = self.get_value().shape @property def time_resolution(self) -> Timedelta: @@ -66,7 +66,7 @@ def get_value(self, freq: np.ndarray | None = None) -> DataFrame: If the Item is empty, return a single 0. """ return ( - DataFrame(columns = ["time", *freq]) + DataFrame(columns=["time", *freq]) if self.is_empty else self.file.read(start=self.begin, stop=self.end) ) From cb90becd80c389f049703f3e2dddee8020424b22 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 9 Jan 2025 15:09:23 +0100 Subject: [PATCH 065/118] add mfft to npz files --- src/OSmOSE/data/spectro_data.py | 4 +++- src/OSmOSE/data/spectro_file.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index d8c9ddd0..0b419ca4 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -151,6 +151,7 @@ def write(self, folder: Path) -> None: window = self.fft.win hop = [self.fft.hop] fs = [self.fft.fs] + mfft = [self.fft.mfft] np.savez( file=folder / f"{self}.npz", fs=fs, @@ -159,6 +160,7 @@ def write(self, folder: Path) -> None: window=window, hop=hop, sx=sx, + mfft=mfft, ) def _get_value_from_items(self, items: list[SpectroItem]) -> DataFrame: @@ -232,7 +234,7 @@ def from_files( """ f0 = files[0] - fft = ShortTimeFFT(win=f0.window, hop=f0.hop, fs=f0.sample_rate) + fft = ShortTimeFFT(win=f0.window, hop=f0.hop, fs=f0.sample_rate, mfft=f0.mfft) return cls.from_base_data(BaseData.from_files(files, begin, end), fft=fft) @classmethod diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index 15deb282..c356cea3 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -61,8 +61,10 @@ def _read_metadata(self, path: PathLike) -> None: freq = data["freq"] hop = int(data["hop"][0]) window = data["window"] + mfft = data["mfft"][0] self.sample_rate = sample_rate + self.mfft = mfft delta_times = [ Timedelta(seconds=time[i] - time[i - 1]).round(freq="ns") From 0a0c74ee53216e7fd8ffd816c047101c6dce2ac4 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 9 Jan 2025 16:29:04 +0100 Subject: [PATCH 066/118] move spectro_item data logic to SpectroItem class --- src/OSmOSE/data/spectro_data.py | 24 ++---------------------- src/OSmOSE/data/spectro_item.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 27 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 0b419ca4..2557b79b 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -177,36 +177,16 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> DataFrame: time_resolution = next(i.file.time_resolution for i in items if not i.is_empty) freq = next(i.file.freq for i in items if not i.is_empty) - joined_df = self._get_item_value(items[0], time_resolution, freq) + joined_df = items[0].get_value(freq=freq, time_resolution=time_resolution) for item in items[1:]: time_offset = joined_df["time"].iloc[-1] + time_resolution.total_seconds() - item_data = self._get_item_value(item, time_resolution, freq) + item_data = item.get_value(freq=freq, time_resolution=time_resolution) item_data["time"] += time_offset joined_df = pd.concat((joined_df, item_data)) return joined_df.iloc[:, 1:].T.to_numpy() - def _get_item_value( - self, - item: SpectroItem, - time_resolution: Timedelta | None = None, - freq: np.ndarray | None = None, - ) -> DataFrame: - """Return the resampled (if needed) data from the Spectro item.""" - item_data = item.get_value(freq) - if item.is_empty: - time = ( - np.arange(item.duration // time_resolution) - * time_resolution.total_seconds() - ) - for t in time: - item_data.loc[item_data.shape[0]] = [ - t, - *[-120.0] * (item_data.shape[1] - 1), - ] - return item_data - @classmethod def from_files( cls, diff --git a/src/OSmOSE/data/spectro_item.py b/src/OSmOSE/data/spectro_item.py index 5f243657..a8c00c23 100644 --- a/src/OSmOSE/data/spectro_item.py +++ b/src/OSmOSE/data/spectro_item.py @@ -60,13 +60,22 @@ def from_base_item(cls, item: BaseItem) -> SpectroItem: ) raise TypeError - def get_value(self, freq: np.ndarray | None = None) -> DataFrame: + def get_value(self, freq: np.ndarray | None = None, time_resolution: Timedelta | None = None) -> DataFrame: """Get the values from the File between the begin and stop timestamps. If the Item is empty, return a single 0. """ - return ( - DataFrame(columns=["time", *freq]) - if self.is_empty - else self.file.read(start=self.begin, stop=self.end) + if not self.is_empty: + return self.file.read(start=self.begin, stop=self.end) + + output_df = DataFrame(columns=["time", *freq]) + time = ( + np.arange(self.duration // time_resolution) + * time_resolution.total_seconds() ) + for t in time: + output_df.loc[output_df.shape[0]] = [ + t, + *[-120.0] * (output_df.shape[1] - 1), + ] + return output_df From fa5e64746655f9e72d483d0314d93400ae64f68c Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 9 Jan 2025 17:27:21 +0100 Subject: [PATCH 067/118] move fft from npz logic to SpectroFile --- src/OSmOSE/data/spectro_data.py | 4 +--- src/OSmOSE/data/spectro_file.py | 4 ++++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 2557b79b..17b6a4ad 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -213,9 +213,7 @@ def from_files( The SpectroData object. """ - f0 = files[0] - fft = ShortTimeFFT(win=f0.window, hop=f0.hop, fs=f0.sample_rate, mfft=f0.mfft) - return cls.from_base_data(BaseData.from_files(files, begin, end), fft=fft) + return cls.from_base_data(BaseData.from_files(files, begin, end), fft=files[0].get_fft()) @classmethod def from_base_data( diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index c356cea3..fa278e00 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -11,6 +11,7 @@ import numpy as np import pandas as pd from pandas import Timedelta, Timestamp +from scipy.signal import ShortTimeFFT from OSmOSE.data.base_file import BaseFile @@ -129,6 +130,9 @@ def read(self, start: Timestamp, stop: Timestamp) -> pd.DataFrame: return pd.DataFrame({"time": time, **dict(zip(self.freq, sx))}) + def get_fft(self) -> ShortTimeFFT: + return ShortTimeFFT(win=self.window, hop=self.hop, fs=self.sample_rate, mfft=self.mfft) + @classmethod def from_base_file(cls, file: BaseFile) -> SpectroFile: """Return a SpectroFile object from a BaseFile object.""" From 2525c440af4401dc01310c7c4e78e7cf72857fd5 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 9 Jan 2025 18:14:07 +0100 Subject: [PATCH 068/118] spectro data objects pass np.ndarray values --- src/OSmOSE/data/spectro_data.py | 17 ++--------------- src/OSmOSE/data/spectro_file.py | 10 +++------- src/OSmOSE/data/spectro_item.py | 16 +++------------- 3 files changed, 8 insertions(+), 35 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 17b6a4ad..a7470e80 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -10,8 +10,6 @@ import matplotlib.pyplot as plt import numpy as np -import pandas as pd -from pandas import DataFrame from scipy.signal import ShortTimeFFT from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES @@ -163,7 +161,7 @@ def write(self, folder: Path) -> None: mfft=mfft, ) - def _get_value_from_items(self, items: list[SpectroItem]) -> DataFrame: + def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: if not all( np.array_equal(items[0].file.freq, i.file.freq) for i in items[1:] @@ -174,18 +172,7 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> DataFrame: if len({i.file.time_resolution for i in items if not i.is_empty}) > 1: raise ValueError("Items don't have the same time resolution.") - time_resolution = next(i.file.time_resolution for i in items if not i.is_empty) - freq = next(i.file.freq for i in items if not i.is_empty) - - joined_df = items[0].get_value(freq=freq, time_resolution=time_resolution) - - for item in items[1:]: - time_offset = joined_df["time"].iloc[-1] + time_resolution.total_seconds() - item_data = item.get_value(freq=freq, time_resolution=time_resolution) - item_data["time"] += time_offset - joined_df = pd.concat((joined_df, item_data)) - - return joined_df.iloc[:, 1:].T.to_numpy() + return np.hstack(tuple(item.get_value(self.fft) for item in items)) @classmethod def from_files( diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index fa278e00..ce984e31 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING import numpy as np -import pandas as pd from pandas import Timedelta, Timestamp from scipy.signal import ShortTimeFFT @@ -72,7 +71,7 @@ def _read_metadata(self, path: PathLike) -> None: for i in range(1, time.shape[0]) ] most_frequent_delta_time = max( - ((v, delta_times.count(v)) for v in set(delta_times)), key=lambda i: i[1] + ((v, delta_times.count(v)) for v in set(delta_times)), key=lambda i: i[1], )[0] self.time_resolution = most_frequent_delta_time self.end = ( @@ -84,7 +83,7 @@ def _read_metadata(self, path: PathLike) -> None: self.window = window self.hop = hop - def read(self, start: Timestamp, stop: Timestamp) -> pd.DataFrame: + def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the spectro data between start and stop from the file. The data is a 2D array representing the sxx values of the spectrogram. @@ -125,10 +124,7 @@ def read(self, start: Timestamp, stop: Timestamp) -> pd.DataFrame: ) stop_bin = min(stop_bin, time.shape[0]) - sx = data["sx"][:, start_bin:stop_bin] - time = time[start_bin:stop_bin] - time[start_bin] - - return pd.DataFrame({"time": time, **dict(zip(self.freq, sx))}) + return data["sx"][:, start_bin:stop_bin] def get_fft(self) -> ShortTimeFFT: return ShortTimeFFT(win=self.window, hop=self.hop, fs=self.sample_rate, mfft=self.mfft) diff --git a/src/OSmOSE/data/spectro_item.py b/src/OSmOSE/data/spectro_item.py index a8c00c23..5ded60bd 100644 --- a/src/OSmOSE/data/spectro_item.py +++ b/src/OSmOSE/data/spectro_item.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING import numpy as np -from pandas import DataFrame +from scipy.signal import ShortTimeFFT from OSmOSE.data.base_file import BaseFile from OSmOSE.data.base_item import BaseItem @@ -60,7 +60,7 @@ def from_base_item(cls, item: BaseItem) -> SpectroItem: ) raise TypeError - def get_value(self, freq: np.ndarray | None = None, time_resolution: Timedelta | None = None) -> DataFrame: + def get_value(self, fft: ShortTimeFFT) -> np.ndarray: """Get the values from the File between the begin and stop timestamps. If the Item is empty, return a single 0. @@ -68,14 +68,4 @@ def get_value(self, freq: np.ndarray | None = None, time_resolution: Timedelta | if not self.is_empty: return self.file.read(start=self.begin, stop=self.end) - output_df = DataFrame(columns=["time", *freq]) - time = ( - np.arange(self.duration // time_resolution) - * time_resolution.total_seconds() - ) - for t in time: - output_df.loc[output_df.shape[0]] = [ - t, - *[-120.0] * (output_df.shape[1] - 1), - ] - return output_df + return np.ones((fft.f.shape[0], fft.p_num(int(self.duration.total_seconds() * fft.fs)))) * -120. From 6b622f205b74825fe09d0ac8ec6cde2798494201 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Fri, 10 Jan 2025 11:48:44 +0100 Subject: [PATCH 069/118] resolve overlaps between joined npz files --- src/OSmOSE/data/spectro_data.py | 20 ++++++++++++++++---- src/OSmOSE/data/spectro_file.py | 7 +++++-- src/OSmOSE/data/spectro_item.py | 9 +++++++-- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index a7470e80..9764d5c8 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -117,7 +117,7 @@ def get_value(self) -> np.ndarray: if not self.audio_data or not self.fft: raise ValueError("SpectroData has not been initialized") - sx = self.fft.spectrogram(self.audio_data.get_value()) + sx = self.fft.spectrogram(self.audio_data.get_value(), padding="even") return 10 * np.log10(abs(sx) + np.nextafter(0, 1)) def plot(self, ax: plt.Axes | None = None) -> None: @@ -169,10 +169,20 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: ): raise ValueError("Items don't have the same frequency bins.") - if len({i.file.time_resolution for i in items if not i.is_empty}) > 1: + if len({i.file.get_fft().delta_t for i in items if not i.is_empty}) > 1: raise ValueError("Items don't have the same time resolution.") - return np.hstack(tuple(item.get_value(self.fft) for item in items)) + output = items[0].get_value(fft=self.fft) + for item in items[1:]: + p1_le = self.fft.lower_border_end[1] - self.fft.p_min - 1 + output = np.hstack( + ( + output[:, :-p1_le], + (output[:, -p1_le:] + item.get_value(fft=self.fft)[:, :p1_le]) / 2, + item.get_value(fft=self.fft)[:, p1_le:], + ) + ) + return output @classmethod def from_files( @@ -200,7 +210,9 @@ def from_files( The SpectroData object. """ - return cls.from_base_data(BaseData.from_files(files, begin, end), fft=files[0].get_fft()) + return cls.from_base_data( + BaseData.from_files(files, begin, end), fft=files[0].get_fft() + ) @classmethod def from_base_data( diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index ce984e31..29c3cc19 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -71,7 +71,8 @@ def _read_metadata(self, path: PathLike) -> None: for i in range(1, time.shape[0]) ] most_frequent_delta_time = max( - ((v, delta_times.count(v)) for v in set(delta_times)), key=lambda i: i[1], + ((v, delta_times.count(v)) for v in set(delta_times)), + key=lambda i: i[1], )[0] self.time_resolution = most_frequent_delta_time self.end = ( @@ -127,7 +128,9 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: return data["sx"][:, start_bin:stop_bin] def get_fft(self) -> ShortTimeFFT: - return ShortTimeFFT(win=self.window, hop=self.hop, fs=self.sample_rate, mfft=self.mfft) + return ShortTimeFFT( + win=self.window, hop=self.hop, fs=self.sample_rate, mfft=self.mfft + ) @classmethod def from_base_file(cls, file: BaseFile) -> SpectroFile: diff --git a/src/OSmOSE/data/spectro_item.py b/src/OSmOSE/data/spectro_item.py index 5ded60bd..8e0cc7b5 100644 --- a/src/OSmOSE/data/spectro_item.py +++ b/src/OSmOSE/data/spectro_item.py @@ -60,7 +60,7 @@ def from_base_item(cls, item: BaseItem) -> SpectroItem: ) raise TypeError - def get_value(self, fft: ShortTimeFFT) -> np.ndarray: + def get_value(self, fft: ShortTimeFFT | None = None) -> np.ndarray: """Get the values from the File between the begin and stop timestamps. If the Item is empty, return a single 0. @@ -68,4 +68,9 @@ def get_value(self, fft: ShortTimeFFT) -> np.ndarray: if not self.is_empty: return self.file.read(start=self.begin, stop=self.end) - return np.ones((fft.f.shape[0], fft.p_num(int(self.duration.total_seconds() * fft.fs)))) * -120. + return ( + np.ones( + (fft.f.shape[0], fft.p_num(int(self.duration.total_seconds() * fft.fs))) + ) + * -120.0 + ) From a35dbd6ca024e3bb6946618891a87e4be0a2226a Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 13 Jan 2025 11:16:50 +0100 Subject: [PATCH 070/118] fix spectro data shape --- src/OSmOSE/data/spectro_data.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 9764d5c8..afc9f827 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -93,20 +93,14 @@ def get_default_ax() -> plt.Axes: @property def shape(self) -> tuple[int, ...]: """Shape of the Spectro data.""" - return max(item.shape[0] for item in self.items), sum( - item.shape[1] for item in self.items + return self.fft.f_pts, self.fft.p_num( + int(self.fft.fs * self.duration.total_seconds()) ) def __str__(self) -> str: """Overwrite __str__.""" return self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES) - def _set_time_resolution(self, time_resolution: Timedelta) -> None: - """Set the SpectroFile time resolution.""" - if len(tr := {item.time_resolution for item in self.items}) > 1: - raise ValueError("Items don't have the same time resolution") - self.time_resolution = tr.pop() if len(tr) == 1 else time_resolution - def get_value(self) -> np.ndarray: """Return the value of the Spectro data. @@ -180,7 +174,7 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: output[:, :-p1_le], (output[:, -p1_le:] + item.get_value(fft=self.fft)[:, :p1_le]) / 2, item.get_value(fft=self.fft)[:, p1_le:], - ) + ), ) return output @@ -211,7 +205,8 @@ def from_files( """ return cls.from_base_data( - BaseData.from_files(files, begin, end), fft=files[0].get_fft() + BaseData.from_files(files, begin, end), + fft=files[0].get_fft(), ) @classmethod From 0737750a572b546a740d3c29b0ff44c0de81d0ae Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 13 Jan 2025 17:43:36 +0100 Subject: [PATCH 071/118] add SpectroData docstrings --- src/OSmOSE/data/spectro_data.py | 70 +++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 11 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index afc9f827..071232d3 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -10,10 +10,8 @@ import matplotlib.pyplot as plt import numpy as np -from scipy.signal import ShortTimeFFT from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES -from OSmOSE.data.audio_data import AudioData from OSmOSE.data.base_data import BaseData from OSmOSE.data.spectro_file import SpectroFile from OSmOSE.data.spectro_item import SpectroItem @@ -21,7 +19,10 @@ if TYPE_CHECKING: from pathlib import Path - from pandas import Timedelta, Timestamp + from pandas import Timestamp + from scipy.signal import ShortTimeFFT + + from OSmOSE.data.audio_data import AudioData class SpectroData(BaseData[SpectroItem, SpectroFile]): @@ -37,7 +38,6 @@ def __init__( audio_data: AudioData = None, begin: Timestamp | None = None, end: Timestamp | None = None, - time_resolution: Timedelta | None = None, fft: ShortTimeFFT | None = None, ) -> None: """Initialize a SpectroData from a list of SpectroItems. @@ -46,24 +46,35 @@ def __init__( ---------- items: list[SpectroItem] List of the SpectroItem constituting the SpectroData. - time_resolution: Timedelta - The time resolution of the Spectro data. + audio_data: AudioData + The audio data from which to compute the spectrogram. begin: Timestamp | None Only effective if items is None. Set the begin of the empty data. end: Timestamp | None Only effective if items is None. Set the end of the empty data. + fft: ShortTimeFFT + The short time FFT used for computing the spectrogram. """ super().__init__(items=items, begin=begin, end=end) - # self._set_time_resolution(time_resolution=time_resolution) self.audio_data = audio_data self.fft = fft @staticmethod def get_default_ax() -> plt.Axes: + """Return a default-formatted Axes on a new figure. + + The default OSmOSE spectrograms are plotted on wide, borderless spectrograms. + This method set the default figure and axes parameters. + Returns + ------- + plt.Axes: + The default Axes on a new figure. + + """ # Legacy OSEkit behaviour. _, ax = plt.subplots( nrows=1, @@ -94,7 +105,7 @@ def get_default_ax() -> plt.Axes: def shape(self) -> tuple[int, ...]: """Shape of the Spectro data.""" return self.fft.f_pts, self.fft.p_num( - int(self.fft.fs * self.duration.total_seconds()) + int(self.fft.fs * self.duration.total_seconds()), ) def __str__(self) -> str: @@ -102,19 +113,28 @@ def __str__(self) -> str: return self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES) def get_value(self) -> np.ndarray: - """Return the value of the Spectro data. + """Return the Sx matrix of the spectrogram. - The data from the Spectro file will be resampled if necessary. + The Sx matrix contains the absolute square of the STFT. """ if not all(item.is_empty for item in self.items): return self._get_value_from_items(self.items) if not self.audio_data or not self.fft: - raise ValueError("SpectroData has not been initialized") + raise ValueError("SpectroData should have either items or audio_data.") sx = self.fft.spectrogram(self.audio_data.get_value(), padding="even") return 10 * np.log10(abs(sx) + np.nextafter(0, 1)) def plot(self, ax: plt.Axes | None = None) -> None: + """Plot the spectrogram on a specific Axes. + + Parameters + ---------- + ax: plt.axes | None + Axes on which the spectrogram should be plotted. + Defaulted as the SpectroData.get_default_ax Axes. + + """ ax = ax if ax is not None else SpectroData.get_default_ax() sx = self.get_value() time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] @@ -122,6 +142,17 @@ def plot(self, ax: plt.Axes | None = None) -> None: ax.pcolormesh(time, freq, sx) def save_spectrogram(self, folder: Path, ax: plt.Axes | None = None) -> None: + """Export the spectrogram as a png image. + + Parameters + ---------- + folder: Path + Folder in which the spectrogram should be saved. + ax: plt.Axes | None + Axes on which the spectrogram should be plotted. + Defaulted as the SpectroData.get_default_ax Axes. + + """ super().write(folder) self.plot(ax) plt.savefig(f"{folder / str(self)}", bbox_inches="tight", pad_inches=0) @@ -221,6 +252,8 @@ def from_base_data( ---------- data: BaseData BaseData object to convert to SpectroData. + fft: ShortTimeFFT + The ShortTimeFFT used to compute the spectrogram. Returns ------- @@ -232,4 +265,19 @@ def from_base_data( @classmethod def from_audio_data(cls, data: AudioData, fft: ShortTimeFFT) -> SpectroData: + """Instantiate a SpectroData object from a AudioData object. + + Parameters + ---------- + data: AudioData + Audio data from which the SpectroData should be computed. + fft: ShortTimeFFT + The ShortTimeFFT used to compute the spectrogram. + + Returns + ------- + SpectroData: + The SpectroData object. + + """ return cls(audio_data=data, fft=fft, begin=data.begin, end=data.end) From c90058e742248b66843d2556dc451357b94b9014 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 13 Jan 2025 17:53:17 +0100 Subject: [PATCH 072/118] add docstrings to spectro items and files --- src/OSmOSE/data/spectro_file.py | 14 +++++++++++++- src/OSmOSE/data/spectro_item.py | 8 +++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index 29c3cc19..5c09c018 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -128,8 +128,20 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: return data["sx"][:, start_bin:stop_bin] def get_fft(self) -> ShortTimeFFT: + """Return the ShortTimeFFT used for computing the spectrogram. + + Returns + ------- + ShortTimeFFT: + The ShortTimeFFT used for computing the spectrogram. + It is instantiated back from the parameters stored in the npz file. + + """ return ShortTimeFFT( - win=self.window, hop=self.hop, fs=self.sample_rate, mfft=self.mfft + win=self.window, + hop=self.hop, + fs=self.sample_rate, + mfft=self.mfft, ) @classmethod diff --git a/src/OSmOSE/data/spectro_item.py b/src/OSmOSE/data/spectro_item.py index 8e0cc7b5..18c67099 100644 --- a/src/OSmOSE/data/spectro_item.py +++ b/src/OSmOSE/data/spectro_item.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING import numpy as np -from scipy.signal import ShortTimeFFT from OSmOSE.data.base_file import BaseFile from OSmOSE.data.base_item import BaseItem @@ -13,6 +12,7 @@ if TYPE_CHECKING: from pandas import Timedelta, Timestamp + from scipy.signal import ShortTimeFFT class SpectroItem(BaseItem[SpectroFile]): @@ -39,7 +39,6 @@ def __init__( """ super().__init__(file, begin, end) - # self.shape = self.get_value().shape @property def time_resolution(self) -> Timedelta: @@ -70,7 +69,10 @@ def get_value(self, fft: ShortTimeFFT | None = None) -> np.ndarray: return ( np.ones( - (fft.f.shape[0], fft.p_num(int(self.duration.total_seconds() * fft.fs))) + ( + fft.f.shape[0], + fft.p_num(int(self.duration.total_seconds() * fft.fs)), + ), ) * -120.0 ) From e55409d409d4810477ab5e10585ff64b1ac3001e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 14 Jan 2025 14:04:41 +0100 Subject: [PATCH 073/118] add dataset files property --- src/OSmOSE/data/base_dataset.py | 10 ++++++++++ src/OSmOSE/data/base_file.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/src/OSmOSE/data/base_dataset.py b/src/OSmOSE/data/base_dataset.py index deb57b84..a4d89315 100644 --- a/src/OSmOSE/data/base_dataset.py +++ b/src/OSmOSE/data/base_dataset.py @@ -42,6 +42,16 @@ def end(self) -> Timestamp: """End of the last data object.""" return max(data.end for data in self.data) + @property + def files(self) -> set[TFile]: + """All files referred to by the Dataset.""" + return { + item.file + for data in self.data + for item in data.items + if item.file is not None + } + def write(self, folder: Path) -> None: """Write all data objects in the specified folder. diff --git a/src/OSmOSE/data/base_file.py b/src/OSmOSE/data/base_file.py index 8ceda5c9..41f5323c 100644 --- a/src/OSmOSE/data/base_file.py +++ b/src/OSmOSE/data/base_file.py @@ -86,3 +86,7 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: The data between start and stop. """ + + def __hash__(self) -> int: + """Overwrite hash magic method.""" + return hash(self.path) From c330f619becc6a91c7e382e42b4ac7120595eba1 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 14 Jan 2025 15:50:41 +0100 Subject: [PATCH 074/118] add spectro dataset --- src/OSmOSE/data/spectro_dataset.py | 108 +++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 src/OSmOSE/data/spectro_dataset.py diff --git a/src/OSmOSE/data/spectro_dataset.py b/src/OSmOSE/data/spectro_dataset.py new file mode 100644 index 00000000..c3bb28d6 --- /dev/null +++ b/src/OSmOSE/data/spectro_dataset.py @@ -0,0 +1,108 @@ +"""SpectroDataset is a collection of SpectroData objects. + +SpectroDataset is a collection of SpectroData, with methods +that simplify repeated operations on the spectro data. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from OSmOSE.data.base_dataset import BaseDataset +from OSmOSE.data.spectro_data import SpectroData +from OSmOSE.data.spectro_file import SpectroFile + +if TYPE_CHECKING: + from pathlib import Path + + from pandas import Timedelta, Timestamp + from scipy.signal import ShortTimeFFT + + from OSmOSE.data.audio_dataset import AudioDataset + + +class SpectroDataset(BaseDataset[SpectroData, SpectroFile]): + """SpectroDataset is a collection of SpectroData objects. + + SpectroDataset is a collection of SpectroData, with methods + that simplify repeated operations on the spectro data. + + """ + + def __init__(self, data: list[SpectroData]) -> None: + """Initialize a SpectroDataset.""" + super().__init__(data) + + @property + def fft(self) -> ShortTimeFFT: + """Return the fft of the spectro data.""" + return next(data.fft for data in self.data) + + @fft.setter + def fft(self, fft: ShortTimeFFT) -> None: + for data in self.data: + data.fft = fft + + @classmethod + def from_folder( + cls, + folder: Path, + strptime_format: str, + begin: Timestamp | None = None, + end: Timestamp | None = None, + data_duration: Timedelta | None = None, + ) -> SpectroDataset: + """Return a SpectroDataset from a folder containing the spectro files. + + Parameters + ---------- + folder: Path + The folder containing the spectro files. + strptime_format: str + The strptime format of the timestamps in the spectro file names. + begin: Timestamp | None + The begin of the spectro dataset. + Defaulted to the begin of the first file. + end: Timestamp | None + The end of the spectro dataset. + Defaulted to the end of the last file. + data_duration: Timedelta | None + Duration of the spectro data objects. + If provided, spectro data will be evenly distributed between begin and end. + Else, one data object will cover the whole time period. + + Returns + ------- + Spectrodataset: + The spectro dataset. + + """ + files = [ + SpectroFile(file, strptime_format=strptime_format) + for file in folder.glob("*.npz") + ] + base_dataset = BaseDataset.from_files(files, begin, end, data_duration) + return cls.from_base_dataset(base_dataset, files[0].get_fft()) + + @classmethod + def from_base_dataset( + cls, + base_dataset: BaseDataset, + fft: ShortTimeFFT, + ) -> SpectroDataset: + """Return a SpectroDataset object from a BaseDataset object.""" + return cls( + [SpectroData.from_base_data(data, fft) for data in base_dataset.data], + ) + + @classmethod + def from_audio_dataset( + cls, + audio_dataset: AudioDataset, + fft: ShortTimeFFT, + ) -> SpectroDataset: + """Return a SpectroDataset object from an AudioDataset object. + + The SpectroData is computed from the AudioData using the given fft. + """ + return cls([SpectroData.from_audio_data(d, fft) for d in audio_dataset.data]) From cebbc6776f2609d716f4397f569a7c735fde23d9 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 15 Jan 2025 10:13:06 +0100 Subject: [PATCH 075/118] split files property between data and dataset --- src/OSmOSE/data/base_data.py | 5 +++++ src/OSmOSE/data/base_dataset.py | 7 +------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index 21882094..ceeca68b 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -75,6 +75,11 @@ def write(self, path: Path) -> None: """ path.mkdir(parents=True, exist_ok=True, mode=DPDEFAULT) + @property + def files(self) -> set[TFile]: + """All files referred to by the Data.""" + return {item.file for item in self.items if item.file is not None} + @classmethod def from_files( cls, diff --git a/src/OSmOSE/data/base_dataset.py b/src/OSmOSE/data/base_dataset.py index a4d89315..ade08cc9 100644 --- a/src/OSmOSE/data/base_dataset.py +++ b/src/OSmOSE/data/base_dataset.py @@ -45,12 +45,7 @@ def end(self) -> Timestamp: @property def files(self) -> set[TFile]: """All files referred to by the Dataset.""" - return { - item.file - for data in self.data - for item in data.items - if item.file is not None - } + return {file for data in self.data for file in data.files} def write(self, folder: Path) -> None: """Write all data objects in the specified folder. From 572c89deab7da519c11e6f45f9f5cdfe90f87184 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 15 Jan 2025 13:53:42 +0100 Subject: [PATCH 076/118] fix empty AudioData sample_rate --- src/OSmOSE/data/audio_data.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 3904a16b..8d534ede 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -81,14 +81,16 @@ def _set_sample_rate(self, sample_rate: int | None = None) -> None: first item that has one. Else, it is set to None. """ - if sample_rate is not None or any( - sample_rate := item.sample_rate - for item in self.items - if item.sample_rate is not None - ): + if sample_rate is not None: self.sample_rate = sample_rate - else: - self.sample_rate = None + return + if sr := next( + (item.sample_rate for item in self.items if item.sample_rate is not None), + None, + ): + self.sample_rate = sr + return + self.sample_rate = None def get_value(self) -> np.ndarray: """Return the value of the audio data. @@ -120,7 +122,7 @@ def _get_item_value(self, item: AudioItem) -> np.ndarray: item_data = item.get_value() if item.is_empty: return item_data.repeat( - int(item.duration.total_seconds() * self.sample_rate) + int(item.duration.total_seconds() * self.sample_rate), ) if item.sample_rate != self.sample_rate: return resample(item_data, item.sample_rate, self.sample_rate) @@ -178,4 +180,7 @@ def from_base_data( The AudioData object. """ - return cls([AudioItem.from_base_item(item) for item in data.items], sample_rate) + return cls( + items=[AudioItem.from_base_item(item) for item in data.items], + sample_rate=sample_rate, + ) From 000480fe3d475f1f2f87f7466958284e3f3af462 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 15 Jan 2025 15:55:07 +0100 Subject: [PATCH 077/118] add data divide method --- src/OSmOSE/data/audio_data.py | 7 +++++-- src/OSmOSE/data/audio_item.py | 13 +++++++++++++ src/OSmOSE/data/base_data.py | 9 +++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 8d534ede..5868d7ef 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -66,7 +66,7 @@ def nb_channels(self) -> int: @property def shape(self) -> tuple[int, ...]: """Shape of the audio data.""" - data_length = int(self.sample_rate * self.duration.total_seconds()) + data_length = round(self.sample_rate * self.duration.total_seconds()) return data_length if self.nb_channels <= 1 else (data_length, self.nb_channels) def __str__(self) -> str: @@ -122,12 +122,15 @@ def _get_item_value(self, item: AudioItem) -> np.ndarray: item_data = item.get_value() if item.is_empty: return item_data.repeat( - int(item.duration.total_seconds() * self.sample_rate), + round(item.duration.total_seconds() * self.sample_rate), ) if item.sample_rate != self.sample_rate: return resample(item_data, item.sample_rate, self.sample_rate) return item_data + def divide(self, nb_subdata: int = 2) -> list[AudioData]: + return [AudioData.from_base_data(base_data, self.sample_rate) for base_data in super().divide(nb_subdata)] + @classmethod def from_files( cls, diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index 0b4128e4..1e76226c 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -9,6 +9,7 @@ from OSmOSE.data.base_item import BaseItem if TYPE_CHECKING: + import numpy as np from pandas import Timestamp @@ -47,6 +48,18 @@ def nb_channels(self) -> int: """Number of channels in the associated AudioFile.""" return 0 if self.is_empty else self.file.metadata.channels + @property + def shape(self) -> int: + """Number of points in the audio item data.""" + return round(self.sample_rate * self.duration.total_seconds()) + + def get_value(self) -> np.ndarray: + """Get the values from the File between the begin and stop timestamps. + + If the Item is empty, return a single 0. + """ + return super().get_value()[:self.shape] + @classmethod def from_base_item(cls, item: BaseItem) -> AudioItem: """Return an AudioItem object from a BaseItem object.""" diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index ceeca68b..b2ba0636 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np +from pandas import date_range from OSmOSE.config import DPDEFAULT from OSmOSE.data.base_file import BaseFile @@ -80,6 +81,14 @@ def files(self) -> set[TFile]: """All files referred to by the Data.""" return {item.file for item in self.items if item.file is not None} + def divide(self, nb_subdata: int = 2) -> list[BaseData]: + dates = date_range(self.begin, self.end, periods=nb_subdata + 1) + subdata_dates = zip(dates, dates[1:]) + return [ + BaseData.from_files(files=list(self.files), begin=b, end=e) + for b, e in subdata_dates + ] + @classmethod def from_files( cls, From 3721d51790ec565a153984e9536475d6ca8f68e1 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 15 Jan 2025 15:55:34 +0100 Subject: [PATCH 078/118] add spectro data ltas first version --- src/OSmOSE/data/spectro_data.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 071232d3..a023bcb3 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -122,10 +122,23 @@ def get_value(self) -> np.ndarray: if not self.audio_data or not self.fft: raise ValueError("SpectroData should have either items or audio_data.") - sx = self.fft.spectrogram(self.audio_data.get_value(), padding="even") - return 10 * np.log10(abs(sx) + np.nextafter(0, 1)) - - def plot(self, ax: plt.Axes | None = None) -> None: + return self.fft.spectrogram(self.audio_data.get_value(), padding="even") + + def get_ltas_value(self, nb_windows: int = 1920) -> np.ndarray: + if self.shape[1] <= nb_windows: + return self.get_value() + sub_spectros = [ + SpectroData.from_audio_data(ad, self.fft) + for ad in self.audio_data.divide(nb_windows) + ] + return np.vstack( + [ + np.mean(sub_spectro.get_ltas_value(nb_windows), axis=1) + for sub_spectro in sub_spectros + ] + ).T + + def plot(self, ax: plt.Axes | None = None, nb_ltas_windows: int = 0) -> None: """Plot the spectrogram on a specific Axes. Parameters @@ -136,10 +149,15 @@ def plot(self, ax: plt.Axes | None = None) -> None: """ ax = ax if ax is not None else SpectroData.get_default_ax() - sx = self.get_value() + sx = ( + self.get_value() + if nb_ltas_windows == 0 + else self.get_ltas_value(nb_ltas_windows) + ) + sx = 10 * np.log10(abs(sx) + np.nextafter(0, 1)) time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f - ax.pcolormesh(time, freq, sx) + ax.pcolormesh(time, freq, sx, vmin=-120, vmax=0) def save_spectrogram(self, folder: Path, ax: plt.Axes | None = None) -> None: """Export the spectrogram as a png image. From 673df23f7bbb2fbcd594449865fe77a439458a0a Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 15 Jan 2025 18:10:20 +0100 Subject: [PATCH 079/118] add AudioFileManager --- src/OSmOSE/data/__init__.py | 3 +++ src/OSmOSE/data/audio_file.py | 13 ++++++---- src/OSmOSE/data/audio_file_manager.py | 36 +++++++++++++++++++++++++++ src/OSmOSE/data/audio_item.py | 4 +-- 4 files changed, 49 insertions(+), 7 deletions(-) create mode 100644 src/OSmOSE/data/audio_file_manager.py diff --git a/src/OSmOSE/data/__init__.py b/src/OSmOSE/data/__init__.py index e69de29b..4d2511df 100644 --- a/src/OSmOSE/data/__init__.py +++ b/src/OSmOSE/data/__init__.py @@ -0,0 +1,3 @@ +from OSmOSE.data.audio_file_manager import AudioFileManager + +audio_file_manager = AudioFileManager() diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index c47160cb..bc7404b1 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -8,9 +8,9 @@ from os import PathLike import numpy as np -import soundfile as sf from pandas import Timedelta, Timestamp +from OSmOSE.data import audio_file_manager as afm from OSmOSE.data.base_file import BaseFile @@ -44,8 +44,11 @@ def __init__( """ super().__init__(path=path, begin=begin, strptime_format=strptime_format) - self.metadata = sf.info(path) - self.end = self.begin + Timedelta(seconds=self.metadata.duration) + sample_rate, frames, channels = afm.info(path) + duration = frames / sample_rate + self.sample_rate = sample_rate + self.channels = channels + self.end = self.begin + Timedelta(seconds=duration) def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the audio data between start and stop from the file. @@ -63,10 +66,10 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: The audio data between start and stop. """ - sample_rate = self.metadata.samplerate + sample_rate = self.sample_rate start_sample = round((start - self.begin).total_seconds() * sample_rate) stop_sample = round((stop - self.begin).total_seconds() * sample_rate) - return sf.read(self.path, start=start_sample, stop=stop_sample)[0] + return afm.read(self.path, start=start_sample, stop=stop_sample) @classmethod def from_base_file(cls, file: BaseFile) -> AudioFile: diff --git a/src/OSmOSE/data/audio_file_manager.py b/src/OSmOSE/data/audio_file_manager.py new file mode 100644 index 00000000..bb552ded --- /dev/null +++ b/src/OSmOSE/data/audio_file_manager.py @@ -0,0 +1,36 @@ +import soundfile as sf + + +class AudioFileManager: + def __init__(self): + self.opened_file = None + self.calls = 0 + self.opens = 0 + + def close(self): + if self.opened_file is None: + return + self.opened_file.close() + self.opened_file = None + + def open(self, path): + self.opened_file = sf.SoundFile(path, "r") + + def switch(self, path): + self.calls += 1 + if self.opened_file is None: + self.open(path) + if self.opened_file.name == str(path): + return + self.close() + self.open(path) + self.opens += 1 + + def read(self, path, start: int, stop: int): + self.switch(path) + self.opened_file.seek(start) + return self.opened_file.read(stop-start) + + def info(self, path): + self.switch(path) + return self.opened_file.samplerate, self.opened_file.frames, self.opened_file.channels diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index 1e76226c..48179b85 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -41,12 +41,12 @@ def __init__( @property def sample_rate(self) -> float: """Sample rate of the associated AudioFile.""" - return None if self.is_empty else self.file.metadata.samplerate + return None if self.is_empty else self.file.sample_rate @property def nb_channels(self) -> int: """Number of channels in the associated AudioFile.""" - return 0 if self.is_empty else self.file.metadata.channels + return 0 if self.is_empty else self.file.channels @property def shape(self) -> int: From 33726bc8940d68eae75e5b21bd9bded7686dd762 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 16 Jan 2025 09:49:14 +0100 Subject: [PATCH 080/118] add docstrings for the AudioFileManager --- src/OSmOSE/data/audio_data.py | 5 +- src/OSmOSE/data/audio_file_manager.py | 84 +++++++++++++++++++++------ src/OSmOSE/data/audio_item.py | 2 +- 3 files changed, 72 insertions(+), 19 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 5868d7ef..e5f4b274 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -129,7 +129,10 @@ def _get_item_value(self, item: AudioItem) -> np.ndarray: return item_data def divide(self, nb_subdata: int = 2) -> list[AudioData]: - return [AudioData.from_base_data(base_data, self.sample_rate) for base_data in super().divide(nb_subdata)] + return [ + AudioData.from_base_data(base_data, self.sample_rate) + for base_data in super().divide(nb_subdata) + ] @classmethod def from_files( diff --git a/src/OSmOSE/data/audio_file_manager.py b/src/OSmOSE/data/audio_file_manager.py index bb552ded..96a5e2d3 100644 --- a/src/OSmOSE/data/audio_file_manager.py +++ b/src/OSmOSE/data/audio_file_manager.py @@ -1,36 +1,86 @@ +"""Audio File Manager which keeps an audio file open until a request in another file is made. + +This workflow avoids closing/opening a same file repeatedly. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + import soundfile as sf +if TYPE_CHECKING: + from os import PathLike + + import numpy as np + class AudioFileManager: - def __init__(self): + """Audio File Manager which keeps an audio file open until a request in another file is made.""" + + def __init__(self) -> None: + """Initialize an audio file manager.""" self.opened_file = None - self.calls = 0 - self.opens = 0 - def close(self): + def _close(self) -> None: if self.opened_file is None: return self.opened_file.close() self.opened_file = None - def open(self, path): + def _open(self, path: PathLike | str) -> None: self.opened_file = sf.SoundFile(path, "r") - def switch(self, path): - self.calls += 1 + def _switch(self, path: PathLike | str) -> None: if self.opened_file is None: - self.open(path) + self._open(path) if self.opened_file.name == str(path): return - self.close() - self.open(path) - self.opens += 1 + self._close() + self._open(path) - def read(self, path, start: int, stop: int): - self.switch(path) + def read(self, path: PathLike | str, start: int = 0, stop: int = -1) -> np.ndarray: + """Read the content of an audio file. + + If the audio file is not the current opened file, + the current opened file is switched. + + Parameters + ---------- + path: PathLike | str + Path to the audio file. + start: int + First frame to read. + stop: int + Last frame to read. + + Returns + ------- + np.ndarray: + A (channel * frames) array containing the audio data. + + """ + self._switch(path) self.opened_file.seek(start) - return self.opened_file.read(stop-start) + return self.opened_file.read(stop - start) + + def info(self, path: PathLike | str) -> tuple[int, int, int]: + """Return the sample rate, number of frames and channels of the audio file. + + Parameters + ---------- + path: PathLike | str + Path to the audio file. + + Returns + ------- + tuple[int,int,int]: + Sample rate, number of frames and channels of the audio file. - def info(self, path): - self.switch(path) - return self.opened_file.samplerate, self.opened_file.frames, self.opened_file.channels + """ + self._switch(path) + return ( + self.opened_file.samplerate, + self.opened_file.frames, + self.opened_file.channels, + ) diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index 48179b85..11307b36 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -58,7 +58,7 @@ def get_value(self) -> np.ndarray: If the Item is empty, return a single 0. """ - return super().get_value()[:self.shape] + return super().get_value()[: self.shape] @classmethod def from_base_item(cls, item: BaseItem) -> AudioItem: From 0e6886ea27efa9fbdd04a16baa02fcf01f6e449e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 16 Jan 2025 15:17:54 +0100 Subject: [PATCH 081/118] truncate item values at data level --- src/OSmOSE/data/audio_data.py | 1 + src/OSmOSE/data/audio_item.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index e5f4b274..26e17c8b 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -101,6 +101,7 @@ def get_value(self) -> np.ndarray: idx = 0 for item in self.items: item_data = self._get_item_value(item) + item_data = item_data[:min(item_data.shape[0], data.shape[0] - idx)] data[idx : idx + len(item_data)] = item_data idx += len(item_data) return data diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index 11307b36..dc243236 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -58,7 +58,7 @@ def get_value(self) -> np.ndarray: If the Item is empty, return a single 0. """ - return super().get_value()[: self.shape] + return super().get_value() @classmethod def from_base_item(cls, item: BaseItem) -> AudioItem: From d82e8ca41df57d6322b4b4e1d47e05a7156fc9ce Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 20 Jan 2025 11:16:31 +0100 Subject: [PATCH 082/118] add subtype parameter to AudioData write method --- src/OSmOSE/data/audio_data.py | 13 +++++++++---- src/OSmOSE/data/audio_dataset.py | 16 ++++++++++++++++ src/OSmOSE/data/base_data.py | 6 +++++- src/OSmOSE/utils/audio_utils.py | 9 ++++++--- 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 26e17c8b..5a9e5c95 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -101,22 +101,27 @@ def get_value(self) -> np.ndarray: idx = 0 for item in self.items: item_data = self._get_item_value(item) - item_data = item_data[:min(item_data.shape[0], data.shape[0] - idx)] + item_data = item_data[: min(item_data.shape[0], data.shape[0] - idx)] data[idx : idx + len(item_data)] = item_data idx += len(item_data) return data - def write(self, folder: Path) -> None: + def write(self, folder: Path, subtype: str | None = None) -> None: """Write the audio data to file. Parameters ---------- folder: pathlib.Path Folder in which to write the audio file. + subtype: str | None + Subtype as provided by the soundfile module. + Defaulted as the default 16-bit PCM for WAV audio files. """ - super().write(path=folder) - sf.write(folder / f"{self}.wav", self.get_value(), self.sample_rate) + super().create_directories(path=folder) + sf.write( + folder / f"{self}.wav", self.get_value(), self.sample_rate, subtype=subtype + ) def _get_item_value(self, item: AudioItem) -> np.ndarray: """Return the resampled (if needed) data from the audio item.""" diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index 2e2680bd..ac22fd3b 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -53,6 +53,22 @@ def sample_rate(self, sample_rate: float) -> None: for data in self.data: data.sample_rate = sample_rate + def write(self, folder: Path, subtype: str | None = None) -> None: + """Write all data objects in the specified folder. + + Parameters + ---------- + folder: Path + Folder in which to write the data. + subtype: str | None + Subtype as provided by the soundfile module. + Defaulted as the default 16-bit PCM for WAV audio files. + + + """ + for data in self.data: + data.write(folder, subtype) + @classmethod def from_folder( cls, diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index b2ba0636..7e208701 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -69,13 +69,17 @@ def get_value(self) -> np.ndarray: """Get the concatenated values from all Items.""" return np.concatenate([item.get_value() for item in self.items]) - def write(self, path: Path) -> None: + @staticmethod + def create_directories(path: Path) -> None: """Create the directory in which the data will be written. The actual data writing is left to the specified classes. """ path.mkdir(parents=True, exist_ok=True, mode=DPDEFAULT) + def write(self, folder: Path) -> None: + """Abstract method for writing data to file.""" + @property def files(self) -> set[TFile]: """All files referred to by the Data.""" diff --git a/src/OSmOSE/utils/audio_utils.py b/src/OSmOSE/utils/audio_utils.py index a8f53784..a8ae03ea 100644 --- a/src/OSmOSE/utils/audio_utils.py +++ b/src/OSmOSE/utils/audio_utils.py @@ -149,6 +149,7 @@ def generate_sample_audio( series_type: Literal["repeat", "increase"] = "repeat", min_value: float = 0.0, max_value: float = 1.0, + dtype: np.dtype = np.float64, ) -> list[np.ndarray]: """Generate sample audio data. @@ -175,15 +176,17 @@ def generate_sample_audio( """ if series_type == "repeat": return np.split( - np.tile(np.linspace(min_value, max_value, nb_samples), nb_files), + np.tile( + np.linspace(min_value, max_value, nb_samples, dtype=dtype), nb_files + ), nb_files, ) if series_type == "increase": return np.split( - np.linspace(min_value, max_value, nb_samples * nb_files), + np.linspace(min_value, max_value, nb_samples * nb_files, dtype=dtype), nb_files, ) - return np.split(np.empty(nb_samples * nb_files), nb_files) + return np.split(np.empty(nb_samples * nb_files, dtype=dtype), nb_files) def resample(data: np.ndarray, origin_sr: float, target_sr: float) -> np.ndarray: From d2afc8482f3f91115df95300cf9ff0fbef6ea5c1 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 20 Jan 2025 11:38:09 +0100 Subject: [PATCH 083/118] add audio write tests --- tests/test_audio.py | 67 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/tests/test_audio.py b/tests/test_audio.py index 01b1c9ea..00cce9e9 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pytest +import soundfile as sf from OSmOSE.config import TIMESTAMP_FORMAT_TEST_FILES from OSmOSE.data.audio_data import AudioData @@ -431,7 +432,9 @@ def test_audio_resample_sample_count( None, pd.Timedelta(seconds=1), generate_sample_audio( - nb_files=3, nb_samples=48_000, series_type="increase" + nb_files=3, + nb_samples=48_000, + series_type="increase", ), id="multiple_consecutive_files", ), @@ -450,7 +453,10 @@ def test_audio_resample_sample_count( [ generate_sample_audio(nb_files=1, nb_samples=96_000)[0][0:48_000], generate_sample_audio( - nb_files=1, nb_samples=48_000, min_value=0.0, max_value=0.0 + nb_files=1, + nb_samples=48_000, + min_value=0.0, + max_value=0.0, )[0], generate_sample_audio(nb_files=1, nb_samples=96_000)[0][48_000:], ], @@ -493,3 +499,60 @@ def test_audio_dataset_from_folder( np.array_equal(data.get_value(), expected) for (data, expected) in zip(dataset.data, expected_audio_data) ) + + +@pytest.mark.parametrize( + ("audio_files", "subtype", "expected_audio_data"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + "DOUBLE", + generate_sample_audio(1, 48_000, dtype=np.float64), + id="float64_file", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + "FLOAT", + generate_sample_audio(1, 48_000, dtype=np.float32), + id="float32_file", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 2, + "inter_file_duration": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + "DOUBLE", + generate_sample_audio(1, 48_000, dtype=np.float64), + id="padded_file", + ), + ], + indirect=["audio_files"], +) +def test_write_files( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + subtype: str, + expected_audio_data: list[tuple[int, bool]], +) -> None: + dataset = AudioDataset.from_folder( + tmp_path, + strptime_format=TIMESTAMP_FORMAT_TEST_FILES, + ) + output_path = tmp_path / "output" + dataset.write(output_path, subtype=subtype) + for data in dataset.data: + assert f"{data}.wav" in [f.name for f in output_path.glob("*.wav")] + assert np.allclose(data.get_value(), sf.read(output_path / f"{data}.wav")[0]) From de1cf5eebed2dbe67f94db6265ec2753dac37fe6 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 20 Jan 2025 13:56:32 +0100 Subject: [PATCH 084/118] fix audio data read boundary frames --- src/OSmOSE/data/audio_data.py | 24 ++++++++++++++++++++---- src/OSmOSE/data/audio_file.py | 8 ++++++-- src/OSmOSE/data/audio_file_manager.py | 2 +- src/OSmOSE/utils/audio_utils.py | 5 ++++- 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 5a9e5c95..0be3c7a9 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -64,7 +64,7 @@ def nb_channels(self) -> int: ) @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> tuple[int, ...] | int: """Shape of the audio data.""" data_length = round(self.sample_rate * self.duration.total_seconds()) return data_length if self.nb_channels <= 1 else (data_length, self.nb_channels) @@ -120,7 +120,10 @@ def write(self, folder: Path, subtype: str | None = None) -> None: """ super().create_directories(path=folder) sf.write( - folder / f"{self}.wav", self.get_value(), self.sample_rate, subtype=subtype + folder / f"{self}.wav", + self.get_value(), + self.sample_rate, + subtype=subtype, ) def _get_item_value(self, item: AudioItem) -> np.ndarray: @@ -134,10 +137,23 @@ def _get_item_value(self, item: AudioItem) -> np.ndarray: return resample(item_data, item.sample_rate, self.sample_rate) return item_data - def divide(self, nb_subdata: int = 2) -> list[AudioData]: + def split(self, nb_subdata: int = 2) -> list[AudioData]: + """Split the audio data object in the specified number of audio subdata. + + Parameters + ---------- + nb_subdata: int + Number of subdata in which to split the data. + + Returns + ------- + list[AudioData] + The list of AudioData subdata objects. + + """ return [ AudioData.from_base_data(base_data, self.sample_rate) - for base_data in super().divide(nb_subdata) + for base_data in super().split(nb_subdata) ] @classmethod diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index bc7404b1..635f07c7 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -8,6 +8,8 @@ from os import PathLike import numpy as np +from math import floor + from pandas import Timedelta, Timestamp from OSmOSE.data import audio_file_manager as afm @@ -64,11 +66,13 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: ------- numpy.ndarray: The audio data between start and stop. + The first frame of the data is the first frame that ends after start. + The last frame of the data is the last frame that starts before stop. """ sample_rate = self.sample_rate - start_sample = round((start - self.begin).total_seconds() * sample_rate) - stop_sample = round((stop - self.begin).total_seconds() * sample_rate) + start_sample = floor((start - self.begin).total_seconds() * sample_rate) + stop_sample = floor((stop - self.begin).total_seconds() * sample_rate) return afm.read(self.path, start=start_sample, stop=stop_sample) @classmethod diff --git a/src/OSmOSE/data/audio_file_manager.py b/src/OSmOSE/data/audio_file_manager.py index 96a5e2d3..b04da563 100644 --- a/src/OSmOSE/data/audio_file_manager.py +++ b/src/OSmOSE/data/audio_file_manager.py @@ -52,7 +52,7 @@ def read(self, path: PathLike | str, start: int = 0, stop: int = -1) -> np.ndarr start: int First frame to read. stop: int - Last frame to read. + Frame after the last frame to read. Returns ------- diff --git a/src/OSmOSE/utils/audio_utils.py b/src/OSmOSE/utils/audio_utils.py index a8ae03ea..89ca00e3 100644 --- a/src/OSmOSE/utils/audio_utils.py +++ b/src/OSmOSE/utils/audio_utils.py @@ -167,6 +167,8 @@ def generate_sample_audio( Minimum value of the audio data. max_value: float Maximum value of the audio data. + dtype: np.dtype + The type of the output array. Returns ------- @@ -177,7 +179,8 @@ def generate_sample_audio( if series_type == "repeat": return np.split( np.tile( - np.linspace(min_value, max_value, nb_samples, dtype=dtype), nb_files + np.linspace(min_value, max_value, nb_samples, dtype=dtype), + nb_files, ), nb_files, ) From 227af9b679a96f837d75482c6ecc7c06ec2403aa Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 20 Jan 2025 14:31:53 +0100 Subject: [PATCH 085/118] add tests for AudioData split method --- src/OSmOSE/data/base_data.py | 15 +++++- src/OSmOSE/data/spectro_data.py | 20 ++++++-- tests/test_audio.py | 87 ++++++++++++++++++++++++++++++++- 3 files changed, 114 insertions(+), 8 deletions(-) diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/data/base_data.py index 7e208701..f06fe304 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/data/base_data.py @@ -85,7 +85,20 @@ def files(self) -> set[TFile]: """All files referred to by the Data.""" return {item.file for item in self.items if item.file is not None} - def divide(self, nb_subdata: int = 2) -> list[BaseData]: + def split(self, nb_subdata: int = 2) -> list[BaseData]: + """Split the data object in the specified number of subdata. + + Parameters + ---------- + nb_subdata: int + Number of subdata in which to split the data. + + Returns + ------- + list[BaseData] + The list of BaseData subdata objects. + + """ dates = date_range(self.begin, self.end, periods=nb_subdata + 1) subdata_dates = zip(dates, dates[1:]) return [ diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index a023bcb3..5e43c529 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -124,18 +124,28 @@ def get_value(self) -> np.ndarray: return self.fft.spectrogram(self.audio_data.get_value(), padding="even") - def get_ltas_value(self, nb_windows: int = 1920) -> np.ndarray: + def get_ltas_value(self, nb_windows: int = 1920, depth: int = 0) -> np.ndarray: if self.shape[1] <= nb_windows: return self.get_value() sub_spectros = [ SpectroData.from_audio_data(ad, self.fft) - for ad in self.audio_data.divide(nb_windows) + for ad in self.audio_data.split(nb_windows) ] + from tqdm import tqdm + + if depth == 0: + m = [] + for sub_spectro in tqdm(sub_spectros): + m.append( + np.mean(sub_spectro.get_ltas_value(nb_windows, depth + 1), axis=1), + ) + return np.vstack(m).T + return np.vstack( [ - np.mean(sub_spectro.get_ltas_value(nb_windows), axis=1) + np.mean(sub_spectro.get_ltas_value(nb_windows, depth + 1), axis=1) for sub_spectro in sub_spectros - ] + ], ).T def plot(self, ax: plt.Axes | None = None, nb_ltas_windows: int = 0) -> None: @@ -185,7 +195,7 @@ def write(self, folder: Path) -> None: Folder in which to write the Spectro file. """ - super().write(path=folder) + super().create_directories(path=folder) sx = self.get_value() time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f diff --git a/tests/test_audio.py b/tests/test_audio.py index 00cce9e9..60f28ccd 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -486,7 +486,7 @@ def test_audio_dataset_from_folder( begin: pd.Timestamp | None, end: pd.Timestamp | None, duration: pd.Timedelta | None, - expected_audio_data: list[tuple[int, bool]], + expected_audio_data: list[np.ndarray], ) -> None: dataset = AudioDataset.from_folder( tmp_path, @@ -545,7 +545,7 @@ def test_write_files( tmp_path: Path, audio_files: tuple[list[Path], pytest.fixtures.Subrequest], subtype: str, - expected_audio_data: list[tuple[int, bool]], + expected_audio_data: list[np.ndarray], ) -> None: dataset = AudioDataset.from_folder( tmp_path, @@ -556,3 +556,86 @@ def test_write_files( for data in dataset.data: assert f"{data}.wav" in [f.name for f in output_path.glob("*.wav")] assert np.allclose(data.get_value(), sf.read(output_path / f"{data}.wav")[0]) + + +@pytest.mark.parametrize( + ("audio_files", "nb_subdata", "original_audio_data"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 2, + generate_sample_audio(1, 48_000, dtype=np.float64), + id="even_samples_split_in_two", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 4, + generate_sample_audio(1, 48_000, dtype=np.float64), + id="even_samples_split_in_four", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_001, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 2, + generate_sample_audio(1, 48_000, dtype=np.float64), + id="odd_samples_split_in_two", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_001, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 4, + generate_sample_audio(1, 48_001, dtype=np.float64), + id="odd_samples_split_in_four", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 3, + generate_sample_audio(1, 10, dtype=np.float64), + id="infinite_decimal_points", + ), + ], + indirect=["audio_files"], +) +def test_split_data( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + nb_subdata: int, + original_audio_data: list[np.ndarray], +) -> None: + dataset = AudioDataset.from_folder( + tmp_path, + strptime_format=TIMESTAMP_FORMAT_TEST_FILES, + ) + for data in dataset.data: + subdata_shape = data.shape // nb_subdata + for subdata, data_range in zip( + data.split(nb_subdata), + range(0, data.shape, subdata_shape), + ): + assert np.array_equal( + subdata.get_value(), + data.get_value()[data_range : data_range + subdata_shape], + ) From 67bde9903541c8d295e2006ba3470eaf93f5ba85 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 20 Jan 2025 15:13:20 +0100 Subject: [PATCH 086/118] compute item shape from file frame_indexes method --- src/OSmOSE/data/audio_file.py | 27 +++++++++-- src/OSmOSE/data/audio_item.py | 3 +- tests/test_audio.py | 84 +++++++++++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 4 deletions(-) diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index 635f07c7..b1fc0ce2 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -70,11 +70,32 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: The last frame of the data is the last frame that starts before stop. """ - sample_rate = self.sample_rate - start_sample = floor((start - self.begin).total_seconds() * sample_rate) - stop_sample = floor((stop - self.begin).total_seconds() * sample_rate) + start_sample, stop_sample = self.frames_indexes(start, stop) return afm.read(self.path, start=start_sample, stop=stop_sample) + def frames_indexes(self, start: Timestamp, stop: Timestamp) -> tuple[int, int]: + """Return the indexes of the frames between the start and stop timestamps. + + The start index is that of the first sample that ends after the start timestamp. + The stop index is that of the last sample that starts before the stop timestamp. + + Parameters + ---------- + start: pandas.Timestamp + Timestamp corresponding to the first data point to read. + stop: pandas.Timestamp + Timestamp after the last data point to read. + + Returns + ------- + tuple[int,int] + First and last frames of the data. + + """ + start_sample = floor((start - self.begin).total_seconds() * self.sample_rate) + stop_sample = floor((stop - self.begin).total_seconds() * self.sample_rate) + return start_sample, stop_sample + @classmethod def from_base_file(cls, file: BaseFile) -> AudioFile: """Return an AudioFile object from a BaseFile object.""" diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/data/audio_item.py index dc243236..66eedeb1 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/data/audio_item.py @@ -51,7 +51,8 @@ def nb_channels(self) -> int: @property def shape(self) -> int: """Number of points in the audio item data.""" - return round(self.sample_rate * self.duration.total_seconds()) + start_sample, end_sample = self.file.frames_indexes(self.begin, self.end) + return end_sample - start_sample def get_value(self) -> np.ndarray: """Get the values from the File between the begin and stop timestamps. diff --git a/tests/test_audio.py b/tests/test_audio.py index 60f28ccd..581099c2 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -205,6 +205,90 @@ def test_audio_file_read( generate_sample_audio(nb_files=1, nb_samples=48_000)[0][24_000:28_800], id="mid_file", ), + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=320_000, + ), + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=600_000, + ), + generate_sample_audio(nb_files=1, nb_samples=10)[0][3:6], + id="start_between_frames", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=300_000, + ), + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=620_000, + ), + generate_sample_audio(nb_files=1, nb_samples=10)[0][3:6], + id="stop_between_frames", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=290_000, + ), + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=790_000, + ), + generate_sample_audio(nb_files=1, nb_samples=10)[0][2:7], + id="first_frame_included_last_frame_discarded", + ), ], indirect=["audio_files"], ) From 16cd48412a2b6473b17238cfe0a21e9575a26a05 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 20 Jan 2025 17:01:27 +0100 Subject: [PATCH 087/118] return a float AudioDataset.sample_rate if only 1 sr in its data --- src/OSmOSE/data/audio_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index ac22fd3b..346d775a 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -44,9 +44,10 @@ def __init__(self, data: list[AudioData]) -> None: super().__init__(data) @property - def sample_rate(self) -> set[float]: + def sample_rate(self) -> set[float] | float: """Return the sample rate of the audio data.""" - return {data.sample_rate for data in self.data} + sample_rates = {data.sample_rate for data in self.data} + return sample_rates if len(list(sample_rates)) > 1 else next(iter(sample_rates)) @sample_rate.setter def sample_rate(self, sample_rate: float) -> None: From ce6fc22e16a0076be13d7e9ec09bad9da228c42f Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 21 Jan 2025 10:36:11 +0100 Subject: [PATCH 088/118] add read test --- tests/test_audio_file_manager.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/test_audio_file_manager.py diff --git a/tests/test_audio_file_manager.py b/tests/test_audio_file_manager.py new file mode 100644 index 00000000..0ae7fc8a --- /dev/null +++ b/tests/test_audio_file_manager.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from OSmOSE.data.audio_file_manager import AudioFileManager +from OSmOSE.utils.audio_utils import generate_sample_audio +from pathlib import Path +import pytest +import numpy as np + + +@pytest.mark.parametrize( + ("audio_files", "frames", "expected"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + (None, None), + generate_sample_audio(nb_files=1, nb_samples=48_000)[0], + id="full_file", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + (None, 10), + generate_sample_audio(nb_files=1, nb_samples=48_000)[0][:10], + id="begin_of_file", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + (10, None), + generate_sample_audio(nb_files=1, nb_samples=48_000)[0][10:], + id="end_of_file", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + (10, 10), + generate_sample_audio(nb_files=1, nb_samples=48_000)[0][10:10], + id="mid_of_file", + ), + ], + indirect=["audio_files"], +) +def test_read( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + frames: tuple[int, int], + expected: np.ndarray, +) -> None: + + audio_file_path = tmp_path / "audio_000101000000000000.wav" + afm = AudioFileManager() + params = {"start": frames[0], "stop": frames[1]} + params = {k: v for k, v in params.items() if v is not None} + assert np.array_equal(afm.read(path=audio_file_path, **params), expected) From 3ed8e48f40bf0ed57d6674f2ae1e53accdeb6945 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 21 Jan 2025 11:34:11 +0100 Subject: [PATCH 089/118] add afm read error tests --- src/OSmOSE/data/audio_file_manager.py | 19 ++++++- tests/test_audio_file_manager.py | 73 +++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/data/audio_file_manager.py b/src/OSmOSE/data/audio_file_manager.py index b04da563..047bf859 100644 --- a/src/OSmOSE/data/audio_file_manager.py +++ b/src/OSmOSE/data/audio_file_manager.py @@ -39,7 +39,9 @@ def _switch(self, path: PathLike | str) -> None: self._close() self._open(path) - def read(self, path: PathLike | str, start: int = 0, stop: int = -1) -> np.ndarray: + def read( + self, path: PathLike | str, start: int = 0, stop: int | None = None + ) -> np.ndarray: """Read the content of an audio file. If the audio file is not the current opened file, @@ -61,6 +63,21 @@ def read(self, path: PathLike | str, start: int = 0, stop: int = -1) -> np.ndarr """ self._switch(path) + _, frames, _ = self.info(path) + if stop is None: + stop = frames + + if not 0 <= start < frames: + raise ValueError( + "Start should be between 0 and the last frame of the audio file." + ) + if not 0 <= stop <= frames: + raise ValueError( + "Stop should be between 0 and the last frame of the audio file." + ) + if start > stop: + raise ValueError("Start should be inferior to Stop.") + self.opened_file.seek(start) return self.opened_file.read(stop - start) diff --git a/tests/test_audio_file_manager.py b/tests/test_audio_file_manager.py index 0ae7fc8a..a3b7752c 100644 --- a/tests/test_audio_file_manager.py +++ b/tests/test_audio_file_manager.py @@ -5,6 +5,7 @@ from pathlib import Path import pytest import numpy as np +from soundfile import LibsndfileError @pytest.mark.parametrize( @@ -65,3 +66,75 @@ def test_read( params = {"start": frames[0], "stop": frames[1]} params = {k: v for k, v in params.items() if v is not None} assert np.array_equal(afm.read(path=audio_file_path, **params), expected) + + +@pytest.mark.parametrize( + ("audio_files", "frames", "expected"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + (-10, None), + pytest.raises( + ValueError, + match="Start should be between 0 and the last frame of the audio file.", + ), + id="negative_start_raises_error", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + (50_000, None), + pytest.raises( + ValueError, + match="Start should be between 0 and the last frame of the audio file.", + ), + id="out_of_bounds_start_raises_error", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + (None, -10), + pytest.raises( + ValueError, + match="Stop should be between 0 and the last frame of the audio file.", + ), + id="negative_stop_raises_error", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + (None, 50_000), + pytest.raises( + ValueError, + match="Stop should be between 0 and the last frame of the audio file.", + ), + id="out_of_bounds_stop_raises_error", + ), + ], + indirect=["audio_files"], +) +def test_read_errors( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + frames: tuple[int, int], + expected: np.ndarray, +) -> None: + audio_file_path = tmp_path / "audio_000101000000000000.wav" + afm = AudioFileManager() + params = {"start": frames[0], "stop": frames[1]} + params = {k: v for k, v in params.items() if v is not None} + with expected as e: + assert afm.read(path=audio_file_path, **params) == e From f0a068fffe8493ebe539f199ba93b64161e9672e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 21 Jan 2025 15:15:12 +0100 Subject: [PATCH 090/118] add AudioFileManager tests --- src/OSmOSE/data/audio_file_manager.py | 9 +- tests/conftest.py | 16 ++++ tests/test_audio_file_manager.py | 130 ++++++++++++++++++++++++-- 3 files changed, 142 insertions(+), 13 deletions(-) diff --git a/src/OSmOSE/data/audio_file_manager.py b/src/OSmOSE/data/audio_file_manager.py index 047bf859..5683a9cd 100644 --- a/src/OSmOSE/data/audio_file_manager.py +++ b/src/OSmOSE/data/audio_file_manager.py @@ -40,7 +40,10 @@ def _switch(self, path: PathLike | str) -> None: self._open(path) def read( - self, path: PathLike | str, start: int = 0, stop: int | None = None + self, + path: PathLike | str, + start: int = 0, + stop: int | None = None, ) -> np.ndarray: """Read the content of an audio file. @@ -69,11 +72,11 @@ def read( if not 0 <= start < frames: raise ValueError( - "Start should be between 0 and the last frame of the audio file." + "Start should be between 0 and the last frame of the audio file.", ) if not 0 <= stop <= frames: raise ValueError( - "Stop should be between 0 and the last frame of the audio file." + "Stop should be between 0 and the last frame of the audio file.", ) if start > stop: raise ValueError("Start should be inferior to Stop.") diff --git a/tests/conftest.py b/tests/conftest.py index 3fbb7dda..b70a1a33 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ from scipy.signal import chirp from OSmOSE.config import OSMOSE_PATH, TIMESTAMP_FORMAT_TEST_FILES +from OSmOSE.data import AudioFileManager from OSmOSE.utils.audio_utils import generate_sample_audio @@ -100,6 +101,21 @@ def mock_chown(path: Path, uid: int, gid: int) -> None: return mocked_grp_module +@pytest.fixture +def patch_afm_open(monkeypatch) -> list[Path]: + """Mock the AudioFileManager._open method in order to track the file openings.""" + + opened_files = [] + open_func = AudioFileManager._open + + def mock_open(self, path: Path): + opened_files.append(path) + open_func(self, path) + + monkeypatch.setattr(AudioFileManager, "_open", mock_open) + return opened_files + + @pytest.fixture def input_dataset(tmp_path: Path): """Fixture to create an input dataset. diff --git a/tests/test_audio_file_manager.py b/tests/test_audio_file_manager.py index a3b7752c..878c607c 100644 --- a/tests/test_audio_file_manager.py +++ b/tests/test_audio_file_manager.py @@ -1,11 +1,12 @@ from __future__ import annotations -from OSmOSE.data.audio_file_manager import AudioFileManager -from OSmOSE.utils.audio_utils import generate_sample_audio from pathlib import Path -import pytest + import numpy as np -from soundfile import LibsndfileError +import pytest + +from OSmOSE.data.audio_file_manager import AudioFileManager +from OSmOSE.utils.audio_utils import generate_sample_audio @pytest.mark.parametrize( @@ -55,17 +56,16 @@ indirect=["audio_files"], ) def test_read( - tmp_path: Path, audio_files: tuple[list[Path], pytest.fixtures.Subrequest], frames: tuple[int, int], expected: np.ndarray, ) -> None: - audio_file_path = tmp_path / "audio_000101000000000000.wav" + audio_files, _ = audio_files afm = AudioFileManager() params = {"start": frames[0], "stop": frames[1]} params = {k: v for k, v in params.items() if v is not None} - assert np.array_equal(afm.read(path=audio_file_path, **params), expected) + assert np.array_equal(afm.read(path=audio_files[0], **params), expected) @pytest.mark.parametrize( @@ -127,14 +127,124 @@ def test_read( indirect=["audio_files"], ) def test_read_errors( - tmp_path: Path, audio_files: tuple[list[Path], pytest.fixtures.Subrequest], frames: tuple[int, int], expected: np.ndarray, ) -> None: - audio_file_path = tmp_path / "audio_000101000000000000.wav" + audio_files, _ = audio_files afm = AudioFileManager() params = {"start": frames[0], "stop": frames[1]} params = {k: v for k, v in params.items() if v is not None} with expected as e: - assert afm.read(path=audio_file_path, **params) == e + assert afm.read(path=audio_files[0], **params) == e + + +@pytest.mark.parametrize( + ("audio_files", "file_openings", "expected_opened_files"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 1, + }, + [0], + [0], + id="one_single_file_opening", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 1, + }, + [0, 0, 0, 0, 0], + [0], + id="repeated_file_openings", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 5, + }, + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + id="different_file_openings", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 5, + }, + [0, 0, 0, 1, 1, 1, 1, 2, 3, 3, 4, 4, 4, 2, 2, 1, 1], + [0, 1, 2, 3, 4, 2, 1], + id="multiple_repeated_file_openings", + ), + ], + indirect=["audio_files"], +) +def test_switch( + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + file_openings: list[int], + patch_afm_open: list[Path], + expected_opened_files: list[int], +) -> None: + + afm = AudioFileManager() + audio_files, _ = audio_files + for file in file_openings: + afm.read(path=audio_files[file]) + assert [audio_files.index(f) for f in patch_afm_open] == expected_opened_files + assert audio_files.index(Path(afm.opened_file.name)) == file_openings[-1] + + +@pytest.mark.parametrize( + "audio_files", + [ + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 1, + }, + id="one_single_file_opening", + ), + ], + indirect=True, +) +def test_close(audio_files: tuple[list[Path], pytest.fixtures.Subrequest]) -> None: + afm = AudioFileManager() + assert afm.opened_file is None + audio_files, _ = audio_files + afm.read(audio_files[0]) + assert afm.opened_file is not None + assert Path(afm.opened_file.name) == audio_files[0] + afm._close() + assert afm.opened_file is None + + +@pytest.mark.parametrize( + "audio_files", + [ + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 5, + }, + id="multiple_files", + ), + ], + indirect=True, +) +def test_info(audio_files: tuple[list[Path], pytest.fixtures.Subrequest]) -> None: + afm = AudioFileManager() + audio_files, request = audio_files + for file in audio_files: + assert afm.opened_file is None or Path(afm.opened_file.name) != file.name + sample_rate, frames, channels = afm.info(file) + assert request.param["sample_rate"] == sample_rate + assert request.param["duration"] * request.param["sample_rate"] == frames + assert channels == 1 From 0b391cd1c68f4f9b04304549327fe10ba6e11afa Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 21 Jan 2025 17:17:06 +0100 Subject: [PATCH 091/118] linting --- tests/conftest.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b70a1a33..c6fc48fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,10 +38,12 @@ def audio_files( ) files = [] for index, begin_time in enumerate( - pd.date_range( - date_begin, - periods=nb_files, - freq=pd.Timedelta(seconds=duration + inter_file_duration), + list( + pd.date_range( + date_begin, + periods=nb_files, + freq=pd.Timedelta(seconds=duration + inter_file_duration), + ) ), ): time_str = begin_time.strftime(format=TIMESTAMP_FORMAT_TEST_FILES) @@ -102,13 +104,13 @@ def mock_chown(path: Path, uid: int, gid: int) -> None: @pytest.fixture -def patch_afm_open(monkeypatch) -> list[Path]: +def patch_afm_open(monkeypatch: pytest.MonkeyPatch) -> list[Path]: """Mock the AudioFileManager._open method in order to track the file openings.""" opened_files = [] open_func = AudioFileManager._open - def mock_open(self, path: Path): + def mock_open(self: AudioFileManager, path: Path) -> None: opened_files.append(path) open_func(self, path) From ff81e6be3bb2dac5e5fda0a5d5e977f224cf7666 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 21 Jan 2025 18:00:55 +0100 Subject: [PATCH 092/118] fix issue rejecting last frame of audio file --- src/OSmOSE/data/audio_file.py | 2 +- tests/test_audio.py | 75 ++++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index b1fc0ce2..405b5824 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -93,7 +93,7 @@ def frames_indexes(self, start: Timestamp, stop: Timestamp) -> tuple[int, int]: """ start_sample = floor((start - self.begin).total_seconds() * self.sample_rate) - stop_sample = floor((stop - self.begin).total_seconds() * self.sample_rate) + stop_sample = round((stop - self.begin).total_seconds() * self.sample_rate) return start_sample, stop_sample @classmethod diff --git a/tests/test_audio.py b/tests/test_audio.py index 581099c2..eb6c8bd1 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -286,8 +286,36 @@ def test_audio_file_read( second=0, microsecond=790_000, ), + generate_sample_audio(nb_files=1, nb_samples=10)[0][2:8], + id="first_frame_included_last_frame_rounding_up", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 10, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=290_000, + ), + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=720_000, + ), generate_sample_audio(nb_files=1, nb_samples=10)[0][2:7], - id="first_frame_included_last_frame_discarded", + id="first_frame_included_last_frame_rounding_down", ), ], indirect=["audio_files"], @@ -408,6 +436,51 @@ def test_audio_data( assert np.array_equal(data.get_value(), expected) +@pytest.mark.parametrize( + "audio_files", + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + id="simple_audio", + ), + pytest.param( + { + "duration": 14.303492063, + "sample_rate": 44_100, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + id="uneven_boundaries_rounding_up", + ), + pytest.param( + { + "duration": 14.303471655328797, + "sample_rate": 44_100, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + id="uneven_boundaries_rounding_down", + ), + ], + indirect=True, +) +def test_read_vs_soundfile( + audio_files: tuple[list[Path], pytest.fixtures.Subrequest] +) -> None: + audio_files, _ = audio_files + af = AudioFile(audio_files[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) + ad = AudioData.from_files([af]) + assert np.array_equal(sf.read(audio_files[0])[0], ad.get_value()) + + @pytest.mark.parametrize( ("audio_files", "start", "stop", "sample_rate", "expected_nb_samples"), [ From 5c42879746230f31c5eb1fd173efc55bb3e1c535 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 22 Jan 2025 14:31:50 +0100 Subject: [PATCH 093/118] add nb_bytes SpectroData property --- src/OSmOSE/data/spectro_data.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 5e43c529..0346523e 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -108,6 +108,11 @@ def shape(self) -> tuple[int, ...]: int(self.fft.fs * self.duration.total_seconds()), ) + @property + def nb_bytes(self) -> int: + """Total bytes consumed by the spectro values.""" + return self.shape[0] * self.shape[1] * 8 + def __str__(self) -> str: """Overwrite __str__.""" return self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES) From a989b66f419602bf260c085a692666a9c10a0879 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 22 Jan 2025 16:29:18 +0100 Subject: [PATCH 094/118] ltas to specific class --- src/OSmOSE/data/ltas_data.py | 135 ++++++++++++++++++++++++++++++++ src/OSmOSE/data/spectro_data.py | 32 +------- 2 files changed, 137 insertions(+), 30 deletions(-) create mode 100644 src/OSmOSE/data/ltas_data.py diff --git a/src/OSmOSE/data/ltas_data.py b/src/OSmOSE/data/ltas_data.py new file mode 100644 index 00000000..2b5a00bc --- /dev/null +++ b/src/OSmOSE/data/ltas_data.py @@ -0,0 +1,135 @@ +"""LTASData is a special form of SpectroData. + +The Sx values from a LTASData object are computed recursively. +LTAS should be preferred to classic spectrograms in cases where the audio is really long. +In that case, the corresponding number of time bins (scipy.ShortTimeFTT.p_nums) is +too long for the whole Sx matrix to be computed once. + +The LTAS are rather computed recursively. If the number of temporal bins is higher than +a target p_num value, the audio is split in p_num parts. A separate sft is computed +on each of these bits and averaged so that the end Sx presents p_num temporal windows. + +This averaging is performed recursively: if the audio data is such that after a first split, +the p_nums for each part still is higher than p_num, the parts are further split and +each part is replaced with an average of the stft performed within it. + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from scipy.signal import ShortTimeFFT +from tqdm import tqdm + +from OSmOSE.data.spectro_data import SpectroData +from OSmOSE.data.spectro_item import SpectroItem + +if TYPE_CHECKING: + + from pandas import Timestamp + + from OSmOSE.data.audio_data import AudioData + + +class LTASData(SpectroData): + """LTASData is a special form of SpectroData. + + The Sx values from a LTASData object are computed recursively. + LTAS should be preferred to classic spectrograms in cases where the audio is really long. + In that case, the corresponding number of time bins (scipy.ShortTimeFTT.p_nums) is + too long for the whole Sx matrix to be computed once. + + The LTAS are rather computed recursively. If the number of temporal bins is higher than + a target p_num value, the audio is split in p_num parts. A separate sft is computed + on each of these bits and averaged so that the end Sx presents p_num temporal windows. + + This averaging is performed recursively: if the audio data is such that after a first split, + the p_nums for each part still is higher than p_num, the parts are further split and + each part is replaced with an average of the stft performed within it. + + """ + + def __init__( + self, + items: list[SpectroItem] | None = None, + audio_data: AudioData = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + fft: ShortTimeFFT | None = None, + nb_time_bins: int = 1920, + ) -> None: + """Initialize a SpectroData from a list of SpectroItems. + + Parameters + ---------- + items: list[SpectroItem] + List of the SpectroItem constituting the SpectroData. + audio_data: AudioData + The audio data from which to compute the spectrogram. + begin: Timestamp | None + Only effective if items is None. + Set the begin of the empty data. + end: Timestamp | None + Only effective if items is None. + Set the end of the empty data. + fft: ShortTimeFFT + The short time FFT used for computing the spectrogram. + + """ + ltas_fft = LTASData.get_ltas_fft(fft) + super().__init__( + items=items, audio_data=audio_data, begin=begin, end=end, fft=ltas_fft + ) + self.nb_time_bins = nb_time_bins + + def get_value(self, depth: int = 0) -> np.ndarray: + if self.shape[1] <= self.nb_time_bins: + return super().get_value() + sub_spectros = [ + LTASData.from_spectro_data( + SpectroData.from_audio_data(ad, self.fft), + nb_time_bins=self.nb_time_bins, + ) + for ad in self.audio_data.split(self.nb_time_bins) + ] + + if depth == 0: + m = [] + for sub_spectro in tqdm(sub_spectros): + m.append( + np.mean(sub_spectro.get_value(depth + 1), axis=1), + ) + return np.vstack(m).T + + return np.vstack( + [ + np.mean(sub_spectro.get_value(depth + 1), axis=1) + for sub_spectro in sub_spectros + ], + ).T + + @classmethod + def from_spectro_data(cls, spectro_data: SpectroData, nb_time_bins: int): + items = spectro_data.items + audio_data = spectro_data.audio_data + begin = spectro_data.begin + end = spectro_data.end + fft = spectro_data.fft + return cls( + items=items, + audio_data=audio_data, + begin=begin, + end=end, + fft=fft, + nb_time_bins=nb_time_bins, + ) + + @staticmethod + def get_ltas_fft(fft: ShortTimeFFT): + win = fft.win + fs = fft.fs + mfft = fft.mfft + hop = win.shape[0] + return ShortTimeFFT(win=win, hop=hop, fs=fs, mfft=mfft) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 0346523e..28340ac0 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -129,31 +129,7 @@ def get_value(self) -> np.ndarray: return self.fft.spectrogram(self.audio_data.get_value(), padding="even") - def get_ltas_value(self, nb_windows: int = 1920, depth: int = 0) -> np.ndarray: - if self.shape[1] <= nb_windows: - return self.get_value() - sub_spectros = [ - SpectroData.from_audio_data(ad, self.fft) - for ad in self.audio_data.split(nb_windows) - ] - from tqdm import tqdm - - if depth == 0: - m = [] - for sub_spectro in tqdm(sub_spectros): - m.append( - np.mean(sub_spectro.get_ltas_value(nb_windows, depth + 1), axis=1), - ) - return np.vstack(m).T - - return np.vstack( - [ - np.mean(sub_spectro.get_ltas_value(nb_windows, depth + 1), axis=1) - for sub_spectro in sub_spectros - ], - ).T - - def plot(self, ax: plt.Axes | None = None, nb_ltas_windows: int = 0) -> None: + def plot(self, ax: plt.Axes | None = None) -> None: """Plot the spectrogram on a specific Axes. Parameters @@ -164,11 +140,7 @@ def plot(self, ax: plt.Axes | None = None, nb_ltas_windows: int = 0) -> None: """ ax = ax if ax is not None else SpectroData.get_default_ax() - sx = ( - self.get_value() - if nb_ltas_windows == 0 - else self.get_ltas_value(nb_ltas_windows) - ) + sx = self.get_value() sx = 10 * np.log10(abs(sx) + np.nextafter(0, 1)) time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f From d6b4fe53cd7e878b31fdc26cef0843eaa44fd45e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 22 Jan 2025 17:23:02 +0100 Subject: [PATCH 095/118] move tqdm in list comprehension --- src/OSmOSE/data/ltas_data.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/OSmOSE/data/ltas_data.py b/src/OSmOSE/data/ltas_data.py index 2b5a00bc..235ffe9b 100644 --- a/src/OSmOSE/data/ltas_data.py +++ b/src/OSmOSE/data/ltas_data.py @@ -80,7 +80,11 @@ def __init__( """ ltas_fft = LTASData.get_ltas_fft(fft) super().__init__( - items=items, audio_data=audio_data, begin=begin, end=end, fft=ltas_fft + items=items, + audio_data=audio_data, + begin=begin, + end=end, + fft=ltas_fft, ) self.nb_time_bins = nb_time_bins @@ -95,18 +99,10 @@ def get_value(self, depth: int = 0) -> np.ndarray: for ad in self.audio_data.split(self.nb_time_bins) ] - if depth == 0: - m = [] - for sub_spectro in tqdm(sub_spectros): - m.append( - np.mean(sub_spectro.get_value(depth + 1), axis=1), - ) - return np.vstack(m).T - return np.vstack( [ np.mean(sub_spectro.get_value(depth + 1), axis=1) - for sub_spectro in sub_spectros + for sub_spectro in (sub_spectros if depth != 0 else tqdm(sub_spectros)) ], ).T From bda9f3f73d591c686b8fec08ae704b5c415aad6d Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 22 Jan 2025 17:55:04 +0100 Subject: [PATCH 096/118] add docstrings for LTASData --- src/OSmOSE/data/ltas_data.py | 51 ++++++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/src/OSmOSE/data/ltas_data.py b/src/OSmOSE/data/ltas_data.py index 235ffe9b..b2979734 100644 --- a/src/OSmOSE/data/ltas_data.py +++ b/src/OSmOSE/data/ltas_data.py @@ -76,6 +76,15 @@ def __init__( Set the end of the empty data. fft: ShortTimeFFT The short time FFT used for computing the spectrogram. + nb_time_bins: int + The maximum number of time bins of the LTAS. + Given the audio data and the fft parameters, + if the resulting spectrogram has a number of windows p_num + <= nb_time_bins, the LTAS is computed like a classic spectrogram. + Otherwise, the audio data is split in nb_time_bins equal-duration + audio data, and each bin of the LTAS consist in an average of the + fft values obtained on each of these bins. The audio is split recursively + until p_num <= nb_time_bins. """ ltas_fft = LTASData.get_ltas_fft(fft) @@ -89,6 +98,10 @@ def __init__( self.nb_time_bins = nb_time_bins def get_value(self, depth: int = 0) -> np.ndarray: + """Return the Sx matrix of the LTAS. + + The Sx matrix contains the absolute square of the STFT. + """ if self.shape[1] <= self.nb_time_bins: return super().get_value() sub_spectros = [ @@ -107,7 +120,25 @@ def get_value(self, depth: int = 0) -> np.ndarray: ).T @classmethod - def from_spectro_data(cls, spectro_data: SpectroData, nb_time_bins: int): + def from_spectro_data( + cls, spectro_data: SpectroData, nb_time_bins: int + ) -> LTASData: + """Initialize a LTASData from a SpectroData. + + Parameters + ---------- + spectro_data: SpectroData + The spectrogram to turn in a LTAS. + nb_time_bins: int + The maximum number of windows over which the audio will be split to perform + a LTAS. + + Returns + ------- + LTASData: + The LTASData instance. + + """ items = spectro_data.items audio_data = spectro_data.audio_data begin = spectro_data.begin @@ -123,7 +154,23 @@ def from_spectro_data(cls, spectro_data: SpectroData, nb_time_bins: int): ) @staticmethod - def get_ltas_fft(fft: ShortTimeFFT): + def get_ltas_fft(fft: ShortTimeFFT) -> ShortTimeFFT: + """Return a ShortTimeFFT object optimized for computing LTAS. + + The overlap of the fft is forced set to 0, as the value of consecutive + windows will in the end be averaged. + + Parameters + ---------- + fft: ShortTimeFFT + The fft to optimize for LTAS computation. + + Returns + ------- + ShortTimeFFT + The optimized fft. + + """ win = fft.win fs = fft.fs mfft = fft.mfft From 09c2842bfd7f872ed04c0e103d1977f989ae0d02 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 23 Jan 2025 13:59:59 +0100 Subject: [PATCH 097/118] add sx as an optional parameter to SpectroData plot and save methods --- src/OSmOSE/data/spectro_data.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 28340ac0..3bb5f0a2 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -129,7 +129,7 @@ def get_value(self) -> np.ndarray: return self.fft.spectrogram(self.audio_data.get_value(), padding="even") - def plot(self, ax: plt.Axes | None = None) -> None: + def plot(self, ax: plt.Axes | None = None, sx: np.ndarray | None = None) -> None: """Plot the spectrogram on a specific Axes. Parameters @@ -137,10 +137,12 @@ def plot(self, ax: plt.Axes | None = None) -> None: ax: plt.axes | None Axes on which the spectrogram should be plotted. Defaulted as the SpectroData.get_default_ax Axes. + sx: np.ndarray | None + Spectrogram sx values. Will be computed if not provided. """ ax = ax if ax is not None else SpectroData.get_default_ax() - sx = self.get_value() + sx = self.get_value() if sx is None else sx sx = 10 * np.log10(abs(sx) + np.nextafter(0, 1)) time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f @@ -163,17 +165,19 @@ def save_spectrogram(self, folder: Path, ax: plt.Axes | None = None) -> None: plt.savefig(f"{folder / str(self)}", bbox_inches="tight", pad_inches=0) plt.close() - def write(self, folder: Path) -> None: + def write(self, folder: Path, sx: np.ndarray | None = None) -> None: """Write the Spectro data to file. Parameters ---------- folder: pathlib.Path Folder in which to write the Spectro file. + sx: np.ndarray | None + Spectrogram sx values. Will be computed if not provided. """ super().create_directories(path=folder) - sx = self.get_value() + sx = self.get_value() if sx is None else sx time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f window = self.fft.win From 0b4d77ba51f24289c664280c674e6e4f8f0d2685 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 3 Feb 2025 16:56:15 +0100 Subject: [PATCH 098/118] consider both flac and wav audio files --- src/OSmOSE/data/audio_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index 346d775a..8390932a 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -106,7 +106,8 @@ def from_folder( """ files = [ AudioFile(file, strptime_format=strptime_format) - for file in folder.glob("*.wav") + for file in folder.iterdir() + if file.suffix.lower() in (".wav", ".flac") ] base_dataset = BaseDataset.from_files(files, begin, end, data_duration) return cls.from_base_dataset(base_dataset) From 7fa97b55484388aa9729ffb5dcf87d4db823477c Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Mon, 3 Feb 2025 17:24:23 +0100 Subject: [PATCH 099/118] add flac loading test --- tests/conftest.py | 18 ++++++++++-------- tests/test_audio.py | 25 +++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c6fc48fc..8a65c760 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,6 +29,7 @@ def audio_files( date_begin = request.param.get("date_begin", pd.Timestamp("2000-01-01 00:00:00")) inter_file_duration = request.param.get("inter_file_duration", 0) series_type = request.param.get("series_type", "repeat") + format = request.param.get("format", "wav") nb_samples = int(round(duration * sample_rate)) data = generate_sample_audio( @@ -43,18 +44,19 @@ def audio_files( date_begin, periods=nb_files, freq=pd.Timedelta(seconds=duration + inter_file_duration), - ) + ), ), ): time_str = begin_time.strftime(format=TIMESTAMP_FORMAT_TEST_FILES) - file = tmp_path / f"audio_{time_str}.wav" + file = tmp_path / f"audio_{time_str}.{format}" files.append(file) - sf.write( - file=file, - data=data[index], - samplerate=sample_rate, - subtype="DOUBLE", - ) + kwargs = { + "file": file, + "data": data[index], + "samplerate": sample_rate, + "subtype": "DOUBLE" if format.lower() == "wav" else "PCM_24", + } + sf.write(**kwargs) return files, request diff --git a/tests/test_audio.py b/tests/test_audio.py index eb6c8bd1..9d35fd18 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -148,6 +148,27 @@ def test_audio_file_timestamps( generate_sample_audio(nb_files=1, nb_samples=48_000)[0][43_200:], id="read_end_of_file", ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "format": "flac", + }, + pd.Timestamp("2024-01-01 12:00:00"), + pd.Timestamp( + year=2024, + month=1, + day=1, + hour=12, + minute=0, + second=0, + microsecond=100_000, + ), + generate_sample_audio(nb_files=1, nb_samples=48_000)[0][:4_800], + id="flac_file", + ), ], indirect=["audio_files"], ) @@ -159,7 +180,7 @@ def test_audio_file_read( ) -> None: files, request = audio_files file = AudioFile(files[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) - assert np.array_equal(file.read(start, stop), expected) + assert np.allclose(file.read(start, stop), expected, atol=1e-7) @pytest.mark.parametrize( @@ -473,7 +494,7 @@ def test_audio_data( indirect=True, ) def test_read_vs_soundfile( - audio_files: tuple[list[Path], pytest.fixtures.Subrequest] + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], ) -> None: audio_files, _ = audio_files af = AudioFile(audio_files[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) From f56156404c76be5564e37aee84e93b0a17154514 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 4 Feb 2025 11:36:53 +0100 Subject: [PATCH 100/118] reject non parsable files on audiodataset instantiation --- src/OSmOSE/data/audio_dataset.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index 8390932a..39306d73 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -9,6 +9,9 @@ import logging from typing import TYPE_CHECKING +from soundfile import LibsndfileError + +from OSmOSE.config import global_logging_context as glc from OSmOSE.data.audio_data import AudioData from OSmOSE.data.audio_file import AudioFile from OSmOSE.data.base_dataset import BaseDataset @@ -104,12 +107,24 @@ def from_folder( The audio dataset. """ - files = [ - AudioFile(file, strptime_format=strptime_format) - for file in folder.iterdir() - if file.suffix.lower() in (".wav", ".flac") - ] - base_dataset = BaseDataset.from_files(files, begin, end, data_duration) + audio_files = [] + rejected_files = [] + for file in folder.iterdir(): + if file.suffix.lower() not in (".wav", ".flac"): + continue + try: + af = AudioFile(file, strptime_format=strptime_format) + audio_files.append(af) + except (ValueError, LibsndfileError): + rejected_files.append(file) + + if rejected_files: + rejected_files = "\n\t".join(f.name for f in rejected_files) + glc.logger.warn( + f"The following files couldn't be parsed:\n{rejected_files}" + ) + + base_dataset = BaseDataset.from_files(audio_files, begin, end, data_duration) return cls.from_base_dataset(base_dataset) @classmethod From 991ec8067947207c6f3237feafb234652991fa1e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 4 Feb 2025 14:29:51 +0100 Subject: [PATCH 101/118] add tests for audio file check in AudioDataset.from_folder --- src/OSmOSE/data/audio_dataset.py | 5 +- tests/conftest.py | 4 + tests/test_audio.py | 186 +++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/data/audio_dataset.py index 39306d73..1c167d55 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/data/audio_dataset.py @@ -121,9 +121,12 @@ def from_folder( if rejected_files: rejected_files = "\n\t".join(f.name for f in rejected_files) glc.logger.warn( - f"The following files couldn't be parsed:\n{rejected_files}" + f"The following files couldn't be parsed:\n\t{rejected_files}", ) + if not audio_files: + raise FileNotFoundError(f"No valid audio file found in {folder}.") + base_dataset = BaseDataset.from_files(audio_files, begin, end, data_duration) return cls.from_base_dataset(base_dataset) diff --git a/tests/conftest.py b/tests/conftest.py index 8a65c760..8adcfced 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,10 @@ def audio_files( request: pytest.fixtures.Subrequest, ) -> tuple[list[Path], pytest.fixtures.Subrequest]: nb_files = request.param.get("nb_files", 1) + + if nb_files == 0: + return [], request + sample_rate = request.param.get("sample_rate", 48_000) duration = request.param.get("duration", 1.0) date_begin = request.param.get("date_begin", pd.Timestamp("2000-01-01 00:00:00")) diff --git a/tests/test_audio.py b/tests/test_audio.py index 9d35fd18..f0ba76b4 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING import numpy as np @@ -679,6 +680,191 @@ def test_audio_dataset_from_folder( ) +@pytest.mark.parametrize( + ( + "audio_files", + "expected_audio_data", + "corrupted_audio_files", + "non_audio_files", + "error", + ), + [ + pytest.param( + {"nb_files": 0}, + [], + [], + [], + pytest.raises( + FileNotFoundError, + match="No valid audio file found in ", + ), + id="no_file", + ), + pytest.param( + {"nb_files": 0}, + [], + [ + pd.Timestamp("2000-01-01 00:00:00").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".wav", + pd.Timestamp("2000-01-01 00:00:10").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".flac", + ], + [], + pytest.raises( + FileNotFoundError, + match="No valid audio file found in ", + ), + id="corrupted_audio_files", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 3, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + generate_sample_audio( + nb_files=1, + nb_samples=144_000, + series_type="increase", + ), + [ + pd.Timestamp("2000-01-01 00:00:00").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".wav", + pd.Timestamp("2000-01-01 00:00:10").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".flac", + ], + [], + None, + id="mixed_audio_files", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 3, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + generate_sample_audio( + nb_files=1, + nb_samples=144_000, + series_type="increase", + ), + [], + [ + pd.Timestamp("2000-01-01 00:00:00").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".csv" + ], + None, + id="non_audio_files_are_not_logged", + ), + pytest.param( + {"nb_files": 0}, + [], + [ + pd.Timestamp("2000-01-01 00:00:00").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".wav", + pd.Timestamp("2000-01-01 00:00:10").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".flac", + ], + [ + pd.Timestamp("2000-01-01 00:00:00").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".csv" + ], + pytest.raises( + FileNotFoundError, + match="No valid audio file found in ", + ), + id="all_but_ok_audio", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 3, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + "series_type": "increase", + }, + generate_sample_audio( + nb_files=1, + nb_samples=144_000, + series_type="increase", + ), + [ + pd.Timestamp("2000-01-01 00:00:00").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".wav", + pd.Timestamp("2000-01-01 00:00:10").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".flac", + ], + [ + pd.Timestamp("2000-01-01 00:00:00").strftime( + format=TIMESTAMP_FORMAT_TEST_FILES, + ) + + ".csv" + ], + None, + id="full_mix", + ), + ], + indirect=["audio_files"], +) +def test_audio_dataset_from_folder_errors_warnings( + tmp_path: Path, + caplog, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + expected_audio_data: list[np.ndarray], + corrupted_audio_files: list[str], + non_audio_files: list[str], + error, +) -> None: + + for corrupted_file in corrupted_audio_files: + (tmp_path / corrupted_file).open("a").close() # Write empty audio files. + + with caplog.at_level(logging.WARNING): + if error is not None: + with error as e: + assert ( + AudioDataset.from_folder( + tmp_path, + strptime_format=TIMESTAMP_FORMAT_TEST_FILES, + ) + == e + ) + else: + dataset = AudioDataset.from_folder( + tmp_path, + strptime_format=TIMESTAMP_FORMAT_TEST_FILES, + ) + assert all( + np.array_equal(data.get_value(), expected) + for (data, expected) in zip(dataset.data, expected_audio_data) + ) + assert all(f in caplog.text for f in corrupted_audio_files) + + @pytest.mark.parametrize( ("audio_files", "subtype", "expected_audio_data"), [ From 4686cc36021b97457202c098ef37808a65df25cc Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 6 Feb 2025 16:39:45 +0100 Subject: [PATCH 102/118] fix call to BaseData.create_directory in spectrogram export --- src/OSmOSE/data/spectro_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 3bb5f0a2..0584f5d2 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -160,7 +160,7 @@ def save_spectrogram(self, folder: Path, ax: plt.Axes | None = None) -> None: Defaulted as the SpectroData.get_default_ax Axes. """ - super().write(folder) + super().create_directories(path=folder) self.plot(ax) plt.savefig(f"{folder / str(self)}", bbox_inches="tight", pad_inches=0) plt.close() From dd1075f7df47ed6158b9494a22a0210f418cdce5 Mon Sep 17 00:00:00 2001 From: Gautzilla <72027971+Gautzilla@users.noreply.github.com> Date: Tue, 11 Feb 2025 15:21:22 +0100 Subject: [PATCH 103/118] Npz file concat (#3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR fixes the concatenation of multiple NPZ files to form a larger spectrogram as LTAS were computed in the legacy OSEkit. It introduces the split() method for SpectroData objects, which deviates from the behaviour of AudioData.split() as it forces the chunks to be made on frames on which a window of the SFT is centered. This is required to reconstruct a sft from sft parts, as shown in the "It is possible to calculate the SFT of signal parts:" section. This PR also improves the accuracy of the SpectroFile objects begin/end timestamps by writting the timestamps in the npz file, rather than computing it in a pretty questionable way as it was done before 🥸. --- src/OSmOSE/config.py | 2 +- src/OSmOSE/data/audio_data.py | 37 ++++ src/OSmOSE/data/audio_file.py | 4 +- src/OSmOSE/data/spectro_data.py | 38 +++- src/OSmOSE/data/spectro_dataset.py | 12 ++ src/OSmOSE/data/spectro_file.py | 17 +- src/OSmOSE/utils/audio_utils.py | 2 +- tests/test_audio.py | 85 ++++++++- tests/test_spectro.py | 280 +++++++++++++++++++++++++++++ 9 files changed, 454 insertions(+), 23 deletions(-) create mode 100644 tests/test_spectro.py diff --git a/src/OSmOSE/config.py b/src/OSmOSE/config.py index 759d470c..840af0c2 100755 --- a/src/OSmOSE/config.py +++ b/src/OSmOSE/config.py @@ -33,7 +33,7 @@ TIMESTAMP_FORMAT_AUDIO_FILE = "%Y-%m-%dT%H:%M:%S.%f%z" TIMESTAMP_FORMAT_TEST_FILES = "%y%m%d%H%M%S%f" -TIMESTAMP_FORMAT_EXPORTED_FILES = "%Y_%m_%d_%H_%M_%S" +TIMESTAMP_FORMAT_EXPORTED_FILES = "%Y_%m_%d_%H_%M_%S_%f" FPDEFAULT = 0o664 # Default file permissions DPDEFAULT = stat.S_ISGID | 0o775 # Default directory permissions diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 0be3c7a9..4bd51ba2 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -10,6 +10,7 @@ import numpy as np import soundfile as sf +from pandas import Timedelta from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES from OSmOSE.data.audio_file import AudioFile @@ -156,6 +157,42 @@ def split(self, nb_subdata: int = 2) -> list[AudioData]: for base_data in super().split(nb_subdata) ] + def split_frames(self, start_frame: int = 0, stop_frame: int = -1) -> AudioData: + """Return a new AudioData from a subpart of this AudioData's data. + + Parameters + ---------- + start_frame: int + First frame included in the new AudioData. + stop_frame: int + First frame after the last frame included in the new AudioData. + + Returns + ------- + AudioData + A new AudioData which data is included between start_frame and stop_frame. + + """ + if start_frame < 0: + raise ValueError("Start_frame must be greater than or equal to 0.") + if stop_frame < -1 or stop_frame > self.shape: + raise ValueError("Stop_frame must be lower than the length of the data.") + + start_timestamp = self.begin + Timedelta( + seconds=round(start_frame / self.sample_rate, 9) + ) + stop_timestamp = ( + self.end + if stop_frame == -1 + else self.begin + Timedelta(seconds=stop_frame / self.sample_rate) + ) + return AudioData.from_files( + list(self.files), + start_timestamp, + stop_timestamp, + sample_rate=self.sample_rate, + ) + @classmethod def from_files( cls, diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/data/audio_file.py index 405b5824..7cf7414b 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/data/audio_file.py @@ -92,8 +92,8 @@ def frames_indexes(self, start: Timestamp, stop: Timestamp) -> tuple[int, int]: First and last frames of the data. """ - start_sample = floor((start - self.begin).total_seconds() * self.sample_rate) - stop_sample = round((stop - self.begin).total_seconds() * self.sample_rate) + start_sample = floor(((start - self.begin) * self.sample_rate).total_seconds()) + stop_sample = round(((stop - self.begin) * self.sample_rate).total_seconds()) return start_sample, stop_sample @classmethod diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index 0584f5d2..fd1c5d13 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -127,7 +127,7 @@ def get_value(self) -> np.ndarray: if not self.audio_data or not self.fft: raise ValueError("SpectroData should have either items or audio_data.") - return self.fft.spectrogram(self.audio_data.get_value(), padding="even") + return self.fft.stft(self.audio_data.get_value(), padding="zeros") def plot(self, ax: plt.Axes | None = None, sx: np.ndarray | None = None) -> None: """Plot the spectrogram on a specific Axes. @@ -143,7 +143,7 @@ def plot(self, ax: plt.Axes | None = None, sx: np.ndarray | None = None) -> None """ ax = ax if ax is not None else SpectroData.get_default_ax() sx = self.get_value() if sx is None else sx - sx = 10 * np.log10(abs(sx) + np.nextafter(0, 1)) + sx = 10 * np.log10(abs(sx) ** 2 + np.nextafter(0, 1)) time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f ax.pcolormesh(time, freq, sx, vmin=-120, vmax=0) @@ -184,6 +184,7 @@ def write(self, folder: Path, sx: np.ndarray | None = None) -> None: hop = [self.fft.hop] fs = [self.fft.fs] mfft = [self.fft.mfft] + timestamps = (str(t) for t in (self.begin, self.end)) np.savez( file=folder / f"{self}.npz", fs=fs, @@ -193,8 +194,37 @@ def write(self, folder: Path, sx: np.ndarray | None = None) -> None: hop=hop, sx=sx, mfft=mfft, + timestamps="_".join(timestamps), ) + def split(self, nb_subdata: int = 2) -> list[SpectroData]: + """Split the spectro data object in the specified number of spectro subdata. + + Parameters + ---------- + nb_subdata: int + Number of subdata in which to split the data. + + Returns + ------- + list[SpectroData] + The list of SpectroData subdata objects. + + """ + split_frames = list( + np.linspace(0, self.audio_data.shape, nb_subdata + 1, dtype=int), + ) + split_frames = [ + self.fft.nearest_k_p(frame) if idx < (len(split_frames) - 1) else frame + for idx, frame in enumerate(split_frames) + ] + + ad_split = [ + self.audio_data.split_frames(start_frame=a, stop_frame=b) + for a, b in zip(split_frames, split_frames[1:]) + ] + return [SpectroData.from_audio_data(ad, self.fft) for ad in ad_split] + def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: if not all( np.array_equal(items[0].file.freq, i.file.freq) @@ -208,11 +238,11 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: output = items[0].get_value(fft=self.fft) for item in items[1:]: - p1_le = self.fft.lower_border_end[1] - self.fft.p_min - 1 + p1_le = self.fft.lower_border_end[1] - self.fft.p_min output = np.hstack( ( output[:, :-p1_le], - (output[:, -p1_le:] + item.get_value(fft=self.fft)[:, :p1_le]) / 2, + (output[:, -p1_le:] + item.get_value(fft=self.fft)[:, :p1_le]), item.get_value(fft=self.fft)[:, p1_le:], ), ) diff --git a/src/OSmOSE/data/spectro_dataset.py b/src/OSmOSE/data/spectro_dataset.py index c3bb28d6..ec67e71a 100644 --- a/src/OSmOSE/data/spectro_dataset.py +++ b/src/OSmOSE/data/spectro_dataset.py @@ -43,6 +43,18 @@ def fft(self, fft: ShortTimeFFT) -> None: for data in self.data: data.fft = fft + def save_spectrogram(self, folder: Path) -> None: + """Export all spectrogram data as png images in the specified folder. + + Parameters + ---------- + folder: Path + Folder in which the spectrograms should be saved. + + """ + for data in self.data: + data.save_spectrogram(folder) + @classmethod def from_folder( cls, diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/data/spectro_file.py index 5c09c018..d3bc0a5b 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/data/spectro_file.py @@ -62,22 +62,15 @@ def _read_metadata(self, path: PathLike) -> None: hop = int(data["hop"][0]) window = data["window"] mfft = data["mfft"][0] + timestamps = str(data["timestamps"]) self.sample_rate = sample_rate self.mfft = mfft - delta_times = [ - Timedelta(seconds=time[i] - time[i - 1]).round(freq="ns") - for i in range(1, time.shape[0]) - ] - most_frequent_delta_time = max( - ((v, delta_times.count(v)) for v in set(delta_times)), - key=lambda i: i[1], - )[0] - self.time_resolution = most_frequent_delta_time - self.end = ( - self.begin + Timedelta(seconds=time[-1]) + self.time_resolution - ).round(freq="us") + self.begin, self.end = (Timestamp(t) for t in timestamps.split("_")) + + self.time = time + self.time_resolution = (self.end - self.begin) / len(self.time) self.freq = freq diff --git a/src/OSmOSE/utils/audio_utils.py b/src/OSmOSE/utils/audio_utils.py index 41ec70c9..fa88791e 100644 --- a/src/OSmOSE/utils/audio_utils.py +++ b/src/OSmOSE/utils/audio_utils.py @@ -203,4 +203,4 @@ def resample(data: np.ndarray, origin_sr: float, target_sr: float) -> np.ndarray The resampled audio data. """ - return soxr.resample(data, origin_sr, target_sr) + return soxr.resample(data, origin_sr, target_sr, quality="QQ") diff --git a/tests/test_audio.py b/tests/test_audio.py index f0ba76b4..64192e9d 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -765,7 +765,7 @@ def test_audio_dataset_from_folder( pd.Timestamp("2000-01-01 00:00:00").strftime( format=TIMESTAMP_FORMAT_TEST_FILES, ) - + ".csv" + + ".csv", ], None, id="non_audio_files_are_not_logged", @@ -787,7 +787,7 @@ def test_audio_dataset_from_folder( pd.Timestamp("2000-01-01 00:00:00").strftime( format=TIMESTAMP_FORMAT_TEST_FILES, ) - + ".csv" + + ".csv", ], pytest.raises( FileNotFoundError, @@ -822,7 +822,7 @@ def test_audio_dataset_from_folder( pd.Timestamp("2000-01-01 00:00:00").strftime( format=TIMESTAMP_FORMAT_TEST_FILES, ) - + ".csv" + + ".csv", ], None, id="full_mix", @@ -1003,3 +1003,82 @@ def test_split_data( subdata.get_value(), data.get_value()[data_range : data_range + subdata_shape], ) + + +@pytest.mark.parametrize( + ("audio_files", "start_frame", "stop_frame", "expected_begin", "expected_data"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 0, + -1, + pd.Timestamp("2024-01-01 12:00:00"), + generate_sample_audio(1, 48_000, dtype=np.float64)[0], + id="whole_data", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 0, + 1, + pd.Timestamp("2024-01-01 12:00:00"), + generate_sample_audio(1, 48_000, dtype=np.float64)[0][:1], + id="first_frame", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 47_999, + -1, + pd.Timestamp("2024-01-01 12:00:00") + + pd.Timedelta(seconds=round(47_999 / 48_000, 9)), + generate_sample_audio(1, 48_000, dtype=np.float64)[0][-1:], + id="last_frame", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 4_800 * 3, + 4_800 * 4, + pd.Timestamp("2024-01-01 12:00:00.3"), + generate_sample_audio(1, 48_000, dtype=np.float64)[0][ + 4_800 * 3 : 4_800 * 4 + ], + id="subpart", + ), + ], + indirect=["audio_files"], +) +def test_split_data_frames( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + start_frame: int, + stop_frame: int, + expected_begin: pd.Timestamp, + expected_data: np.ndarray, +) -> None: + dataset = AudioDataset.from_folder( + tmp_path, + strptime_format=TIMESTAMP_FORMAT_TEST_FILES, + ) + ad = dataset.data[0].split_frames(start_frame, stop_frame) + + assert ad.begin == expected_begin + assert np.array_equal(ad.get_value(), expected_data) diff --git a/tests/test_spectro.py b/tests/test_spectro.py new file mode 100644 index 00000000..ce5dd5c7 --- /dev/null +++ b/tests/test_spectro.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import pytest +from scipy.signal import ShortTimeFFT +from scipy.signal.windows import hamming + +from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES, TIMESTAMP_FORMAT_TEST_FILES +from OSmOSE.data.audio_data import AudioData +from OSmOSE.data.audio_dataset import AudioDataset +from OSmOSE.data.audio_file import AudioFile +from OSmOSE.data.spectro_data import SpectroData +from OSmOSE.data.spectro_dataset import SpectroDataset +from OSmOSE.data.spectro_file import SpectroFile +from OSmOSE.utils.audio_utils import generate_sample_audio + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.mark.parametrize( + ("audio_files", "original_audio_data", "sft"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + generate_sample_audio(1, 48_000, dtype=np.float64), + ShortTimeFFT(hamming(1_024), 1024, 48_000), + id="short_spectrogram", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + generate_sample_audio(1, 1_024, dtype=np.float64), + ShortTimeFFT(hamming(1_024), 1024, 1_024), + id="data_is_one_window_long", + ), + ], + indirect=["audio_files"], +) +def test_spectrogram_shape( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + original_audio_data: list[np.ndarray], + sft: ShortTimeFFT, +) -> None: + dataset = AudioDataset.from_folder( + tmp_path, + strptime_format=TIMESTAMP_FORMAT_TEST_FILES, + ) + spectro_dataset = SpectroDataset.from_audio_dataset(dataset, sft) + for audio, spectro in zip(dataset.data, spectro_dataset.data): + assert spectro.shape == spectro.get_value().shape + assert spectro.shape == (sft.f.shape[0], sft.p_num(audio.shape)) + + +@pytest.mark.parametrize( + ("audio_files", "date_begin", "sft"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + pd.Timestamp("2024-01-01 12:00:00"), + ShortTimeFFT(hamming(1_024), 1024, 48_000), + id="second_precision", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 1_024, + "nb_files": 1, + }, + pd.Timestamp("2024-01-01 12:00:00.123"), + ShortTimeFFT(hamming(512), 512, 1_024), + id="millisecond_precision", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + pd.Timestamp("2024-01-01 12:00:00.123456"), + ShortTimeFFT(hamming(1_024), 1_024, 48_000), + id="microsecond_precision", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 48_000, + "nb_files": 1, + }, + pd.Timestamp("2024-01-01 12:00:00.123456789"), + ShortTimeFFT(hamming(1_024), 1_024, 48_000), + id="nanosecond_precision", + ), + pytest.param( + { + "duration": 1.123456789, + "sample_rate": 48_000, + "nb_files": 1, + }, + pd.Timestamp("2024-01-01 12:00:00.123456789"), + ShortTimeFFT(hamming(1_024), 1_024, 48_000), + id="nanosecond_precision_end", + ), + ], + indirect=["audio_files"], +) +def test_spectro_parameters_in_npz_files( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + date_begin: pd.Timestamp, + sft: ShortTimeFFT, +) -> None: + + af = next(AudioFile(f, begin=date_begin) for f in tmp_path.glob("*.wav")) + + ad = AudioData.from_files([af]) + sd = SpectroData.from_audio_data(ad, sft) + sd.write(tmp_path / "npz") + file = tmp_path / "npz" / f"{sd}.npz" + sf = SpectroFile(file, strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES) + + assert sf.begin == ad.begin + assert sf.end == ad.end + assert np.array_equal(sf.freq, sft.f) + assert sf.hop == sft.hop + assert sf.mfft == sft.mfft + assert sf.sample_rate == sft.fs + nb_time_bins = sft.t(ad.shape).shape[0] + assert np.array_equal( + sf.time, np.arange(nb_time_bins) * ad.duration.total_seconds() / nb_time_bins + ) + + +@pytest.mark.parametrize( + ("audio_files", "nb_chunks", "sft"), + [ + pytest.param( + { + "duration": 6, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 3, + ShortTimeFFT(hamming(1_024), 1_024, 1_024), + id="6_seconds_split_in_3", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 1, + ShortTimeFFT(hamming(1_024), 100, 1_024), + id="1_npz_file", + ), + pytest.param( + { + "duration": 6, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 3, + ShortTimeFFT(hamming(1_024), 100, 1_024), + id="6_seconds_split_in_3_with_overlap", + ), + pytest.param( + { + "duration": 8, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 4, + ShortTimeFFT(hamming(1_024), 100, 1_024), + id="8_seconds_split_in_4", + ), + pytest.param( + { + "duration": 4, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 4, + ShortTimeFFT(hamming(12_000), 12_000, 48_000), + id="high_sr_no_overlap", + ), + pytest.param( + { + "duration": 2, + "sample_rate": 48_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 3, + ShortTimeFFT(hamming(12_000), 10_000, 48_000), + id="high_sr_overlap", + ), + pytest.param( + { + "duration": 6, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 6, + ShortTimeFFT(hamming(1_024), 1_024, 1_024), + id="6_seconds_split_in_6", + ), + pytest.param( + { + "duration": 6, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 6, + ShortTimeFFT(hamming(1_024), 100, 1_024), + id="6_seconds_split_in_6_with_overlap", + ), + ], + indirect=["audio_files"], +) +def test_spectrogram_from_npz_files( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + nb_chunks: int, + sft: ShortTimeFFT, +) -> None: + afs = [ + AudioFile(f, strptime_format=TIMESTAMP_FORMAT_TEST_FILES) + for f in tmp_path.glob("*.wav") + ] + + ad = AudioData.from_files(afs) + sd = SpectroData.from_audio_data(ad, sft) + + sd_split = sd.split(nb_chunks) + + for spectro in sd_split: + spectro.write(tmp_path / "output") + assert len(list((tmp_path / "output").glob("*.npz"))) == nb_chunks + + sds = SpectroDataset.from_folder( + tmp_path / "output", + TIMESTAMP_FORMAT_EXPORTED_FILES, + ) + + assert sds.begin == ad.begin + assert sds.duration == ad.duration + assert len(sds.data) == 1 + assert sds.data[0].shape == sds.data[0].get_value().shape + + assert sds.data[0].shape == ( + sft.f.shape[0], + sft.p_num(int(ad.duration.total_seconds() * ad.sample_rate)), + ) + + assert np.allclose(sd.get_value(), sds.data[0].get_value()) From b1c1d3c7fe83144e3671e027c62e87a5e7f0d313 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 11 Feb 2025 15:53:58 +0100 Subject: [PATCH 104/118] fix AudioData.split_frames lower frame computation --- src/OSmOSE/data/audio_data.py | 3 ++- tests/test_audio.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index 4bd51ba2..b705a4b5 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -6,6 +6,7 @@ from __future__ import annotations +from math import ceil from typing import TYPE_CHECKING import numpy as np @@ -179,7 +180,7 @@ def split_frames(self, start_frame: int = 0, stop_frame: int = -1) -> AudioData: raise ValueError("Stop_frame must be lower than the length of the data.") start_timestamp = self.begin + Timedelta( - seconds=round(start_frame / self.sample_rate, 9) + seconds=ceil(start_frame / self.sample_rate * 1e9) / 1e9, ) stop_timestamp = ( self.end diff --git a/tests/test_audio.py b/tests/test_audio.py index 64192e9d..81db0fa8 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -1063,6 +1063,19 @@ def test_split_data( ], id="subpart", ), + pytest.param( + { + "duration": 1, + "sample_rate": 144_000, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + 30, + 60, + pd.Timestamp("2024-01-01 12:00:00.000208334"), + generate_sample_audio(1, 144_000, dtype=np.float64)[0][30:60], + id="higher_fs", + ), ], indirect=["audio_files"], ) From b643217c40d7ff496c42f2482d453f1deb0a4341 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 11 Feb 2025 18:00:01 +0100 Subject: [PATCH 105/118] reject DC components before computing sx values --- src/OSmOSE/data/audio_data.py | 15 ++++++++++++++- src/OSmOSE/data/spectro_data.py | 2 +- tests/test_spectro.py | 34 +++++++++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/data/audio_data.py index b705a4b5..0e58fdcf 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/data/audio_data.py @@ -94,10 +94,21 @@ def _set_sample_rate(self, sample_rate: int | None = None) -> None: return self.sample_rate = None - def get_value(self) -> np.ndarray: + def get_value(self, reject_dc: bool = False) -> np.ndarray: """Return the value of the audio data. The data from the audio file will be resampled if necessary. + + Parameters + ---------- + reject_dc: bool + If True, the values will be centered on 0. + + Returns + ------- + np.ndarray: + The value of the audio data. + """ data = np.empty(shape=self.shape) idx = 0 @@ -106,6 +117,8 @@ def get_value(self) -> np.ndarray: item_data = item_data[: min(item_data.shape[0], data.shape[0] - idx)] data[idx : idx + len(item_data)] = item_data idx += len(item_data) + if reject_dc: + data -= data.mean() return data def write(self, folder: Path, subtype: str | None = None) -> None: diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/data/spectro_data.py index fd1c5d13..b520e62f 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/data/spectro_data.py @@ -127,7 +127,7 @@ def get_value(self) -> np.ndarray: if not self.audio_data or not self.fft: raise ValueError("SpectroData should have either items or audio_data.") - return self.fft.stft(self.audio_data.get_value(), padding="zeros") + return self.fft.stft(self.audio_data.get_value(reject_dc=True), padding="zeros") def plot(self, ax: plt.Axes | None = None, sx: np.ndarray | None = None) -> None: """Plot the spectrogram on a specific Axes. diff --git a/tests/test_spectro.py b/tests/test_spectro.py index ce5dd5c7..526d4d6e 100644 --- a/tests/test_spectro.py +++ b/tests/test_spectro.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pytest +from pandas import Timedelta from scipy.signal import ShortTimeFFT from scipy.signal.windows import hamming @@ -144,7 +145,8 @@ def test_spectro_parameters_in_npz_files( assert sf.sample_rate == sft.fs nb_time_bins = sft.t(ad.shape).shape[0] assert np.array_equal( - sf.time, np.arange(nb_time_bins) * ad.duration.total_seconds() / nb_time_bins + sf.time, + np.arange(nb_time_bins) * ad.duration.total_seconds() / nb_time_bins, ) @@ -258,17 +260,45 @@ def test_spectrogram_from_npz_files( sd_split = sd.split(nb_chunks) + import soundfile as sf + for spectro in sd_split: spectro.write(tmp_path / "output") + centered_data = spectro.audio_data.get_value(reject_dc=True) + (tmp_path / "audio").mkdir(exist_ok=True) + sf.write( + file=tmp_path / "audio" / f"{spectro.audio_data}.wav", + data=centered_data, + samplerate=spectro.audio_data.sample_rate, + subtype="DOUBLE", + ) + assert len(list((tmp_path / "output").glob("*.npz"))) == nb_chunks + # Since we reject the DC of audio data before computing Sx values of each chunk, + # we must compare the concatenated chunks with an AudioData made from the + # DC-free parts. + + afs = [ + AudioFile(f, strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES) + for f in (tmp_path / "audio").glob("*.wav") + ] + ad = AudioData.from_files(afs) + sd = SpectroData.from_audio_data(ad, sft) + sds = SpectroDataset.from_folder( tmp_path / "output", TIMESTAMP_FORMAT_EXPORTED_FILES, ) assert sds.begin == ad.begin - assert sds.duration == ad.duration + + # Beats me, but Timedelta.round() raises a DivideByZeroException if done + # directly on the duration properties. + + dt1, dt2 = (Timedelta(str(dt)) for dt in (sds.duration, ad.duration)) + + assert dt1.round(freq="ms") == dt2.round(freq="ms") assert len(sds.data) == 1 assert sds.data[0].shape == sds.data[0].get_value().shape From a6fc1351cdd43eef1bd900dad182ec6a26261263 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 12 Feb 2025 09:41:33 +0100 Subject: [PATCH 106/118] rename data module to core_api module --- src/OSmOSE/core_api/__init__.py | 3 + src/OSmOSE/{data => core_api}/audio_data.py | 6 +- .../{data => core_api}/audio_dataset.py | 6 +- src/OSmOSE/{data => core_api}/audio_file.py | 4 +- .../{data => core_api}/audio_file_manager.py | 0 src/OSmOSE/{data => core_api}/audio_item.py | 6 +- src/OSmOSE/{data => core_api}/base_data.py | 6 +- src/OSmOSE/{data => core_api}/base_dataset.py | 6 +- src/OSmOSE/{data => core_api}/base_file.py | 2 +- src/OSmOSE/{data => core_api}/base_item.py | 4 +- src/OSmOSE/{data => core_api}/event.py | 0 src/OSmOSE/{data => core_api}/ltas_data.py | 356 +++++++++--------- src/OSmOSE/{data => core_api}/spectro_data.py | 8 +- .../{data => core_api}/spectro_dataset.py | 8 +- src/OSmOSE/{data => core_api}/spectro_file.py | 2 +- src/OSmOSE/{data => core_api}/spectro_item.py | 6 +- src/OSmOSE/data/__init__.py | 3 - tests/conftest.py | 2 +- tests/test_audio.py | 8 +- tests/test_audio_file_manager.py | 2 +- tests/test_event.py | 2 +- tests/test_spectro.py | 12 +- 22 files changed, 226 insertions(+), 226 deletions(-) create mode 100644 src/OSmOSE/core_api/__init__.py rename src/OSmOSE/{data => core_api}/audio_data.py (98%) rename src/OSmOSE/{data => core_api}/audio_dataset.py (96%) rename src/OSmOSE/{data => core_api}/audio_file.py (97%) rename src/OSmOSE/{data => core_api}/audio_file_manager.py (100%) rename src/OSmOSE/{data => core_api}/audio_item.py (94%) rename src/OSmOSE/{data => core_api}/base_data.py (97%) rename src/OSmOSE/{data => core_api}/base_dataset.py (95%) rename src/OSmOSE/{data => core_api}/base_file.py (98%) rename src/OSmOSE/{data => core_api}/base_item.py (96%) rename src/OSmOSE/{data => core_api}/event.py (100%) rename src/OSmOSE/{data => core_api}/ltas_data.py (93%) rename src/OSmOSE/{data => core_api}/spectro_data.py (97%) rename src/OSmOSE/{data => core_api}/spectro_dataset.py (94%) rename src/OSmOSE/{data => core_api}/spectro_file.py (98%) rename src/OSmOSE/{data => core_api}/spectro_item.py (94%) delete mode 100644 src/OSmOSE/data/__init__.py diff --git a/src/OSmOSE/core_api/__init__.py b/src/OSmOSE/core_api/__init__.py new file mode 100644 index 00000000..e73cf86a --- /dev/null +++ b/src/OSmOSE/core_api/__init__.py @@ -0,0 +1,3 @@ +from OSmOSE.core_api.audio_file_manager import AudioFileManager + +audio_file_manager = AudioFileManager() diff --git a/src/OSmOSE/data/audio_data.py b/src/OSmOSE/core_api/audio_data.py similarity index 98% rename from src/OSmOSE/data/audio_data.py rename to src/OSmOSE/core_api/audio_data.py index 0e58fdcf..b4b4db6e 100644 --- a/src/OSmOSE/data/audio_data.py +++ b/src/OSmOSE/core_api/audio_data.py @@ -14,9 +14,9 @@ from pandas import Timedelta from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES -from OSmOSE.data.audio_file import AudioFile -from OSmOSE.data.audio_item import AudioItem -from OSmOSE.data.base_data import BaseData +from OSmOSE.core_api.audio_file import AudioFile +from OSmOSE.core_api.audio_item import AudioItem +from OSmOSE.core_api.base_data import BaseData from OSmOSE.utils.audio_utils import resample if TYPE_CHECKING: diff --git a/src/OSmOSE/data/audio_dataset.py b/src/OSmOSE/core_api/audio_dataset.py similarity index 96% rename from src/OSmOSE/data/audio_dataset.py rename to src/OSmOSE/core_api/audio_dataset.py index 1c167d55..1fa9ba81 100644 --- a/src/OSmOSE/data/audio_dataset.py +++ b/src/OSmOSE/core_api/audio_dataset.py @@ -12,9 +12,9 @@ from soundfile import LibsndfileError from OSmOSE.config import global_logging_context as glc -from OSmOSE.data.audio_data import AudioData -from OSmOSE.data.audio_file import AudioFile -from OSmOSE.data.base_dataset import BaseDataset +from OSmOSE.core_api.audio_data import AudioData +from OSmOSE.core_api.audio_file import AudioFile +from OSmOSE.core_api.base_dataset import BaseDataset if TYPE_CHECKING: from pathlib import Path diff --git a/src/OSmOSE/data/audio_file.py b/src/OSmOSE/core_api/audio_file.py similarity index 97% rename from src/OSmOSE/data/audio_file.py rename to src/OSmOSE/core_api/audio_file.py index 7cf7414b..6456145c 100644 --- a/src/OSmOSE/data/audio_file.py +++ b/src/OSmOSE/core_api/audio_file.py @@ -12,8 +12,8 @@ from pandas import Timedelta, Timestamp -from OSmOSE.data import audio_file_manager as afm -from OSmOSE.data.base_file import BaseFile +from OSmOSE.core_api import audio_file_manager as afm +from OSmOSE.core_api.base_file import BaseFile class AudioFile(BaseFile): diff --git a/src/OSmOSE/data/audio_file_manager.py b/src/OSmOSE/core_api/audio_file_manager.py similarity index 100% rename from src/OSmOSE/data/audio_file_manager.py rename to src/OSmOSE/core_api/audio_file_manager.py diff --git a/src/OSmOSE/data/audio_item.py b/src/OSmOSE/core_api/audio_item.py similarity index 94% rename from src/OSmOSE/data/audio_item.py rename to src/OSmOSE/core_api/audio_item.py index 66eedeb1..f8ce60f3 100644 --- a/src/OSmOSE/data/audio_item.py +++ b/src/OSmOSE/core_api/audio_item.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING -from OSmOSE.data.audio_file import AudioFile -from OSmOSE.data.base_file import BaseFile -from OSmOSE.data.base_item import BaseItem +from OSmOSE.core_api.audio_file import AudioFile +from OSmOSE.core_api.base_file import BaseFile +from OSmOSE.core_api.base_item import BaseItem if TYPE_CHECKING: import numpy as np diff --git a/src/OSmOSE/data/base_data.py b/src/OSmOSE/core_api/base_data.py similarity index 97% rename from src/OSmOSE/data/base_data.py rename to src/OSmOSE/core_api/base_data.py index f06fe304..62688351 100644 --- a/src/OSmOSE/data/base_data.py +++ b/src/OSmOSE/core_api/base_data.py @@ -12,9 +12,9 @@ from pandas import date_range from OSmOSE.config import DPDEFAULT -from OSmOSE.data.base_file import BaseFile -from OSmOSE.data.base_item import BaseItem -from OSmOSE.data.event import Event +from OSmOSE.core_api.base_file import BaseFile +from OSmOSE.core_api.base_item import BaseItem +from OSmOSE.core_api.event import Event if TYPE_CHECKING: from pathlib import Path diff --git a/src/OSmOSE/data/base_dataset.py b/src/OSmOSE/core_api/base_dataset.py similarity index 95% rename from src/OSmOSE/data/base_dataset.py rename to src/OSmOSE/core_api/base_dataset.py index ade08cc9..0148315a 100644 --- a/src/OSmOSE/data/base_dataset.py +++ b/src/OSmOSE/core_api/base_dataset.py @@ -10,9 +10,9 @@ from pandas import Timedelta, Timestamp, date_range -from OSmOSE.data.base_data import BaseData -from OSmOSE.data.base_file import BaseFile -from OSmOSE.data.event import Event +from OSmOSE.core_api.base_data import BaseData +from OSmOSE.core_api.base_file import BaseFile +from OSmOSE.core_api.event import Event if TYPE_CHECKING: from pathlib import Path diff --git a/src/OSmOSE/data/base_file.py b/src/OSmOSE/core_api/base_file.py similarity index 98% rename from src/OSmOSE/data/base_file.py rename to src/OSmOSE/core_api/base_file.py index 41f5323c..8d182b2a 100644 --- a/src/OSmOSE/data/base_file.py +++ b/src/OSmOSE/core_api/base_file.py @@ -15,7 +15,7 @@ from pathlib import Path -from OSmOSE.data.event import Event +from OSmOSE.core_api.event import Event from OSmOSE.utils.timestamp_utils import strptime_from_text diff --git a/src/OSmOSE/data/base_item.py b/src/OSmOSE/core_api/base_item.py similarity index 96% rename from src/OSmOSE/data/base_item.py rename to src/OSmOSE/core_api/base_item.py index c329fdc9..655d3310 100644 --- a/src/OSmOSE/data/base_item.py +++ b/src/OSmOSE/core_api/base_item.py @@ -9,8 +9,8 @@ import numpy as np -from OSmOSE.data.base_file import BaseFile -from OSmOSE.data.event import Event +from OSmOSE.core_api.base_file import BaseFile +from OSmOSE.core_api.event import Event if TYPE_CHECKING: from pandas import Timestamp diff --git a/src/OSmOSE/data/event.py b/src/OSmOSE/core_api/event.py similarity index 100% rename from src/OSmOSE/data/event.py rename to src/OSmOSE/core_api/event.py diff --git a/src/OSmOSE/data/ltas_data.py b/src/OSmOSE/core_api/ltas_data.py similarity index 93% rename from src/OSmOSE/data/ltas_data.py rename to src/OSmOSE/core_api/ltas_data.py index b2979734..51863af6 100644 --- a/src/OSmOSE/data/ltas_data.py +++ b/src/OSmOSE/core_api/ltas_data.py @@ -1,178 +1,178 @@ -"""LTASData is a special form of SpectroData. - -The Sx values from a LTASData object are computed recursively. -LTAS should be preferred to classic spectrograms in cases where the audio is really long. -In that case, the corresponding number of time bins (scipy.ShortTimeFTT.p_nums) is -too long for the whole Sx matrix to be computed once. - -The LTAS are rather computed recursively. If the number of temporal bins is higher than -a target p_num value, the audio is split in p_num parts. A separate sft is computed -on each of these bits and averaged so that the end Sx presents p_num temporal windows. - -This averaging is performed recursively: if the audio data is such that after a first split, -the p_nums for each part still is higher than p_num, the parts are further split and -each part is replaced with an average of the stft performed within it. - -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np -from scipy.signal import ShortTimeFFT -from tqdm import tqdm - -from OSmOSE.data.spectro_data import SpectroData -from OSmOSE.data.spectro_item import SpectroItem - -if TYPE_CHECKING: - - from pandas import Timestamp - - from OSmOSE.data.audio_data import AudioData - - -class LTASData(SpectroData): - """LTASData is a special form of SpectroData. - - The Sx values from a LTASData object are computed recursively. - LTAS should be preferred to classic spectrograms in cases where the audio is really long. - In that case, the corresponding number of time bins (scipy.ShortTimeFTT.p_nums) is - too long for the whole Sx matrix to be computed once. - - The LTAS are rather computed recursively. If the number of temporal bins is higher than - a target p_num value, the audio is split in p_num parts. A separate sft is computed - on each of these bits and averaged so that the end Sx presents p_num temporal windows. - - This averaging is performed recursively: if the audio data is such that after a first split, - the p_nums for each part still is higher than p_num, the parts are further split and - each part is replaced with an average of the stft performed within it. - - """ - - def __init__( - self, - items: list[SpectroItem] | None = None, - audio_data: AudioData = None, - begin: Timestamp | None = None, - end: Timestamp | None = None, - fft: ShortTimeFFT | None = None, - nb_time_bins: int = 1920, - ) -> None: - """Initialize a SpectroData from a list of SpectroItems. - - Parameters - ---------- - items: list[SpectroItem] - List of the SpectroItem constituting the SpectroData. - audio_data: AudioData - The audio data from which to compute the spectrogram. - begin: Timestamp | None - Only effective if items is None. - Set the begin of the empty data. - end: Timestamp | None - Only effective if items is None. - Set the end of the empty data. - fft: ShortTimeFFT - The short time FFT used for computing the spectrogram. - nb_time_bins: int - The maximum number of time bins of the LTAS. - Given the audio data and the fft parameters, - if the resulting spectrogram has a number of windows p_num - <= nb_time_bins, the LTAS is computed like a classic spectrogram. - Otherwise, the audio data is split in nb_time_bins equal-duration - audio data, and each bin of the LTAS consist in an average of the - fft values obtained on each of these bins. The audio is split recursively - until p_num <= nb_time_bins. - - """ - ltas_fft = LTASData.get_ltas_fft(fft) - super().__init__( - items=items, - audio_data=audio_data, - begin=begin, - end=end, - fft=ltas_fft, - ) - self.nb_time_bins = nb_time_bins - - def get_value(self, depth: int = 0) -> np.ndarray: - """Return the Sx matrix of the LTAS. - - The Sx matrix contains the absolute square of the STFT. - """ - if self.shape[1] <= self.nb_time_bins: - return super().get_value() - sub_spectros = [ - LTASData.from_spectro_data( - SpectroData.from_audio_data(ad, self.fft), - nb_time_bins=self.nb_time_bins, - ) - for ad in self.audio_data.split(self.nb_time_bins) - ] - - return np.vstack( - [ - np.mean(sub_spectro.get_value(depth + 1), axis=1) - for sub_spectro in (sub_spectros if depth != 0 else tqdm(sub_spectros)) - ], - ).T - - @classmethod - def from_spectro_data( - cls, spectro_data: SpectroData, nb_time_bins: int - ) -> LTASData: - """Initialize a LTASData from a SpectroData. - - Parameters - ---------- - spectro_data: SpectroData - The spectrogram to turn in a LTAS. - nb_time_bins: int - The maximum number of windows over which the audio will be split to perform - a LTAS. - - Returns - ------- - LTASData: - The LTASData instance. - - """ - items = spectro_data.items - audio_data = spectro_data.audio_data - begin = spectro_data.begin - end = spectro_data.end - fft = spectro_data.fft - return cls( - items=items, - audio_data=audio_data, - begin=begin, - end=end, - fft=fft, - nb_time_bins=nb_time_bins, - ) - - @staticmethod - def get_ltas_fft(fft: ShortTimeFFT) -> ShortTimeFFT: - """Return a ShortTimeFFT object optimized for computing LTAS. - - The overlap of the fft is forced set to 0, as the value of consecutive - windows will in the end be averaged. - - Parameters - ---------- - fft: ShortTimeFFT - The fft to optimize for LTAS computation. - - Returns - ------- - ShortTimeFFT - The optimized fft. - - """ - win = fft.win - fs = fft.fs - mfft = fft.mfft - hop = win.shape[0] - return ShortTimeFFT(win=win, hop=hop, fs=fs, mfft=mfft) +"""LTASData is a special form of SpectroData. + +The Sx values from a LTASData object are computed recursively. +LTAS should be preferred to classic spectrograms in cases where the audio is really long. +In that case, the corresponding number of time bins (scipy.ShortTimeFTT.p_nums) is +too long for the whole Sx matrix to be computed once. + +The LTAS are rather computed recursively. If the number of temporal bins is higher than +a target p_num value, the audio is split in p_num parts. A separate sft is computed +on each of these bits and averaged so that the end Sx presents p_num temporal windows. + +This averaging is performed recursively: if the audio data is such that after a first split, +the p_nums for each part still is higher than p_num, the parts are further split and +each part is replaced with an average of the stft performed within it. + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from scipy.signal import ShortTimeFFT +from tqdm import tqdm + +from OSmOSE.core_api.spectro_data import SpectroData +from OSmOSE.core_api.spectro_item import SpectroItem + +if TYPE_CHECKING: + + from pandas import Timestamp + + from OSmOSE.core_api.audio_data import AudioData + + +class LTASData(SpectroData): + """LTASData is a special form of SpectroData. + + The Sx values from a LTASData object are computed recursively. + LTAS should be preferred to classic spectrograms in cases where the audio is really long. + In that case, the corresponding number of time bins (scipy.ShortTimeFTT.p_nums) is + too long for the whole Sx matrix to be computed once. + + The LTAS are rather computed recursively. If the number of temporal bins is higher than + a target p_num value, the audio is split in p_num parts. A separate sft is computed + on each of these bits and averaged so that the end Sx presents p_num temporal windows. + + This averaging is performed recursively: if the audio data is such that after a first split, + the p_nums for each part still is higher than p_num, the parts are further split and + each part is replaced with an average of the stft performed within it. + + """ + + def __init__( + self, + items: list[SpectroItem] | None = None, + audio_data: AudioData = None, + begin: Timestamp | None = None, + end: Timestamp | None = None, + fft: ShortTimeFFT | None = None, + nb_time_bins: int = 1920, + ) -> None: + """Initialize a SpectroData from a list of SpectroItems. + + Parameters + ---------- + items: list[SpectroItem] + List of the SpectroItem constituting the SpectroData. + audio_data: AudioData + The audio data from which to compute the spectrogram. + begin: Timestamp | None + Only effective if items is None. + Set the begin of the empty data. + end: Timestamp | None + Only effective if items is None. + Set the end of the empty data. + fft: ShortTimeFFT + The short time FFT used for computing the spectrogram. + nb_time_bins: int + The maximum number of time bins of the LTAS. + Given the audio data and the fft parameters, + if the resulting spectrogram has a number of windows p_num + <= nb_time_bins, the LTAS is computed like a classic spectrogram. + Otherwise, the audio data is split in nb_time_bins equal-duration + audio data, and each bin of the LTAS consist in an average of the + fft values obtained on each of these bins. The audio is split recursively + until p_num <= nb_time_bins. + + """ + ltas_fft = LTASData.get_ltas_fft(fft) + super().__init__( + items=items, + audio_data=audio_data, + begin=begin, + end=end, + fft=ltas_fft, + ) + self.nb_time_bins = nb_time_bins + + def get_value(self, depth: int = 0) -> np.ndarray: + """Return the Sx matrix of the LTAS. + + The Sx matrix contains the absolute square of the STFT. + """ + if self.shape[1] <= self.nb_time_bins: + return super().get_value() + sub_spectros = [ + LTASData.from_spectro_data( + SpectroData.from_audio_data(ad, self.fft), + nb_time_bins=self.nb_time_bins, + ) + for ad in self.audio_data.split(self.nb_time_bins) + ] + + return np.vstack( + [ + np.mean(sub_spectro.get_value(depth + 1), axis=1) + for sub_spectro in (sub_spectros if depth != 0 else tqdm(sub_spectros)) + ], + ).T + + @classmethod + def from_spectro_data( + cls, spectro_data: SpectroData, nb_time_bins: int, + ) -> LTASData: + """Initialize a LTASData from a SpectroData. + + Parameters + ---------- + spectro_data: SpectroData + The spectrogram to turn in a LTAS. + nb_time_bins: int + The maximum number of windows over which the audio will be split to perform + a LTAS. + + Returns + ------- + LTASData: + The LTASData instance. + + """ + items = spectro_data.items + audio_data = spectro_data.audio_data + begin = spectro_data.begin + end = spectro_data.end + fft = spectro_data.fft + return cls( + items=items, + audio_data=audio_data, + begin=begin, + end=end, + fft=fft, + nb_time_bins=nb_time_bins, + ) + + @staticmethod + def get_ltas_fft(fft: ShortTimeFFT) -> ShortTimeFFT: + """Return a ShortTimeFFT object optimized for computing LTAS. + + The overlap of the fft is forced set to 0, as the value of consecutive + windows will in the end be averaged. + + Parameters + ---------- + fft: ShortTimeFFT + The fft to optimize for LTAS computation. + + Returns + ------- + ShortTimeFFT + The optimized fft. + + """ + win = fft.win + fs = fft.fs + mfft = fft.mfft + hop = win.shape[0] + return ShortTimeFFT(win=win, hop=hop, fs=fs, mfft=mfft) diff --git a/src/OSmOSE/data/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py similarity index 97% rename from src/OSmOSE/data/spectro_data.py rename to src/OSmOSE/core_api/spectro_data.py index b520e62f..cb2cdf45 100644 --- a/src/OSmOSE/data/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -12,9 +12,9 @@ import numpy as np from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES -from OSmOSE.data.base_data import BaseData -from OSmOSE.data.spectro_file import SpectroFile -from OSmOSE.data.spectro_item import SpectroItem +from OSmOSE.core_api.base_data import BaseData +from OSmOSE.core_api.spectro_file import SpectroFile +from OSmOSE.core_api.spectro_item import SpectroItem if TYPE_CHECKING: from pathlib import Path @@ -22,7 +22,7 @@ from pandas import Timestamp from scipy.signal import ShortTimeFFT - from OSmOSE.data.audio_data import AudioData + from OSmOSE.core_api.audio_data import AudioData class SpectroData(BaseData[SpectroItem, SpectroFile]): diff --git a/src/OSmOSE/data/spectro_dataset.py b/src/OSmOSE/core_api/spectro_dataset.py similarity index 94% rename from src/OSmOSE/data/spectro_dataset.py rename to src/OSmOSE/core_api/spectro_dataset.py index ec67e71a..1fe12e6b 100644 --- a/src/OSmOSE/data/spectro_dataset.py +++ b/src/OSmOSE/core_api/spectro_dataset.py @@ -8,9 +8,9 @@ from typing import TYPE_CHECKING -from OSmOSE.data.base_dataset import BaseDataset -from OSmOSE.data.spectro_data import SpectroData -from OSmOSE.data.spectro_file import SpectroFile +from OSmOSE.core_api.base_dataset import BaseDataset +from OSmOSE.core_api.spectro_data import SpectroData +from OSmOSE.core_api.spectro_file import SpectroFile if TYPE_CHECKING: from pathlib import Path @@ -18,7 +18,7 @@ from pandas import Timedelta, Timestamp from scipy.signal import ShortTimeFFT - from OSmOSE.data.audio_dataset import AudioDataset + from OSmOSE.core_api.audio_dataset import AudioDataset class SpectroDataset(BaseDataset[SpectroData, SpectroFile]): diff --git a/src/OSmOSE/data/spectro_file.py b/src/OSmOSE/core_api/spectro_file.py similarity index 98% rename from src/OSmOSE/data/spectro_file.py rename to src/OSmOSE/core_api/spectro_file.py index d3bc0a5b..d3206217 100644 --- a/src/OSmOSE/data/spectro_file.py +++ b/src/OSmOSE/core_api/spectro_file.py @@ -12,7 +12,7 @@ from pandas import Timedelta, Timestamp from scipy.signal import ShortTimeFFT -from OSmOSE.data.base_file import BaseFile +from OSmOSE.core_api.base_file import BaseFile if TYPE_CHECKING: from os import PathLike diff --git a/src/OSmOSE/data/spectro_item.py b/src/OSmOSE/core_api/spectro_item.py similarity index 94% rename from src/OSmOSE/data/spectro_item.py rename to src/OSmOSE/core_api/spectro_item.py index 18c67099..1f94723f 100644 --- a/src/OSmOSE/data/spectro_item.py +++ b/src/OSmOSE/core_api/spectro_item.py @@ -6,9 +6,9 @@ import numpy as np -from OSmOSE.data.base_file import BaseFile -from OSmOSE.data.base_item import BaseItem -from OSmOSE.data.spectro_file import SpectroFile +from OSmOSE.core_api.base_file import BaseFile +from OSmOSE.core_api.base_item import BaseItem +from OSmOSE.core_api.spectro_file import SpectroFile if TYPE_CHECKING: from pandas import Timedelta, Timestamp diff --git a/src/OSmOSE/data/__init__.py b/src/OSmOSE/data/__init__.py deleted file mode 100644 index 4d2511df..00000000 --- a/src/OSmOSE/data/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from OSmOSE.data.audio_file_manager import AudioFileManager - -audio_file_manager = AudioFileManager() diff --git a/tests/conftest.py b/tests/conftest.py index 8adcfced..1f97c522 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from scipy.signal import chirp from OSmOSE.config import OSMOSE_PATH, TIMESTAMP_FORMAT_TEST_FILES -from OSmOSE.data import AudioFileManager +from OSmOSE.core_api import AudioFileManager from OSmOSE.utils.audio_utils import generate_sample_audio diff --git a/tests/test_audio.py b/tests/test_audio.py index 81db0fa8..e510eb47 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -9,10 +9,10 @@ import soundfile as sf from OSmOSE.config import TIMESTAMP_FORMAT_TEST_FILES -from OSmOSE.data.audio_data import AudioData -from OSmOSE.data.audio_dataset import AudioDataset -from OSmOSE.data.audio_file import AudioFile -from OSmOSE.data.audio_item import AudioItem +from OSmOSE.core_api.audio_data import AudioData +from OSmOSE.core_api.audio_dataset import AudioDataset +from OSmOSE.core_api.audio_file import AudioFile +from OSmOSE.core_api.audio_item import AudioItem from OSmOSE.utils.audio_utils import generate_sample_audio if TYPE_CHECKING: diff --git a/tests/test_audio_file_manager.py b/tests/test_audio_file_manager.py index 878c607c..862aaeed 100644 --- a/tests/test_audio_file_manager.py +++ b/tests/test_audio_file_manager.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from OSmOSE.data.audio_file_manager import AudioFileManager +from OSmOSE.core_api.audio_file_manager import AudioFileManager from OSmOSE.utils.audio_utils import generate_sample_audio diff --git a/tests/test_event.py b/tests/test_event.py index 330543c8..c4bee806 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -3,7 +3,7 @@ import pytest from pandas import Timestamp -from OSmOSE.data.event import Event +from OSmOSE.core_api.event import Event @pytest.mark.parametrize( diff --git a/tests/test_spectro.py b/tests/test_spectro.py index 526d4d6e..8e32e95b 100644 --- a/tests/test_spectro.py +++ b/tests/test_spectro.py @@ -10,12 +10,12 @@ from scipy.signal.windows import hamming from OSmOSE.config import TIMESTAMP_FORMAT_EXPORTED_FILES, TIMESTAMP_FORMAT_TEST_FILES -from OSmOSE.data.audio_data import AudioData -from OSmOSE.data.audio_dataset import AudioDataset -from OSmOSE.data.audio_file import AudioFile -from OSmOSE.data.spectro_data import SpectroData -from OSmOSE.data.spectro_dataset import SpectroDataset -from OSmOSE.data.spectro_file import SpectroFile +from OSmOSE.core_api.audio_data import AudioData +from OSmOSE.core_api.audio_dataset import AudioDataset +from OSmOSE.core_api.audio_file import AudioFile +from OSmOSE.core_api.spectro_data import SpectroData +from OSmOSE.core_api.spectro_dataset import SpectroDataset +from OSmOSE.core_api.spectro_file import SpectroFile from OSmOSE.utils.audio_utils import generate_sample_audio if TYPE_CHECKING: From f9f693aa024d1ec13557644c6413a3ef145d9bd7 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 12 Feb 2025 10:39:21 +0100 Subject: [PATCH 107/118] fill empty spectro items with complex zeros --- src/OSmOSE/core_api/ltas_data.py | 4 +++- src/OSmOSE/core_api/spectro_item.py | 14 ++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/OSmOSE/core_api/ltas_data.py b/src/OSmOSE/core_api/ltas_data.py index 51863af6..5ea58935 100644 --- a/src/OSmOSE/core_api/ltas_data.py +++ b/src/OSmOSE/core_api/ltas_data.py @@ -121,7 +121,9 @@ def get_value(self, depth: int = 0) -> np.ndarray: @classmethod def from_spectro_data( - cls, spectro_data: SpectroData, nb_time_bins: int, + cls, + spectro_data: SpectroData, + nb_time_bins: int, ) -> LTASData: """Initialize a LTASData from a SpectroData. diff --git a/src/OSmOSE/core_api/spectro_item.py b/src/OSmOSE/core_api/spectro_item.py index 1f94723f..e69cbb4d 100644 --- a/src/OSmOSE/core_api/spectro_item.py +++ b/src/OSmOSE/core_api/spectro_item.py @@ -67,12 +67,10 @@ def get_value(self, fft: ShortTimeFFT | None = None) -> np.ndarray: if not self.is_empty: return self.file.read(start=self.begin, stop=self.end) - return ( - np.ones( - ( - fft.f.shape[0], - fft.p_num(int(self.duration.total_seconds() * fft.fs)), - ), - ) - * -120.0 + return np.zeros( + ( + fft.f.shape[0], + fft.p_num(int(self.duration.total_seconds() * fft.fs)), + ), + dtype=complex, ) From b9715511426460f4c2b3765819d5f213d66f382e Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 12 Feb 2025 10:42:27 +0100 Subject: [PATCH 108/118] fix SpectroData.nb_bytes property to match complex values --- src/OSmOSE/core_api/spectro_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/OSmOSE/core_api/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py index cb2cdf45..4f8ef920 100644 --- a/src/OSmOSE/core_api/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -111,7 +111,7 @@ def shape(self) -> tuple[int, ...]: @property def nb_bytes(self) -> int: """Total bytes consumed by the spectro values.""" - return self.shape[0] * self.shape[1] * 8 + return self.shape[0] * self.shape[1] * 16 def __str__(self) -> str: """Overwrite __str__.""" From a6217180574b31cd7a0b3fc7249d352d2d8ec731 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Wed, 12 Feb 2025 16:05:24 +0100 Subject: [PATCH 109/118] add sx parameter to SpectroData.save_spectrogram --- src/OSmOSE/core_api/spectro_data.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/OSmOSE/core_api/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py index 4f8ef920..b1741020 100644 --- a/src/OSmOSE/core_api/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -148,7 +148,9 @@ def plot(self, ax: plt.Axes | None = None, sx: np.ndarray | None = None) -> None freq = self.fft.f ax.pcolormesh(time, freq, sx, vmin=-120, vmax=0) - def save_spectrogram(self, folder: Path, ax: plt.Axes | None = None) -> None: + def save_spectrogram( + self, folder: Path, ax: plt.Axes | None = None, sx: np.ndarray | None = None + ) -> None: """Export the spectrogram as a png image. Parameters @@ -158,10 +160,12 @@ def save_spectrogram(self, folder: Path, ax: plt.Axes | None = None) -> None: ax: plt.Axes | None Axes on which the spectrogram should be plotted. Defaulted as the SpectroData.get_default_ax Axes. + sx: np.ndarray | None + Spectrogram sx values. Will be computed if not provided. """ super().create_directories(path=folder) - self.plot(ax) + self.plot(ax, sx) plt.savefig(f"{folder / str(self)}", bbox_inches="tight", pad_inches=0) plt.close() From c3a87829f3f87470319159cf495340f03d31e250 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 13 Feb 2025 10:54:20 +0100 Subject: [PATCH 110/118] add from_audio_data LTASData classmethod --- src/OSmOSE/core_api/ltas_data.py | 30 +++++++++++++++++++++++++++++ src/OSmOSE/core_api/spectro_data.py | 5 ++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/core_api/ltas_data.py b/src/OSmOSE/core_api/ltas_data.py index 5ea58935..a754909d 100644 --- a/src/OSmOSE/core_api/ltas_data.py +++ b/src/OSmOSE/core_api/ltas_data.py @@ -155,6 +155,36 @@ def from_spectro_data( nb_time_bins=nb_time_bins, ) + @classmethod + def from_audio_data( + cls, data: AudioData, fft: ShortTimeFFT, nb_time_bins: int = 1920, + ) -> SpectroData: + """Instantiate a SpectroData object from a AudioData object. + + Parameters + ---------- + data: AudioData + Audio data from which the SpectroData should be computed. + fft: ShortTimeFFT + The ShortTimeFFT used to compute the spectrogram. + nb_time_bins: int + The maximum number of windows over which the audio will be split to perform + Defaulted to 1920. + + Returns + ------- + LTASData: + The SpectroData object. + + """ + return cls( + audio_data=data, + fft=fft, + begin=data.begin, + end=data.end, + nb_time_bins=nb_time_bins, + ) + @staticmethod def get_ltas_fft(fft: ShortTimeFFT) -> ShortTimeFFT: """Return a ShortTimeFFT object optimized for computing LTAS. diff --git a/src/OSmOSE/core_api/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py index b1741020..24771ae4 100644 --- a/src/OSmOSE/core_api/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -149,7 +149,10 @@ def plot(self, ax: plt.Axes | None = None, sx: np.ndarray | None = None) -> None ax.pcolormesh(time, freq, sx, vmin=-120, vmax=0) def save_spectrogram( - self, folder: Path, ax: plt.Axes | None = None, sx: np.ndarray | None = None + self, + folder: Path, + ax: plt.Axes | None = None, + sx: np.ndarray | None = None, ) -> None: """Export the spectrogram as a png image. From 14c80489f71d24868d65cad442cbd961ed272ce9 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 13 Feb 2025 11:54:50 +0100 Subject: [PATCH 111/118] add matrix_dtype SpectroData attribute --- src/OSmOSE/core_api/ltas_data.py | 8 ++++++-- src/OSmOSE/core_api/spectro_data.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/OSmOSE/core_api/ltas_data.py b/src/OSmOSE/core_api/ltas_data.py index a754909d..2d9b89c3 100644 --- a/src/OSmOSE/core_api/ltas_data.py +++ b/src/OSmOSE/core_api/ltas_data.py @@ -23,7 +23,7 @@ from scipy.signal import ShortTimeFFT from tqdm import tqdm -from OSmOSE.core_api.spectro_data import SpectroData +from OSmOSE.core_api.spectro_data import MatrixDtype, SpectroData from OSmOSE.core_api.spectro_item import SpectroItem if TYPE_CHECKING: @@ -96,6 +96,7 @@ def __init__( fft=ltas_fft, ) self.nb_time_bins = nb_time_bins + self.matrix_dtype = MatrixDtype.Absolute def get_value(self, depth: int = 0) -> np.ndarray: """Return the Sx matrix of the LTAS. @@ -157,7 +158,10 @@ def from_spectro_data( @classmethod def from_audio_data( - cls, data: AudioData, fft: ShortTimeFFT, nb_time_bins: int = 1920, + cls, + data: AudioData, + fft: ShortTimeFFT, + nb_time_bins: int = 1920, ) -> SpectroData: """Instantiate a SpectroData object from a AudioData object. diff --git a/src/OSmOSE/core_api/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py index 24771ae4..d31bab67 100644 --- a/src/OSmOSE/core_api/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -6,6 +6,7 @@ from __future__ import annotations +from enum import Enum from typing import TYPE_CHECKING import matplotlib.pyplot as plt @@ -25,6 +26,18 @@ from OSmOSE.core_api.audio_data import AudioData +class MatrixDtype(Enum): + """Represent the dtype of the spectrum values. + + Complex will keep the phase info. + Absolute is more suited for LTAS-like spectrograms. + + """ + + Complex = "complex" + Absolute = "absolute" + + class SpectroData(BaseData[SpectroItem, SpectroFile]): """SpectroData represent Spectro data scattered through different SpectroFiles. @@ -61,6 +74,7 @@ def __init__( super().__init__(items=items, begin=begin, end=end) self.audio_data = audio_data self.fft = fft + self.matrix_dtype = MatrixDtype.Complex @staticmethod def get_default_ax() -> plt.Axes: @@ -127,7 +141,12 @@ def get_value(self) -> np.ndarray: if not self.audio_data or not self.fft: raise ValueError("SpectroData should have either items or audio_data.") - return self.fft.stft(self.audio_data.get_value(reject_dc=True), padding="zeros") + sx = self.fft.stft(self.audio_data.get_value(reject_dc=True), padding="zeros") + + if self.matrix_dtype == MatrixDtype.Absolute: + sx = abs(sx) ** 2 + + return sx def plot(self, ax: plt.Axes | None = None, sx: np.ndarray | None = None) -> None: """Plot the spectrogram on a specific Axes. @@ -143,9 +162,14 @@ def plot(self, ax: plt.Axes | None = None, sx: np.ndarray | None = None) -> None """ ax = ax if ax is not None else SpectroData.get_default_ax() sx = self.get_value() if sx is None else sx - sx = 10 * np.log10(abs(sx) ** 2 + np.nextafter(0, 1)) + + if self.matrix_dtype == MatrixDtype.Complex: + sx = abs(sx) ** 2 + + sx = 10 * np.log10(sx + np.nextafter(0, 1)) time = np.arange(sx.shape[1]) * self.duration.total_seconds() / sx.shape[1] freq = self.fft.f + ax.pcolormesh(time, freq, sx, vmin=-120, vmax=0) def save_spectrogram( From ec978c24c7be98a85233c369c463ebc91e506cd9 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 13 Feb 2025 11:56:36 +0100 Subject: [PATCH 112/118] update SpectroData.nb_bytes depending on matrix_dtype --- src/OSmOSE/core_api/spectro_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/core_api/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py index d31bab67..6dbe5c11 100644 --- a/src/OSmOSE/core_api/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -125,7 +125,8 @@ def shape(self) -> tuple[int, ...]: @property def nb_bytes(self) -> int: """Total bytes consumed by the spectro values.""" - return self.shape[0] * self.shape[1] * 16 + nb_bytes_per_cell = 8 if self.matrix_dtype == MatrixDtype.Absolute else 16 + return self.shape[0] * self.shape[1] * nb_bytes_per_cell def __str__(self) -> str: """Overwrite __str__.""" From e7bbef910b743de5016ba93890dcc5d017b11acc Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 13 Feb 2025 12:33:33 +0100 Subject: [PATCH 113/118] convert sx dtype when read from items --- src/OSmOSE/core_api/spectro_data.py | 10 +++++++--- src/OSmOSE/core_api/spectro_item.py | 20 +++++++++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/OSmOSE/core_api/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py index 6dbe5c11..f0bc7092 100644 --- a/src/OSmOSE/core_api/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -268,14 +268,18 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: if len({i.file.get_fft().delta_t for i in items if not i.is_empty}) > 1: raise ValueError("Items don't have the same time resolution.") - output = items[0].get_value(fft=self.fft) + sx_dtype = complex if self.matrix_dtype == MatrixDtype.Complex else float + output = items[0].get_value(fft=self.fft, sx_dtype=sx_dtype) for item in items[1:]: p1_le = self.fft.lower_border_end[1] - self.fft.p_min output = np.hstack( ( output[:, :-p1_le], - (output[:, -p1_le:] + item.get_value(fft=self.fft)[:, :p1_le]), - item.get_value(fft=self.fft)[:, p1_le:], + ( + output[:, -p1_le:] + + item.get_value(fft=self.fft, sx_dtype=sx_dtype)[:, :p1_le] + ), + item.get_value(fft=self.fft, sx_dtype=sx_dtype)[:, p1_le:], ), ) return output diff --git a/src/OSmOSE/core_api/spectro_item.py b/src/OSmOSE/core_api/spectro_item.py index e69cbb4d..707f160a 100644 --- a/src/OSmOSE/core_api/spectro_item.py +++ b/src/OSmOSE/core_api/spectro_item.py @@ -59,18 +59,32 @@ def from_base_item(cls, item: BaseItem) -> SpectroItem: ) raise TypeError - def get_value(self, fft: ShortTimeFFT | None = None) -> np.ndarray: + def get_value( + self, + fft: ShortTimeFFT | None = None, + sx_dtype: type[complex] = complex, + ) -> np.ndarray: """Get the values from the File between the begin and stop timestamps. If the Item is empty, return a single 0. """ if not self.is_empty: - return self.file.read(start=self.begin, stop=self.end) + sx = self.file.read(start=self.begin, stop=self.end) + + if np.iscomplexobj(sx) and sx_dtype is float: + sx = abs(sx) ** 2 + if not np.iscomplexobj(sx) and sx_dtype is complex: + raise TypeError( + "Cannot convert absolute npz values to complex sx values." + "Change the SpectroData dtype to absolute.", + ) + + return sx return np.zeros( ( fft.f.shape[0], fft.p_num(int(self.duration.total_seconds() * fft.fs)), ), - dtype=complex, + dtype=sx_dtype, ) From cf7f728647bf3a78720d109838cee6fc6557eb73 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 13 Feb 2025 12:42:00 +0100 Subject: [PATCH 114/118] set SpectroData matrix_dtype based on SpectroFile sx_dtype --- src/OSmOSE/core_api/spectro_data.py | 5 ++++- src/OSmOSE/core_api/spectro_file.py | 3 +++ src/OSmOSE/core_api/spectro_item.py | 4 ++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/OSmOSE/core_api/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py index f0bc7092..6d2edf92 100644 --- a/src/OSmOSE/core_api/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -310,10 +310,13 @@ def from_files( The SpectroData object. """ - return cls.from_base_data( + instance = cls.from_base_data( BaseData.from_files(files, begin, end), fft=files[0].get_fft(), ) + if all(not file.is_complex for file in files): + instance.matrix_dtype = MatrixDtype.Absolute + return instance @classmethod def from_base_data( diff --git a/src/OSmOSE/core_api/spectro_file.py b/src/OSmOSE/core_api/spectro_file.py index d3206217..4eb2f32e 100644 --- a/src/OSmOSE/core_api/spectro_file.py +++ b/src/OSmOSE/core_api/spectro_file.py @@ -63,6 +63,7 @@ def _read_metadata(self, path: PathLike) -> None: window = data["window"] mfft = data["mfft"][0] timestamps = str(data["timestamps"]) + is_complex = np.iscomplexobj(data["sx"]) self.sample_rate = sample_rate self.mfft = mfft @@ -77,6 +78,8 @@ def _read_metadata(self, path: PathLike) -> None: self.window = window self.hop = hop + self.is_complex = is_complex + def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the spectro data between start and stop from the file. diff --git a/src/OSmOSE/core_api/spectro_item.py b/src/OSmOSE/core_api/spectro_item.py index 707f160a..e453e83a 100644 --- a/src/OSmOSE/core_api/spectro_item.py +++ b/src/OSmOSE/core_api/spectro_item.py @@ -71,9 +71,9 @@ def get_value( if not self.is_empty: sx = self.file.read(start=self.begin, stop=self.end) - if np.iscomplexobj(sx) and sx_dtype is float: + if self.file.is_complex and sx_dtype is float: sx = abs(sx) ** 2 - if not np.iscomplexobj(sx) and sx_dtype is complex: + if self.file.is_complex and sx_dtype is complex: raise TypeError( "Cannot convert absolute npz values to complex sx values." "Change the SpectroData dtype to absolute.", From 5f42db5e345f1a791ebc2cbdead23a82c5f24a01 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 13 Feb 2025 16:16:43 +0100 Subject: [PATCH 115/118] replace Enum with simpler class attribute --- src/OSmOSE/core_api/ltas_data.py | 6 ++--- src/OSmOSE/core_api/spectro_data.py | 35 ++++++++++------------------- src/OSmOSE/core_api/spectro_file.py | 2 +- src/OSmOSE/core_api/spectro_item.py | 15 +++++++------ 4 files changed, 24 insertions(+), 34 deletions(-) diff --git a/src/OSmOSE/core_api/ltas_data.py b/src/OSmOSE/core_api/ltas_data.py index 2d9b89c3..e9d75c24 100644 --- a/src/OSmOSE/core_api/ltas_data.py +++ b/src/OSmOSE/core_api/ltas_data.py @@ -23,14 +23,14 @@ from scipy.signal import ShortTimeFFT from tqdm import tqdm -from OSmOSE.core_api.spectro_data import MatrixDtype, SpectroData -from OSmOSE.core_api.spectro_item import SpectroItem +from OSmOSE.core_api.spectro_data import SpectroData if TYPE_CHECKING: from pandas import Timestamp from OSmOSE.core_api.audio_data import AudioData + from OSmOSE.core_api.spectro_item import SpectroItem class LTASData(SpectroData): @@ -96,7 +96,7 @@ def __init__( fft=ltas_fft, ) self.nb_time_bins = nb_time_bins - self.matrix_dtype = MatrixDtype.Absolute + self.sx_dtype = float def get_value(self, depth: int = 0) -> np.ndarray: """Return the Sx matrix of the LTAS. diff --git a/src/OSmOSE/core_api/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py index 6d2edf92..6fe74aaa 100644 --- a/src/OSmOSE/core_api/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -6,7 +6,6 @@ from __future__ import annotations -from enum import Enum from typing import TYPE_CHECKING import matplotlib.pyplot as plt @@ -26,18 +25,6 @@ from OSmOSE.core_api.audio_data import AudioData -class MatrixDtype(Enum): - """Represent the dtype of the spectrum values. - - Complex will keep the phase info. - Absolute is more suited for LTAS-like spectrograms. - - """ - - Complex = "complex" - Absolute = "absolute" - - class SpectroData(BaseData[SpectroItem, SpectroFile]): """SpectroData represent Spectro data scattered through different SpectroFiles. @@ -74,7 +61,7 @@ def __init__( super().__init__(items=items, begin=begin, end=end) self.audio_data = audio_data self.fft = fft - self.matrix_dtype = MatrixDtype.Complex + self.sx_dtype = complex @staticmethod def get_default_ax() -> plt.Axes: @@ -125,7 +112,7 @@ def shape(self) -> tuple[int, ...]: @property def nb_bytes(self) -> int: """Total bytes consumed by the spectro values.""" - nb_bytes_per_cell = 8 if self.matrix_dtype == MatrixDtype.Absolute else 16 + nb_bytes_per_cell = 16 if self.sx_dtype is complex else 8 return self.shape[0] * self.shape[1] * nb_bytes_per_cell def __str__(self) -> str: @@ -144,7 +131,7 @@ def get_value(self) -> np.ndarray: sx = self.fft.stft(self.audio_data.get_value(reject_dc=True), padding="zeros") - if self.matrix_dtype == MatrixDtype.Absolute: + if self.sx_dtype is float: sx = abs(sx) ** 2 return sx @@ -164,7 +151,7 @@ def plot(self, ax: plt.Axes | None = None, sx: np.ndarray | None = None) -> None ax = ax if ax is not None else SpectroData.get_default_ax() sx = self.get_value() if sx is None else sx - if self.matrix_dtype == MatrixDtype.Complex: + if self.sx_dtype is complex: sx = abs(sx) ** 2 sx = 10 * np.log10(sx + np.nextafter(0, 1)) @@ -268,8 +255,7 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: if len({i.file.get_fft().delta_t for i in items if not i.is_empty}) > 1: raise ValueError("Items don't have the same time resolution.") - sx_dtype = complex if self.matrix_dtype == MatrixDtype.Complex else float - output = items[0].get_value(fft=self.fft, sx_dtype=sx_dtype) + output = items[0].get_value(fft=self.fft, sx_dtype=self.sx_dtype) for item in items[1:]: p1_le = self.fft.lower_border_end[1] - self.fft.p_min output = np.hstack( @@ -277,9 +263,12 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: output[:, :-p1_le], ( output[:, -p1_le:] - + item.get_value(fft=self.fft, sx_dtype=sx_dtype)[:, :p1_le] + + item.get_value(fft=self.fft, sx_dtype=self.sx_dtype)[ + :, + :p1_le, + ] ), - item.get_value(fft=self.fft, sx_dtype=sx_dtype)[:, p1_le:], + item.get_value(fft=self.fft, sx_dtype=self.sx_dtype)[:, p1_le:], ), ) return output @@ -314,8 +303,8 @@ def from_files( BaseData.from_files(files, begin, end), fft=files[0].get_fft(), ) - if all(not file.is_complex for file in files): - instance.matrix_dtype = MatrixDtype.Absolute + if not any(file.sx_dtype is complex for file in files): + instance.sx_dtype = float return instance @classmethod diff --git a/src/OSmOSE/core_api/spectro_file.py b/src/OSmOSE/core_api/spectro_file.py index 4eb2f32e..7866659b 100644 --- a/src/OSmOSE/core_api/spectro_file.py +++ b/src/OSmOSE/core_api/spectro_file.py @@ -78,7 +78,7 @@ def _read_metadata(self, path: PathLike) -> None: self.window = window self.hop = hop - self.is_complex = is_complex + self.sx_dtype = complex if is_complex else float def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: """Return the spectro data between start and stop from the file. diff --git a/src/OSmOSE/core_api/spectro_item.py b/src/OSmOSE/core_api/spectro_item.py index e453e83a..8c652a2f 100644 --- a/src/OSmOSE/core_api/spectro_item.py +++ b/src/OSmOSE/core_api/spectro_item.py @@ -71,13 +71,14 @@ def get_value( if not self.is_empty: sx = self.file.read(start=self.begin, stop=self.end) - if self.file.is_complex and sx_dtype is float: - sx = abs(sx) ** 2 - if self.file.is_complex and sx_dtype is complex: - raise TypeError( - "Cannot convert absolute npz values to complex sx values." - "Change the SpectroData dtype to absolute.", - ) + if self.file.sx_dtype is not sx_dtype: + if sx_dtype is float: + sx = abs(sx) ** 2 + if sx_dtype is complex: + raise TypeError( + "Cannot convert absolute npz values to complex sx values." + "Change the SpectroData dtype to absolute.", + ) return sx From 637b73a24977f6bb81dfe5f7e0fdb2d5132a438b Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 13 Feb 2025 16:37:55 +0100 Subject: [PATCH 116/118] add property for SpectroData.sx_dtype --- src/OSmOSE/core_api/spectro_data.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/OSmOSE/core_api/spectro_data.py b/src/OSmOSE/core_api/spectro_data.py index 6fe74aaa..238eabb6 100644 --- a/src/OSmOSE/core_api/spectro_data.py +++ b/src/OSmOSE/core_api/spectro_data.py @@ -61,7 +61,7 @@ def __init__( super().__init__(items=items, begin=begin, end=end) self.audio_data = audio_data self.fft = fft - self.sx_dtype = complex + self._sx_dtype = complex @staticmethod def get_default_ax() -> plt.Axes: @@ -115,6 +115,22 @@ def nb_bytes(self) -> int: nb_bytes_per_cell = 16 if self.sx_dtype is complex else 8 return self.shape[0] * self.shape[1] * nb_bytes_per_cell + @property + def sx_dtype(self) -> type[complex]: + """Data type used to represent the sx values. Should either be float or complex. + + If complex, the phase info will be included in the computed spectrum. + If float, only the absolute value of the spectrum will be kept. + + """ + return self._sx_dtype + + @sx_dtype.setter + def sx_dtype(self, dtype: type[complex]) -> [complex, float]: + if dtype not in (complex, float): + raise ValueError("dtype must be complex or float.") + self._sx_dtype = dtype + def __str__(self) -> str: """Overwrite __str__.""" return self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES) From a2abaa2bf883df07646db3586736e7f8b9884988 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 13 Feb 2025 17:41:49 +0100 Subject: [PATCH 117/118] add tests for spectro and ltas dtypes --- tests/test_spectro.py | 100 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/tests/test_spectro.py b/tests/test_spectro.py index 8e32e95b..916e67ad 100644 --- a/tests/test_spectro.py +++ b/tests/test_spectro.py @@ -13,6 +13,7 @@ from OSmOSE.core_api.audio_data import AudioData from OSmOSE.core_api.audio_dataset import AudioDataset from OSmOSE.core_api.audio_file import AudioFile +from OSmOSE.core_api.ltas_data import LTASData from OSmOSE.core_api.spectro_data import SpectroData from OSmOSE.core_api.spectro_dataset import SpectroDataset from OSmOSE.core_api.spectro_file import SpectroFile @@ -308,3 +309,102 @@ def test_spectrogram_from_npz_files( ) assert np.allclose(sd.get_value(), sds.data[0].get_value()) + + +@pytest.mark.parametrize( + ("audio_files", "origin_dtype", "target_dtype", "expected_value_dtype"), + [ + pytest.param( + { + "duration": 1, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + complex, + complex, + complex, + id="complex_to_complex", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + complex, + float, + float, + id="complex_to_float", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + complex, + float, + float, + id="float_to_float", + ), + pytest.param( + { + "duration": 1, + "sample_rate": 1_024, + "nb_files": 1, + "date_begin": pd.Timestamp("2024-01-01 12:00:00"), + }, + float, + complex, + pytest.raises( + TypeError, + match="Cannot convert absolute npz values to complex sx values.", + ), + id="float_to_complex_raises_exception", + ), + ], + indirect=["audio_files"], +) +def test_spectrogram_sx_dtype( + tmp_path: Path, + audio_files: tuple[list[Path], pytest.fixtures.Subrequest], + origin_dtype: type[complex], + target_dtype: type[complex], + expected_value_dtype: type[complex], +) -> None: + audio_file, request = audio_files + af = AudioFile(audio_file[0], strptime_format=TIMESTAMP_FORMAT_TEST_FILES) + ad = AudioData.from_files([af]) + sft = ShortTimeFFT(hamming(128), 128, 1_024) + sd = SpectroData.from_audio_data(ad, sft) + + sd.sx_dtype = origin_dtype + ltas = LTASData.from_spectro_data(sd, 4) + assert ltas.sx_dtype is float # Default LTASData behaviour + + assert sd.get_value().dtype == origin_dtype + + sd.write(tmp_path / "npz") + + sd2 = SpectroDataset.from_folder( + (tmp_path / "npz"), strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES + ).data[0] + + assert sd2.sx_dtype is complex # Default SpectroData behaviour + + assert ltas.get_value().dtype == float + + sd2.sx_dtype = target_dtype + + if type(expected_value_dtype) is type: + assert sd2.get_value().dtype == expected_value_dtype + else: + with expected_value_dtype as e: + assert sd2.get_value().dtype == expected_value_dtype + + sd2.sx_dtype = origin_dtype + + assert sd2.get_value().dtype == origin_dtype From 0d357c3b3180297c378692bc91b11284fc687b8c Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Thu, 13 Feb 2025 17:47:15 +0100 Subject: [PATCH 118/118] test spectro dtype parsing from npz file --- tests/test_spectro.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_spectro.py b/tests/test_spectro.py index 916e67ad..dd50a500 100644 --- a/tests/test_spectro.py +++ b/tests/test_spectro.py @@ -389,11 +389,15 @@ def test_spectrogram_sx_dtype( sd.write(tmp_path / "npz") - sd2 = SpectroDataset.from_folder( - (tmp_path / "npz"), strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES - ).data[0] + sf = SpectroFile( + tmp_path / "npz" / f"{sd}.npz", strptime_format=TIMESTAMP_FORMAT_EXPORTED_FILES + ) + + assert sf.sx_dtype is origin_dtype + + sd2 = SpectroData.from_files([sf]) - assert sd2.sx_dtype is complex # Default SpectroData behaviour + assert sd2.sx_dtype is origin_dtype # Default SpectroData behaviour assert ltas.get_value().dtype == float