diff --git a/src/pyaro/timeseries/Data.py b/src/pyaro/timeseries/Data.py index 7c65797..68e5ea9 100644 --- a/src/pyaro/timeseries/Data.py +++ b/src/pyaro/timeseries/Data.py @@ -72,6 +72,9 @@ def slice(self, index): # -> Self: for 3.11 """ pass + def __getitem__(self, key): + return self.slice(key) + @property def variable(self) -> str: """Variable name for all the data diff --git a/src/pyaro/timeseries/Filter.py b/src/pyaro/timeseries/Filter.py index 3065bef..90251a2 100644 --- a/src/pyaro/timeseries/Filter.py +++ b/src/pyaro/timeseries/Filter.py @@ -8,8 +8,10 @@ import re import sys import types +from typing import Any import numpy as np +import numpy.typing as npt from .Data import Data, Flag from .Station import Station @@ -34,7 +36,7 @@ def __init__(self, **kwargs): for an empty filter object""" return - def args(self) -> list: + def args(self) -> dict[str, Any]: """retrieve the kwargs possible to retrieve a new object of this filter with filter restrictions :return: a dictionary possible to use as kwargs for the new method @@ -57,7 +59,7 @@ def name(self) -> str: """ def filter_data( - self, data: Data, stations: list[Station], variables: list[str] + self, data: Data, stations: dict[str, Station], variables: list[str] ) -> Data: """Filtering of data @@ -93,14 +95,18 @@ class DataIndexFilter(Filter): filter_data_idx""" @abc.abstractmethod - def filter_data_idx(self, data: Data, stations: dict[str, Station], variables: str): + def filter_data_idx( + self, data: Data, stations: dict[str, Station], variables: list[str] + ): """Filter data to an index which can be applied to Data.slice(idx) later :return: a index for Data.slice(idx) """ pass - def filter_data(self, data: Data, stations: dict[str, Station], variables: str): + def filter_data( + self, data: Data, stations: dict[str, Station], variables: list[str] + ) -> Data: idx = self.filter_data_idx(data, stations, variables) return data.slice(idx) @@ -268,7 +274,7 @@ def new_varname(self, reader_variable: str) -> str: """ return self._reader_to_new.get(reader_variable, reader_variable) - def filter_data(self, data, stations, variables): + def filter_data(self, data, stations, variables) -> Data: """Translate data's variable""" data._set_variable(self._reader_to_new.get(data.variable, data.variable)) return data @@ -320,7 +326,9 @@ class StationReductionFilter(DataIndexFilter): def filter_stations(self, stations: dict[str, Station]) -> dict[str, Station]: pass - def filter_data_idx(self, data: Data, stations: dict[str, Station], variables: str): + def filter_data_idx( + self, data: Data, stations: dict[str, Station], variables: list[str] + ): stat_names = self.filter_stations(stations).keys() dstations = data.stations stat_names = np.fromiter(stat_names, dtype=dstations.dtype) @@ -402,8 +410,8 @@ class BoundingBoxFilter(StationReductionFilter): def __init__( self, - include: list[(float, float, float, float)] = [], - exclude: list[(float, float, float, float)] = [], + include: list[tuple[float, float, float, float]] = [], + exclude: list[tuple[float, float, float, float]] = [], ): for tup in include: self._test_bounding_box(tup) @@ -507,12 +515,20 @@ def init_kwargs(self): def usable_flags(self): return self._valid - def filter_data_idx(self, data: Data, stations: dict[str, Station], variables: str): + def filter_data_idx( + self, data: Data, stations: dict[str, Station], variables: list[str] + ): validflags = np.fromiter(self._valid, dtype=data.flags.dtype) index = np.isin(data.flags, validflags) return index +# Upper and lower bound inclusive +TimeBound = tuple[str | np.datetime64 | datetime, str | np.datetime64 | datetime] +# Internal representation +_TimeBound = tuple[np.datetime64, np.datetime64] + + class TimeBoundsException(Exception): pass @@ -520,7 +536,7 @@ class TimeBoundsException(Exception): @registered_filter class TimeBoundsFilter(DataIndexFilter): """Filter data by start and/or end-times of the measurements. Each timebound consists - of a bound-start and bound-end (both included). Timestamps are given as YYYY-MM-DD HH:MM:SS + of a bound-start and bound-end (both included). Timestamps are given as YYYY-MM-DD HH:MM:SS in UTC :param start_include: list of tuples of start-times, defaults to [], meaning all :param start_exclude: list of tuples of start-times, defaults to [] @@ -529,33 +545,56 @@ class TimeBoundsFilter(DataIndexFilter): :param end_include: list of tuples of end-times, defaults to [], meaning all :param end_exclude: list of tuples of end-times, defaults to [] :raises TimeBoundsException: on any errors with the time-bounds + + Examples: + + end_include: [("2023-01-01 10:00:00", "2024-01-01 07:00:00")] + will only include observations where the end time of each observation + is within the interval specified + (i.e. "end" >= 2023-01-01 10:00:00 and "end" <= "2024-01-01 07:00:00") + + Including multiple bounds will act as an OR, allowing multiple selections. + If we want every observation in January for 2021, 2022, 2023, and 2024 this + could be made as the following filter: + startend_include: [ + ("2021-01-01 00:00:00", "2021-02-01 00:00:00"), + ("2022-01-01 00:00:00", "2022-02-01 00:00:00"), + ("2023-01-01 00:00:00", "2023-02-01 00:00:00"), + ("2024-01-01 00:00:00", "2024-02-01 00:00:00"), + ] """ def __init__( self, - start_include: list[(str, str)] = [], - start_exclude: list[(str, str)] = [], - startend_include: list[(str, str)] = [], - startend_exclude: list[(str, str)] = [], - end_include: list[(str, str)] = [], - end_exclude: list[(str, str)] = [], + start_include: list[TimeBound] = [], + start_exclude: list[TimeBound] = [], + startend_include: list[TimeBound] = [], + startend_exclude: list[TimeBound] = [], + end_include: list[TimeBound] = [], + end_exclude: list[TimeBound] = [], ): - self._start_include = self._str_list_to_datetime_list(start_include) - self._start_exclude = self._str_list_to_datetime_list(start_exclude) - self._startend_include = self._str_list_to_datetime_list(startend_include) - self._startend_exclude = self._str_list_to_datetime_list(startend_exclude) - self._end_include = self._str_list_to_datetime_list(end_include) - self._end_exclude = self._str_list_to_datetime_list(end_exclude) - return + self._start_include = self._timebounds_canonicalise(start_include) + self._start_exclude = self._timebounds_canonicalise(start_exclude) + self._startend_include = self._timebounds_canonicalise(startend_include) + self._startend_exclude = self._timebounds_canonicalise(startend_exclude) + self._end_include = self._timebounds_canonicalise(end_include) + self._end_exclude = self._timebounds_canonicalise(end_exclude) def name(self): return "time_bounds" - def _str_list_to_datetime_list(self, tuple_list: list[(str, str)]): + def _timebounds_canonicalise(self, tuple_list: list[TimeBound]) -> list[_TimeBound]: retlist = [] for start, end in tuple_list: - start_dt = datetime.strptime(start, self.time_format) - end_dt = datetime.strptime(end, self.time_format) + if isinstance(start, str): + start_dt = np.datetime64(datetime.strptime(start, self.time_format)) + else: + start_dt = np.datetime64(start) + if isinstance(end, str): + end_dt = np.datetime64(datetime.strptime(end, self.time_format)) + else: + end_dt = np.datetime64(end) + if start_dt > end_dt: raise TimeBoundsException( f"(start later than end) for (f{start} > f{end})" @@ -563,15 +602,18 @@ def _str_list_to_datetime_list(self, tuple_list: list[(str, str)]): retlist.append((start_dt, end_dt)) return retlist - def _datetime_list_to_str_list(self, tuple_list) -> list[(str, str)]: + def _datetime_list_to_str_list(self, tuple_list) -> list[tuple[str, str]]: retlist = [] for start_dt, end_dt in tuple_list: retlist.append( - (start_dt.strftime(self.time_format), end_dt.strftime(self.time_format)) + ( + start_dt.astype(datetime).strftime(self.time_format), + end_dt.astype(datetime).strftime(self.time_format), + ) ) return retlist - def init_kwargs(self): + def init_kwargs(self) -> dict[str, list[tuple[str, str]]]: return { "start_include": self._datetime_list_to_str_list(self._start_include), "start_exclude": self._datetime_list_to_str_list(self._start_exclude), @@ -581,22 +623,28 @@ def init_kwargs(self): "end_exclude": self._datetime_list_to_str_list(self._startend_exclude), } - def _index_from_include_exclude(self, times1, times2, includes, excludes): + def _index_from_include_exclude( + self, + times1: npt.NDArray[np.datetime64], + times2: npt.NDArray[np.datetime64], + includes: list[_TimeBound], + excludes: list[_TimeBound], + ): if len(includes) == 0: idx = np.repeat(True, len(times1)) else: idx = np.repeat(False, len(times1)) for start, end in includes: - idx |= (np.datetime64(start) <= times1) & (times2 <= np.datetime64(end)) + idx |= (start <= times1) & (times2 <= end) for start, end in excludes: - idx &= (times1 < np.datetime64(start)) | (np.datetime64(end) < times2) + idx &= (times1 < start) | (end < times2) return idx - def has_envelope(self): + def has_envelope(self) -> bool: """Check if this filter has an envelope, i.e. a earliest and latest time""" - return ( + return bool( len(self._start_include) or len(self._startend_include) or len(self._end_include) @@ -612,8 +660,8 @@ def envelope(self) -> tuple[datetime, datetime]: raise TimeBoundsException( "TimeBounds-envelope called but no envelope exists" ) - start = datetime.max - end = datetime.min + start = np.datetime64(datetime.max) + end = np.datetime64(datetime.min) for s, e in self._start_include + self._startend_include + self._end_include: start = min(start, s) end = max(end, e) @@ -621,13 +669,15 @@ def envelope(self) -> tuple[datetime, datetime]: raise TimeBoundsException( f"TimeBoundsEnvelope end < start: {end} < {start}" ) - return (start, end) + return (start.astype(datetime), end.astype(datetime)) - def contains(self, dt_start, dt_end): + def contains( + self, dt_start: npt.NDArray[np.datetime64], dt_end: npt.NDArray[np.datetime64] + ) -> npt.NDArray[np.bool_]: """Test if datetimes in dt_start, dt_end belong to this filter - :param dt_start: numpy array of datetimes - :param dt_end: numpy array of datetimes + :param dt_start: start of each observation as a numpy array of datetimes + :param dt_end: end of each observation as a numpy array of datetimes :return: numpy boolean array with True/False values """ idx = self._index_from_include_exclude( @@ -641,7 +691,9 @@ def contains(self, dt_start, dt_end): ) return idx - def filter_data_idx(self, data: Data, stations: dict[str, Station], variables: str): + def filter_data_idx( + self, data: Data, stations: dict[str, Station], variables: list[str] + ) -> npt.NDArray[np.bool_]: return self.contains(data.start_times, data.end_times) @@ -715,7 +767,9 @@ def init_kwargs(self): def name(self): return "time_variable_station" - def filter_data_idx(self, data: Data, stations: dict[str, Station], variables: str): + def filter_data_idx( + self, data: Data, stations: dict[str, Station], variables: list[str] + ): idx = data.start_times.astype(bool) idx |= True if data.variable in self._exclude: @@ -757,7 +811,9 @@ def init_kwargs(self): def name(self): return "duplicates" - def filter_data_idx(self, data: Data, stations: dict[str, Station], variables: str): + def filter_data_idx( + self, data: Data, stations: dict[str, Station], variables: list[str] + ): if self._keys is None: xkeys = self.default_keys else: @@ -821,7 +877,9 @@ def init_kwargs(self): def name(self): return "time_resolution" - def filter_data_idx(self, data: Data, stations: dict[str, Station], variables: str): + def filter_data_idx( + self, data: Data, stations: dict[str, Station], variables: list[str] + ): idx = data.start_times.astype(bool) idx[:] = True if len(self._minmax) > 0: @@ -1068,7 +1126,7 @@ def _gridded_altitude_from_lat_lon( def _is_close( self, alt_gridded: np.ndarray, alt_station: np.ndarray - ) -> np.ndarray[bool]: + ) -> npt.NDArray[np.bool_]: """ Function to check if two altitudes are within a relative tolerance of each other. diff --git a/tests/test_CSVTimeSeriesReader.py b/tests/test_CSVTimeSeriesReader.py index 81f14e7..c635986 100644 --- a/tests/test_CSVTimeSeriesReader.py +++ b/tests/test_CSVTimeSeriesReader.py @@ -438,7 +438,7 @@ def test_altitude_filter_1(self): "country": "NO", "standard_deviation": "NaN", "flag": "0", - } + }, ) as ts: self.assertEqual(len(ts.stations()), 1) @@ -460,7 +460,7 @@ def test_altitude_filter_2(self): "country": "NO", "standard_deviation": "NaN", "flag": "0", - } + }, ) as ts: self.assertEqual(len(ts.stations()), 1) @@ -468,7 +468,11 @@ def test_altitude_filter_3(self): engines = pyaro.list_timeseries_engines() with engines["csv_timeseries"].open( filename=self.elevation_file, - filters=[pyaro.timeseries.filters.get("altitude", min_altitude=150, max_altitude=250)], + filters=[ + pyaro.timeseries.filters.get( + "altitude", min_altitude=150, max_altitude=250 + ) + ], columns={ "variable": 0, "station": 1, @@ -482,7 +486,7 @@ def test_altitude_filter_3(self): "country": "NO", "standard_deviation": "NaN", "flag": "0", - } + }, ) as ts: self.assertEqual(len(ts.stations()), 1) @@ -490,7 +494,13 @@ def test_relaltitude_filter_emep_1(self): engines = pyaro.list_timeseries_engines() with engines["csv_timeseries"].open( filename=self.elevation_file, - filters=[pyaro.timeseries.filters.get("relaltitude", topo_file = "./tests/testdata/datadir_elevation/topography.nc", rdiff=0)], + filters=[ + pyaro.timeseries.filters.get( + "relaltitude", + topo_file="./tests/testdata/datadir_elevation/topography.nc", + rdiff=0, + ) + ], columns={ "variable": 0, "station": 1, @@ -504,7 +514,7 @@ def test_relaltitude_filter_emep_1(self): "country": "NO", "standard_deviation": "NaN", "flag": "0", - } + }, ) as ts: # Altitudes in test dataset: # Station | Alt_obs | Modeobs | rdiff | @@ -518,7 +528,13 @@ def test_relaltitude_filter_emep_2(self): engines = pyaro.list_timeseries_engines() with engines["csv_timeseries"].open( filename=self.elevation_file, - filters=[pyaro.timeseries.filters.get("relaltitude", topo_file = "./tests/testdata/datadir_elevation/topography.nc", rdiff=90)], + filters=[ + pyaro.timeseries.filters.get( + "relaltitude", + topo_file="./tests/testdata/datadir_elevation/topography.nc", + rdiff=90, + ) + ], columns={ "variable": 0, "station": 1, @@ -532,7 +548,7 @@ def test_relaltitude_filter_emep_2(self): "country": "NO", "standard_deviation": "NaN", "flag": "0", - } + }, ) as ts: # At rdiff = 90, only the first station should be included. self.assertEqual(len(ts.stations()), 1) @@ -541,7 +557,13 @@ def test_relaltitude_filter_emep_3(self): engines = pyaro.list_timeseries_engines() with engines["csv_timeseries"].open( filename=self.elevation_file, - filters=[pyaro.timeseries.filters.get("relaltitude", topo_file = "./tests/testdata/datadir_elevation/topography.nc", rdiff=300)], + filters=[ + pyaro.timeseries.filters.get( + "relaltitude", + topo_file="./tests/testdata/datadir_elevation/topography.nc", + rdiff=300, + ) + ], columns={ "variable": 0, "station": 1, @@ -555,7 +577,7 @@ def test_relaltitude_filter_emep_3(self): "country": "NO", "standard_deviation": "NaN", "flag": "0", - } + }, ) as ts: # Since rdiff=300, all stations should be included. self.assertEqual(len(ts.stations()), 3) @@ -564,7 +586,13 @@ def test_relaltitude_filter_1(self): engines = pyaro.list_timeseries_engines() with engines["csv_timeseries"].open( filename=self.elevation_file, - filters=[pyaro.timeseries.filters.get("relaltitude", topo_file = "./tests/testdata/datadir_elevation/topography.nc", rdiff=0)], + filters=[ + pyaro.timeseries.filters.get( + "relaltitude", + topo_file="./tests/testdata/datadir_elevation/topography.nc", + rdiff=0, + ) + ], columns={ "variable": 0, "station": 1, @@ -578,7 +606,7 @@ def test_relaltitude_filter_1(self): "country": "NO", "standard_deviation": "NaN", "flag": "0", - } + }, ) as ts: self.assertEqual(len(ts.stations()), 0) @@ -586,7 +614,13 @@ def test_relaltitude_filter_2(self): engines = pyaro.list_timeseries_engines() with engines["csv_timeseries"].open( filename=self.elevation_file, - filters=[pyaro.timeseries.filters.get("relaltitude", topo_file = "./tests/testdata/datadir_elevation/topography.nc", rdiff=90)], + filters=[ + pyaro.timeseries.filters.get( + "relaltitude", + topo_file="./tests/testdata/datadir_elevation/topography.nc", + rdiff=90, + ) + ], columns={ "variable": 0, "station": 1, @@ -600,7 +634,7 @@ def test_relaltitude_filter_2(self): "country": "NO", "standard_deviation": "NaN", "flag": "0", - } + }, ) as ts: # At rdiff = 90, only the first station should be included. self.assertEqual(len(ts.stations()), 1) @@ -609,7 +643,13 @@ def test_relaltitude_filter_3(self): engines = pyaro.list_timeseries_engines() with engines["csv_timeseries"].open( filename=self.elevation_file, - filters=[pyaro.timeseries.filters.get("relaltitude", topo_file = "./tests/testdata/datadir_elevation/topography.nc", rdiff=300)], + filters=[ + pyaro.timeseries.filters.get( + "relaltitude", + topo_file="./tests/testdata/datadir_elevation/topography.nc", + rdiff=300, + ) + ], columns={ "variable": 0, "station": 1, @@ -623,12 +663,11 @@ def test_relaltitude_filter_3(self): "country": "NO", "standard_deviation": "NaN", "flag": "0", - } + }, ) as ts: # Since rdiff=300, all stations should be included. self.assertEqual(len(ts.stations()), 3) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_timefilter.py b/tests/test_timefilter.py index 337631e..e6c4430 100644 --- a/tests/test_timefilter.py +++ b/tests/test_timefilter.py @@ -6,13 +6,26 @@ def test_timemax(): - bounds = TimeBoundsFilter(start_include=[("2023-01-01 00:00:00", "2024-01-01 00:00:00")]) + bounds = TimeBoundsFilter( + start_include=[("2023-01-01 00:00:00", "2024-01-01 00:00:00")] + ) envelope = bounds.envelope() assert envelope[0] == datetime.fromisoformat("2023-01-01 00:00:00") assert envelope[1] == datetime.fromisoformat("2024-01-01 00:00:00") - dt_start = np.arange(np.datetime64("2023-01-30"), np.datetime64("2023-03-10"), np.timedelta64(1, "D")) + dt_start = np.arange( + np.datetime64("2023-01-30"), np.datetime64("2023-03-10"), np.timedelta64(1, "D") + ) dt_end = dt_start + np.timedelta64(1, "h") idx = bounds.contains(dt_start, dt_end) assert len(idx) == len(dt_start) + + +def test_roundtrip(): + bounds = TimeBoundsFilter( + start_include=[("2023-01-01 00:00:03", "2024-01-01 00:10:00")] + ) + + init = bounds.init_kwargs() + assert init["start_include"] == [("2023-01-01 00:00:03", "2024-01-01 00:10:00")]