diff --git a/VERSION b/VERSION index 9f8e9b6..b123147 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0 \ No newline at end of file +1.1 \ No newline at end of file diff --git a/src/power_grid_model_ds/_core/model/graphs/models/base.py b/src/power_grid_model_ds/_core/model/graphs/models/base.py index 2ce1b88..612e556 100644 --- a/src/power_grid_model_ds/_core/model/graphs/models/base.py +++ b/src/power_grid_model_ds/_core/model/graphs/models/base.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MPL-2.0 from abc import ABC, abstractmethod +from typing import Generator import numpy as np from numpy._typing import NDArray @@ -34,6 +35,14 @@ def nr_nodes(self): def nr_branches(self): """Returns the number of branches in the graph""" + @property + def all_branches(self) -> Generator[tuple[int, int], None, None]: + """Returns all branches in the graph.""" + return ( + (self.internal_to_external(source), self.internal_to_external(target)) + for source, target in self._all_branches() + ) + @abstractmethod def external_to_internal(self, ext_node_id: int) -> int: """Convert external node id to internal node id (internal) @@ -307,6 +316,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ... @abstractmethod def _find_fundamental_cycles(self) -> list[list[int]]: ... + @abstractmethod + def _all_branches(self) -> Generator[tuple[int, int], None, None]: ... + def _get_branch3_branches(branch3: Branch3Array) -> BranchArray: node_1 = branch3.node_1.item() diff --git a/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py b/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py index 6e1b7cd..ee2cf46 100644 --- a/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py +++ b/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MPL-2.0 import logging +from typing import Generator import rustworkx as rx from rustworkx import NoEdgeBetweenNodes @@ -107,6 +108,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]: """ return find_fundamental_cycles_rustworkx(self._graph) + def _all_branches(self) -> Generator[tuple[int, int], None, None]: + return ((source, target) for source, target in self._graph.edge_list()) + class _NodeVisitor(BFSVisitor): def __init__(self, nodes_to_ignore: list[int]): diff --git a/tests/unit/model/graphs/test_graph_model.py b/tests/unit/model/graphs/test_graph_model.py index 628c77b..15e250e 100644 --- a/tests/unit/model/graphs/test_graph_model.py +++ b/tests/unit/model/graphs/test_graph_model.py @@ -37,6 +37,24 @@ def test_graph_has_branch(graph): assert not graph.has_branch(1, 3) +def test_graph_all_branches(graph): + graph.add_node(1) + graph.add_node(2) + graph.add_branch(1, 2) + + assert [(1, 2)] == list(graph.all_branches) + + +def test_graph_all_branches_parallel(graph): + graph.add_node(1) + graph.add_node(2) + graph.add_branch(1, 2) + graph.add_branch(1, 2) + graph.add_branch(2, 1) + + assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches) + + def test_graph_delete_branch(graph): """Test whether a branch is deleted correctly""" graph.add_node(1)