diff --git a/VERSION b/VERSION index 9f8e9b6..b123147 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0 \ No newline at end of file +1.1 \ No newline at end of file diff --git a/src/power_grid_model_ds/_core/model/graphs/models/base.py b/src/power_grid_model_ds/_core/model/graphs/models/base.py index 2a4a577..4f0a7bd 100644 --- a/src/power_grid_model_ds/_core/model/graphs/models/base.py +++ b/src/power_grid_model_ds/_core/model/graphs/models/base.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: MPL-2.0 from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Generator import numpy as np from numpy._typing import NDArray @@ -34,6 +36,14 @@ def nr_nodes(self): def nr_branches(self): """Returns the number of branches in the graph""" + @property + def all_branches(self) -> Generator[tuple[int, int], None, None]: + """Returns all branches in the graph.""" + return ( + (self.internal_to_external(source), self.internal_to_external(target)) + for source, target in self._all_branches() + ) + @abstractmethod def external_to_internal(self, ext_node_id: int) -> int: """Convert external node id to internal node id (internal) @@ -63,6 +73,14 @@ def has_node(self, node_id: int) -> bool: return self._has_node(node_id=internal_node_id) + def in_branches(self, node_id: int) -> Generator[tuple[int, int], None, None]: + """Return all branches that have the node as an endpoint.""" + int_node_id = self.external_to_internal(node_id) + internal_edges = self._in_branches(int_node_id=int_node_id) + return ( + (self.internal_to_external(source), self.internal_to_external(target)) for source, target in internal_edges + ) + def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None: """Add a node to the graph.""" if self.has_node(ext_node_id): @@ -164,6 +182,28 @@ def delete_branch3_array(self, branch_array: Branch3Array, raise_on_fail: bool = branches = _get_branch3_branches(branch3) self.delete_branch_array(branches, raise_on_fail=raise_on_fail) + @contextmanager + def tmp_remove_nodes(self, nodes: list[int]) -> Generator: + """Context manager that temporarily removes nodes and their branches from the graph. + Example: + >>> with graph.tmp_remove_nodes([1, 2, 3]): + >>> assert not graph.has_node(1) + >>> assert graph.has_node(1) + In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without + considering certain nodes. + """ + edge_list = [] + for node in nodes: + edge_list += list(self.in_branches(node)) + self.delete_node(node) + + yield + + for node in nodes: + self.add_node(node) + for source, target in edge_list: + self.add_branch(source, target) + def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]: """Calculate the shortest path between two nodes @@ -270,6 +310,9 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool: return branch.is_active.item() return True + @abstractmethod + def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]: ... + @abstractmethod def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ... @@ -307,6 +350,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ... @abstractmethod def _find_fundamental_cycles(self) -> list[list[int]]: ... + @abstractmethod + def _all_branches(self) -> Generator[tuple[int, int], None, None]: ... + def _get_branch3_branches(branch3: Branch3Array) -> BranchArray: node_1 = branch3.node_1.item() diff --git a/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py b/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py index 6e1b7cd..b509391 100644 --- a/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py +++ b/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MPL-2.0 import logging +from typing import Generator import rustworkx as rx from rustworkx import NoEdgeBetweenNodes @@ -99,6 +100,9 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo return connected_nodes + def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]: + return ((source, target) for source, target, _ in self._graph.in_edges(int_node_id)) + def _find_fundamental_cycles(self) -> list[list[int]]: """Find all fundamental cycles in the graph using Rustworkx. @@ -107,6 +111,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]: """ return find_fundamental_cycles_rustworkx(self._graph) + def _all_branches(self) -> Generator[tuple[int, int], None, None]: + return ((source, target) for source, target in self._graph.edge_list()) + class _NodeVisitor(BFSVisitor): def __init__(self, nodes_to_ignore: list[int]): diff --git a/tests/unit/model/graphs/test_graph_model.py b/tests/unit/model/graphs/test_graph_model.py index 628c77b..ca9e137 100644 --- a/tests/unit/model/graphs/test_graph_model.py +++ b/tests/unit/model/graphs/test_graph_model.py @@ -4,6 +4,8 @@ """Grid tests""" +from collections import Counter + import numpy as np import pytest from numpy.testing import assert_array_equal @@ -37,6 +39,35 @@ def test_graph_has_branch(graph): assert not graph.has_branch(1, 3) +def test_graph_all_branches(graph): + graph.add_node(1) + graph.add_node(2) + graph.add_branch(1, 2) + + assert [(1, 2)] == list(graph.all_branches) + + +def test_graph_all_branches_parallel(graph): + graph.add_node(1) + graph.add_node(2) + graph.add_branch(1, 2) + graph.add_branch(1, 2) + graph.add_branch(2, 1) + + assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches) + + +def test_graph_in_branches(graph): + graph.add_node(1) + graph.add_node(2) + graph.add_branch(1, 2) + graph.add_branch(1, 2) + graph.add_branch(2, 1) + + assert [(2, 1), (2, 1), (2, 1)] == list(graph.in_branches(1)) + assert [(1, 2), (1, 2), (1, 2)] == list(graph.in_branches(2)) + + def test_graph_delete_branch(graph): """Test whether a branch is deleted correctly""" graph.add_node(1) @@ -320,3 +351,30 @@ def test_get_connected_ignore_multiple_nodes(self, graph_with_2_routes): connected_nodes = graph.get_connected(node_id=1, nodes_to_ignore=[2, 4]) assert {5} == set(connected_nodes) + + +def test_tmp_remove_nodes(graph_with_2_routes) -> None: + graph = graph_with_2_routes + + assert graph.nr_branches == 4 + + # add parallel branches to test whether they are restored correctly + graph.add_branch(1, 5) + graph.add_branch(5, 1) + + assert graph.nr_nodes == 5 + assert graph.nr_branches == 6 + + before_sets = [frozenset(branch) for branch in graph.all_branches] + counter_before = Counter(before_sets) + + with graph.tmp_remove_nodes([1, 2]): + assert graph.nr_nodes == 3 + assert list(graph.all_branches) == [(5, 4)] + + assert graph.nr_nodes == 5 + assert graph.nr_branches == 6 + + after_sets = [frozenset(branch) for branch in graph.all_branches] + counter_after = Counter(after_sets) + assert counter_before == counter_after