From b46a6f56497cfd88c53d3a13de30024b4a5629ed Mon Sep 17 00:00:00 2001 From: Martijn Govers Date: Wed, 22 Jan 2025 16:35:17 +0100 Subject: [PATCH] add type hinting to prevent future similar issues Signed-off-by: Martijn Govers --- .../_core/power_grid_dataset.py | 19 ++++++++++++++----- tests/unit/test_serialization.py | 4 ++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/power_grid_model/_core/power_grid_dataset.py b/src/power_grid_model/_core/power_grid_dataset.py index ad0724b9e..28a5549f9 100644 --- a/src/power_grid_model/_core/power_grid_dataset.py +++ b/src/power_grid_model/_core/power_grid_dataset.py @@ -6,7 +6,7 @@ Power grid model raw dataset handler """ -from typing import Any, Mapping +from typing import Any, Mapping, cast from power_grid_model._core.buffer_handling import ( BufferProperties, @@ -27,7 +27,15 @@ ) from power_grid_model._core.power_grid_meta import ComponentMetaData, DatasetMetaData, power_grid_meta_data from power_grid_model._utils import get_dataset_type, is_columnar, is_nan_or_equivalent, is_sparse, process_data_filter -from power_grid_model.data_types import AttributeType, ComponentData, Dataset +from power_grid_model.data_types import ( + AttributeType, + ColumnarData, + ComponentData, + Dataset, + DenseBatchColumnarData, + SingleColumnarData, + SparseBatchColumnarData, +) from power_grid_model.enum import ComponentAttributeFilterOptions from power_grid_model.typing import ComponentAttributeMapping, _ComponentAttributeMappingDict @@ -455,11 +463,12 @@ def _get_buffer_properties(self, info: CDatasetInfo) -> Mapping[ComponentType, B if component in self._data_filter } - def _filter_attributes(self, buffer): + def _filter_attributes(self, buffer: ColumnarData): if is_sparse(buffer): - attributes = buffer["data"] + attributes = cast(SparseBatchColumnarData, buffer)["data"] else: - attributes = buffer + attributes = cast(SingleColumnarData | DenseBatchColumnarData, buffer) + keys_to_remove = [] for attr, array in attributes.items(): if is_nan_or_equivalent(array): diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index bf7ae7867..a7cc2663a 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -11,7 +11,7 @@ from power_grid_model import DatasetType from power_grid_model._utils import get_dataset_type, is_columnar, is_sparse -from power_grid_model.data_types import BatchDataset, Dataset, SingleDataset +from power_grid_model.data_types import BatchDataset, Dataset, DenseBatchData, SingleComponentData, SingleDataset from power_grid_model.enum import ComponentAttributeFilterOptions from power_grid_model.utils import json_deserialize, json_serialize, msgpack_deserialize, msgpack_serialize @@ -617,7 +617,7 @@ def assert_serialization_correct(deserialized_dataset: Dataset, serialized_datas ) -def _check_only_relevant_attributes_present(component_values) -> bool: +def _check_only_relevant_attributes_present(component_values: SingleComponentData | DenseBatchData) -> bool: if isinstance(component_values, np.ndarray): return True for array in component_values.values():