Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/dgc 2055 get batch size function #429

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

repos:
- repo: https://github.com/fsfe/reuse-tool
rev: v1.0.0
rev: v2.1.0
hooks:
- id: reuse
- repo: https://github.com/pycqa/isort
Expand Down
21 changes: 20 additions & 1 deletion src/power_grid_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import numpy as np

from power_grid_model.core.power_grid_dataset import get_dataset_type
from power_grid_model._utils import get_and_verify_batch_sizes
Jerry-Jinfeng-Guo marked this conversation as resolved.
Show resolved Hide resolved
from power_grid_model.core.power_grid_dataset import CConstDataset, get_dataset_type
Jerry-Jinfeng-Guo marked this conversation as resolved.
Show resolved Hide resolved
from power_grid_model.core.serialization import ( # pylint: disable=unused-import
json_deserialize,
json_serialize,
Expand Down Expand Up @@ -54,6 +55,24 @@ def _get_component_scenario(component_scenarios: BatchArray) -> np.ndarray:
return {component: _get_component_scenario(component_data) for component, component_data in dataset.items()}


def get_dataset_batch_size(dataset: BatchDataset) -> int:
Jerry-Jinfeng-Guo marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the number of scenarios in the batch dataset.

Args:
dataset: the batch dataset

Raises:
ValueError: if the batch dataset is inconsistent.

Returns:
The size of the batch dataset. Making use of existing _utils function.
Jerry-Jinfeng-Guo marked this conversation as resolved.
Show resolved Hide resolved
"""
if isinstance(dataset, CConstDataset):
return dataset.get_info().batch_size()
return get_and_verify_batch_sizes(dataset)
Jerry-Jinfeng-Guo marked this conversation as resolved.
Show resolved Hide resolved


def json_deserialize_from_file(file_path: Path) -> Dataset:
"""
Load and deserialize a JSON file to a new dataset.
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
# SPDX-License-Identifier: MPL-2.0

from pathlib import Path
from typing import Dict
from unittest.mock import MagicMock, mock_open, patch

import numpy as np
import pytest

from power_grid_model import LoadGenType, initialize_array
from power_grid_model.core.power_grid_dataset import CConstDataset
from power_grid_model.core.power_grid_meta import power_grid_meta_data
from power_grid_model.data_types import Dataset
from power_grid_model.utils import (
export_json_data,
get_dataset_batch_size,
get_dataset_scenario,
json_deserialize_from_file,
json_serialize_to_file,
Expand Down Expand Up @@ -41,6 +46,46 @@ def test_get_dataset_scenario():
get_dataset_scenario(data, 2)


@pytest.fixture
def batch_data() -> Dict[str, np.ndarray]:
Jerry-Jinfeng-Guo marked this conversation as resolved.
Show resolved Hide resolved
line = initialize_array("update", "line", (3, 2))
line["id"] = [[5, 6], [6, 7], [7, 5]]
line["from_status"] = [[1, 1], [1, 1], [1, 1]]

# Add batch for asym_load, which has 2-D array for p_specified
asym_load = initialize_array("update", "asym_load", (3, 2))
asym_load["id"] = [[9, 10], [9, 10], [9, 10]]

return {"line": line, "asym_load": asym_load}


def test_get_dataset_batch_size(batch_data):
assert get_dataset_batch_size(batch_data) == 3


def test_get_dataset_batch_size_sparse():
batch_size = 3
data = {
"node": {
"data": np.zeros(shape=3, dtype=power_grid_meta_data["input"]["node"]),
"indptr": np.array([0, 2, 3, 3]),
},
"sym_load": {
"data": np.zeros(shape=2, dtype=power_grid_meta_data["input"]["sym_load"]),
"indptr": np.array([0, 0, 1, 2]),
},
"asym_load": {
"data": np.zeros(shape=4, dtype=power_grid_meta_data["input"]["asym_load"]),
"indptr": np.array([0, 2, 3, 4]),
},
"link": np.zeros(shape=(batch_size, 4), dtype=power_grid_meta_data["input"]["link"]),
}

dataset = CConstDataset(data, ["input"])

assert get_dataset_batch_size(dataset) == 3
Jerry-Jinfeng-Guo marked this conversation as resolved.
Show resolved Hide resolved


@patch("builtins.open", new_callable=mock_open)
@patch("power_grid_model.utils.json_deserialize")
def test_json_deserialize_from_file(deserialize_mock: MagicMock, open_mock: MagicMock):
Expand Down