Skip to content

Commit

Permalink
add type hinting to prevent future similar issues
Browse files Browse the repository at this point in the history
Signed-off-by: Martijn Govers <[email protected]>
  • Loading branch information
mgovers committed Jan 22, 2025
1 parent 78b5783 commit b46a6f5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
19 changes: 14 additions & 5 deletions src/power_grid_model/_core/power_grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Power grid model raw dataset handler
"""

from typing import Any, Mapping
from typing import Any, Mapping, cast

from power_grid_model._core.buffer_handling import (
BufferProperties,
Expand All @@ -27,7 +27,15 @@
)
from power_grid_model._core.power_grid_meta import ComponentMetaData, DatasetMetaData, power_grid_meta_data
from power_grid_model._utils import get_dataset_type, is_columnar, is_nan_or_equivalent, is_sparse, process_data_filter
from power_grid_model.data_types import AttributeType, ComponentData, Dataset
from power_grid_model.data_types import (
AttributeType,
ColumnarData,
ComponentData,
Dataset,
DenseBatchColumnarData,
SingleColumnarData,
SparseBatchColumnarData,
)
from power_grid_model.enum import ComponentAttributeFilterOptions
from power_grid_model.typing import ComponentAttributeMapping, _ComponentAttributeMappingDict

Expand Down Expand Up @@ -455,11 +463,12 @@ def _get_buffer_properties(self, info: CDatasetInfo) -> Mapping[ComponentType, B
if component in self._data_filter
}

def _filter_attributes(self, buffer):
def _filter_attributes(self, buffer: ColumnarData):
if is_sparse(buffer):
attributes = buffer["data"]
attributes = cast(SparseBatchColumnarData, buffer)["data"]
else:
attributes = buffer
attributes = cast(SingleColumnarData | DenseBatchColumnarData, buffer)

keys_to_remove = []
for attr, array in attributes.items():
if is_nan_or_equivalent(array):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from power_grid_model import DatasetType
from power_grid_model._utils import get_dataset_type, is_columnar, is_sparse
from power_grid_model.data_types import BatchDataset, Dataset, SingleDataset
from power_grid_model.data_types import BatchDataset, Dataset, DenseBatchData, SingleComponentData, SingleDataset
from power_grid_model.enum import ComponentAttributeFilterOptions
from power_grid_model.utils import json_deserialize, json_serialize, msgpack_deserialize, msgpack_serialize

Expand Down Expand Up @@ -617,7 +617,7 @@ def assert_serialization_correct(deserialized_dataset: Dataset, serialized_datas
)


def _check_only_relevant_attributes_present(component_values) -> bool:
def _check_only_relevant_attributes_present(component_values: SingleComponentData | DenseBatchData) -> bool:
if isinstance(component_values, np.ndarray):
return True
for array in component_values.values():
Expand Down

0 comments on commit b46a6f5

Please sign in to comment.