Skip to content

Commit

Permalink
Feature: add .all_branches property to graph (#16)
Browse files Browse the repository at this point in the history
* Add .all_branches property to BaseGraphModel

Signed-off-by: Thijs Baaijen <[email protected]>

* Apply suggestions from code review

Co-authored-by: Vincent Koppen <[email protected]>
Signed-off-by: Thijs Baaijen <[email protected]>

---------

Signed-off-by: Thijs Baaijen <[email protected]>
Co-authored-by: Vincent Koppen <[email protected]>
  • Loading branch information
Thijss and vincentkoppen authored Feb 5, 2025
1 parent d28aba7 commit d61f85d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0
1.1
12 changes: 12 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: MPL-2.0

from abc import ABC, abstractmethod
from typing import Generator

import numpy as np
from numpy._typing import NDArray
Expand Down Expand Up @@ -34,6 +35,14 @@ def nr_nodes(self):
def nr_branches(self):
"""Returns the number of branches in the graph"""

@property
def all_branches(self) -> Generator[tuple[int, int], None, None]:
"""Returns all branches in the graph."""
return (
(self.internal_to_external(source), self.internal_to_external(target))
for source, target in self._all_branches()
)

@abstractmethod
def external_to_internal(self, ext_node_id: int) -> int:
"""Convert external node id to internal node id (internal)
Expand Down Expand Up @@ -307,6 +316,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ...
@abstractmethod
def _find_fundamental_cycles(self) -> list[list[int]]: ...

@abstractmethod
def _all_branches(self) -> Generator[tuple[int, int], None, None]: ...


def _get_branch3_branches(branch3: Branch3Array) -> BranchArray:
node_1 = branch3.node_1.item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: MPL-2.0

import logging
from typing import Generator

import rustworkx as rx
from rustworkx import NoEdgeBetweenNodes
Expand Down Expand Up @@ -107,6 +108,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]:
"""
return find_fundamental_cycles_rustworkx(self._graph)

def _all_branches(self) -> Generator[tuple[int, int], None, None]:
return ((source, target) for source, target in self._graph.edge_list())


class _NodeVisitor(BFSVisitor):
def __init__(self, nodes_to_ignore: list[int]):
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ def test_graph_has_branch(graph):
assert not graph.has_branch(1, 3)


def test_graph_all_branches(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)

assert [(1, 2)] == list(graph.all_branches)


def test_graph_all_branches_parallel(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)
graph.add_branch(1, 2)
graph.add_branch(2, 1)

assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches)


def test_graph_delete_branch(graph):
"""Test whether a branch is deleted correctly"""
graph.add_node(1)
Expand Down

0 comments on commit d61f85d

Please sign in to comment.