Skip to content

Commit

Permalink
Merge pull request #881 from PowerGridModel/fix/attribute-relevant-fi…
Browse files Browse the repository at this point in the history
…lter

Fix a bug on relevant filter
  • Loading branch information
mgovers authored Jan 23, 2025
2 parents a4be26e + b46a6f5 commit 9fc25ef
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
29 changes: 20 additions & 9 deletions src/power_grid_model/_core/power_grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Power grid model raw dataset handler
"""

from typing import Any, Mapping
from typing import Any, Mapping, cast

from power_grid_model._core.buffer_handling import (
BufferProperties,
Expand All @@ -26,8 +26,16 @@
power_grid_core as pgc,
)
from power_grid_model._core.power_grid_meta import ComponentMetaData, DatasetMetaData, power_grid_meta_data
from power_grid_model._utils import get_dataset_type, is_columnar, is_nan_or_equivalent, process_data_filter
from power_grid_model.data_types import AttributeType, ComponentData, Dataset
from power_grid_model._utils import get_dataset_type, is_columnar, is_nan_or_equivalent, is_sparse, process_data_filter
from power_grid_model.data_types import (
AttributeType,
ColumnarData,
ComponentData,
Dataset,
DenseBatchColumnarData,
SingleColumnarData,
SparseBatchColumnarData,
)
from power_grid_model.enum import ComponentAttributeFilterOptions
from power_grid_model.typing import ComponentAttributeMapping, _ComponentAttributeMappingDict

Expand Down Expand Up @@ -455,22 +463,25 @@ def _get_buffer_properties(self, info: CDatasetInfo) -> Mapping[ComponentType, B
if component in self._data_filter
}

def _filter_attributes(self, attributes):
def _filter_attributes(self, buffer: ColumnarData):
if is_sparse(buffer):
attributes = cast(SparseBatchColumnarData, buffer)["data"]
else:
attributes = cast(SingleColumnarData | DenseBatchColumnarData, buffer)

keys_to_remove = []
for attr, array in attributes.items():
if is_columnar(array):
continue
if is_nan_or_equivalent(array):
keys_to_remove.append(attr)
for key in keys_to_remove:
del attributes[key]

def _filter_with_mapping(self):
for component_type, attributes in self._data.items():
for component_type, component_buffer in self._data.items():
if component_type in self._data_filter:
filter_option = self._data_filter[component_type]
if filter_option is ComponentAttributeFilterOptions.relevant:
self._filter_attributes(attributes)
if filter_option is ComponentAttributeFilterOptions.relevant and is_columnar(component_buffer):
self._filter_attributes(component_buffer)

def _post_filtering(self):
if isinstance(self._data_filter, dict):
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from power_grid_model import DatasetType
from power_grid_model._utils import get_dataset_type, is_columnar, is_sparse
from power_grid_model.data_types import BatchDataset, Dataset, SingleDataset
from power_grid_model.data_types import BatchDataset, Dataset, DenseBatchData, SingleComponentData, SingleDataset
from power_grid_model.enum import ComponentAttributeFilterOptions
from power_grid_model.utils import json_deserialize, json_serialize, msgpack_deserialize, msgpack_serialize

Expand Down Expand Up @@ -617,10 +617,10 @@ def assert_serialization_correct(deserialized_dataset: Dataset, serialized_datas
)


def _check_only_relevant_attributes_present(component_values) -> bool:
def _check_only_relevant_attributes_present(component_values: SingleComponentData | DenseBatchData) -> bool:
if isinstance(component_values, np.ndarray):
return True
for array in component_values.values():
if not isinstance(array, np.ndarray):
continue
if (array.dtype == np.float64 and np.isnan(array).all()) or (
array.dtype in (np.int32, np.int8) and np.all(array == np.iinfo(array.dtype).min)
):
Expand All @@ -633,7 +633,8 @@ def assert_deserialization_filtering_correct(deserialized_dataset: Dataset, data
return True
if data_filter is ComponentAttributeFilterOptions.relevant:
for component_values in deserialized_dataset.values():
if not _check_only_relevant_attributes_present(component_values):
buffer = component_values if not is_sparse(component_values) else component_values["data"]
if not _check_only_relevant_attributes_present(buffer):
return False
return True

Expand Down

0 comments on commit 9fc25ef

Please sign in to comment.