From 78b578327c79d5b4a920b57ba6f8cc0671fbceca Mon Sep 17 00:00:00 2001 From: Tony Xiang Date: Wed, 22 Jan 2025 15:02:12 +0100 Subject: [PATCH 1/2] fix a bug on relevant filter Signed-off-by: Tony Xiang --- src/power_grid_model/_core/power_grid_dataset.py | 16 +++++++++------- tests/unit/test_serialization.py | 7 ++++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/power_grid_model/_core/power_grid_dataset.py b/src/power_grid_model/_core/power_grid_dataset.py index 92285cb8d..ad0724b9e 100644 --- a/src/power_grid_model/_core/power_grid_dataset.py +++ b/src/power_grid_model/_core/power_grid_dataset.py @@ -26,7 +26,7 @@ power_grid_core as pgc, ) from power_grid_model._core.power_grid_meta import ComponentMetaData, DatasetMetaData, power_grid_meta_data -from power_grid_model._utils import get_dataset_type, is_columnar, is_nan_or_equivalent, process_data_filter +from power_grid_model._utils import get_dataset_type, is_columnar, is_nan_or_equivalent, is_sparse, process_data_filter from power_grid_model.data_types import AttributeType, ComponentData, Dataset from power_grid_model.enum import ComponentAttributeFilterOptions from power_grid_model.typing import ComponentAttributeMapping, _ComponentAttributeMappingDict @@ -455,22 +455,24 @@ def _get_buffer_properties(self, info: CDatasetInfo) -> Mapping[ComponentType, B if component in self._data_filter } - def _filter_attributes(self, attributes): + def _filter_attributes(self, buffer): + if is_sparse(buffer): + attributes = buffer["data"] + else: + attributes = buffer keys_to_remove = [] for attr, array in attributes.items(): - if is_columnar(array): - continue if is_nan_or_equivalent(array): keys_to_remove.append(attr) for key in keys_to_remove: del attributes[key] def _filter_with_mapping(self): - for component_type, attributes in self._data.items(): + for component_type, component_buffer in self._data.items(): if component_type in self._data_filter: filter_option = self._data_filter[component_type] - if filter_option is ComponentAttributeFilterOptions.relevant: - self._filter_attributes(attributes) + if filter_option is ComponentAttributeFilterOptions.relevant and is_columnar(component_buffer): + self._filter_attributes(component_buffer) def _post_filtering(self): if isinstance(self._data_filter, dict): diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index c805c3881..bf7ae7867 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -618,9 +618,9 @@ def assert_serialization_correct(deserialized_dataset: Dataset, serialized_datas def _check_only_relevant_attributes_present(component_values) -> bool: + if isinstance(component_values, np.ndarray): + return True for array in component_values.values(): - if not isinstance(array, np.ndarray): - continue if (array.dtype == np.float64 and np.isnan(array).all()) or ( array.dtype in (np.int32, np.int8) and np.all(array == np.iinfo(array.dtype).min) ): @@ -633,7 +633,8 @@ def assert_deserialization_filtering_correct(deserialized_dataset: Dataset, data return True if data_filter is ComponentAttributeFilterOptions.relevant: for component_values in deserialized_dataset.values(): - if not _check_only_relevant_attributes_present(component_values): + buffer = component_values if not is_sparse(component_values) else component_values["data"] + if not _check_only_relevant_attributes_present(buffer): return False return True From b46a6f56497cfd88c53d3a13de30024b4a5629ed Mon Sep 17 00:00:00 2001 From: Martijn Govers Date: Wed, 22 Jan 2025 16:35:17 +0100 Subject: [PATCH 2/2] add type hinting to prevent future similar issues Signed-off-by: Martijn Govers --- .../_core/power_grid_dataset.py | 19 ++++++++++++++----- tests/unit/test_serialization.py | 4 ++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/power_grid_model/_core/power_grid_dataset.py b/src/power_grid_model/_core/power_grid_dataset.py index ad0724b9e..28a5549f9 100644 --- a/src/power_grid_model/_core/power_grid_dataset.py +++ b/src/power_grid_model/_core/power_grid_dataset.py @@ -6,7 +6,7 @@ Power grid model raw dataset handler """ -from typing import Any, Mapping +from typing import Any, Mapping, cast from power_grid_model._core.buffer_handling import ( BufferProperties, @@ -27,7 +27,15 @@ ) from power_grid_model._core.power_grid_meta import ComponentMetaData, DatasetMetaData, power_grid_meta_data from power_grid_model._utils import get_dataset_type, is_columnar, is_nan_or_equivalent, is_sparse, process_data_filter -from power_grid_model.data_types import AttributeType, ComponentData, Dataset +from power_grid_model.data_types import ( + AttributeType, + ColumnarData, + ComponentData, + Dataset, + DenseBatchColumnarData, + SingleColumnarData, + SparseBatchColumnarData, +) from power_grid_model.enum import ComponentAttributeFilterOptions from power_grid_model.typing import ComponentAttributeMapping, _ComponentAttributeMappingDict @@ -455,11 +463,12 @@ def _get_buffer_properties(self, info: CDatasetInfo) -> Mapping[ComponentType, B if component in self._data_filter } - def _filter_attributes(self, buffer): + def _filter_attributes(self, buffer: ColumnarData): if is_sparse(buffer): - attributes = buffer["data"] + attributes = cast(SparseBatchColumnarData, buffer)["data"] else: - attributes = buffer + attributes = cast(SingleColumnarData | DenseBatchColumnarData, buffer) + keys_to_remove = [] for attr, array in attributes.items(): if is_nan_or_equivalent(array): diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index bf7ae7867..a7cc2663a 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -11,7 +11,7 @@ from power_grid_model import DatasetType from power_grid_model._utils import get_dataset_type, is_columnar, is_sparse -from power_grid_model.data_types import BatchDataset, Dataset, SingleDataset +from power_grid_model.data_types import BatchDataset, Dataset, DenseBatchData, SingleComponentData, SingleDataset from power_grid_model.enum import ComponentAttributeFilterOptions from power_grid_model.utils import json_deserialize, json_serialize, msgpack_deserialize, msgpack_serialize @@ -617,7 +617,7 @@ def assert_serialization_correct(deserialized_dataset: Dataset, serialized_datas ) -def _check_only_relevant_attributes_present(component_values) -> bool: +def _check_only_relevant_attributes_present(component_values: SingleComponentData | DenseBatchData) -> bool: if isinstance(component_values, np.ndarray): return True for array in component_values.values():