Skip to content

Commit

Permalink
Merge branch 'feat/tmp-remove-nodes' of https://github.com/PowerGridM…
Browse files Browse the repository at this point in the history
…odel/power-grid-model-ds into feat/improve_get_components
  • Loading branch information
vincentkoppen committed Jan 29, 2025
2 parents 76fadb3 + c223e83 commit a732914
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 1 deletion.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0
1.1
46 changes: 46 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: MPL-2.0

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Generator

import numpy as np
from numpy._typing import NDArray
Expand Down Expand Up @@ -34,6 +36,14 @@ def nr_nodes(self):
def nr_branches(self):
"""Returns the number of branches in the graph"""

@property
def all_branches(self) -> Generator[tuple[int, int], None, None]:
"""Returns all branches in the graph."""
return (
(self.internal_to_external(source), self.internal_to_external(target))
for source, target in self._all_branches()
)

@abstractmethod
def external_to_internal(self, ext_node_id: int) -> int:
"""Convert external node id to internal node id (internal)
Expand Down Expand Up @@ -63,6 +73,14 @@ def has_node(self, node_id: int) -> bool:

return self._has_node(node_id=internal_node_id)

def in_branches(self, node_id: int) -> Generator[tuple[int, int], None, None]:
"""Return all branches that have the node as an endpoint."""
int_node_id = self.external_to_internal(node_id)
internal_edges = self._in_branches(int_node_id=int_node_id)
return (
(self.internal_to_external(source), self.internal_to_external(target)) for source, target in internal_edges
)

def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
"""Add a node to the graph."""
if self.has_node(ext_node_id):
Expand Down Expand Up @@ -164,6 +182,28 @@ def delete_branch3_array(self, branch_array: Branch3Array, raise_on_fail: bool =
branches = _get_branch3_branches(branch3)
self.delete_branch_array(branches, raise_on_fail=raise_on_fail)

@contextmanager
def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
"""Context manager that temporarily removes nodes and their branches from the graph.
Example:
>>> with graph.tmp_remove_nodes([1, 2, 3]):
>>> assert not graph.has_node(1)
>>> assert graph.has_node(1)
In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without
considering certain nodes.
"""
edge_list = []
for node in nodes:
edge_list += list(self.in_branches(node))
self.delete_node(node)

yield

for node in nodes:
self.add_node(node)
for source, target in edge_list:
self.add_branch(source, target)

def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
"""Calculate the shortest path between two nodes
Expand Down Expand Up @@ -270,6 +310,9 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
return branch.is_active.item()
return True

@abstractmethod
def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]: ...

@abstractmethod
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...

Expand Down Expand Up @@ -307,6 +350,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ...
@abstractmethod
def _find_fundamental_cycles(self) -> list[list[int]]: ...

@abstractmethod
def _all_branches(self) -> Generator[tuple[int, int], None, None]: ...


def _get_branch3_branches(branch3: Branch3Array) -> BranchArray:
node_1 = branch3.node_1.item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: MPL-2.0

import logging
from typing import Generator

import rustworkx as rx
from rustworkx import NoEdgeBetweenNodes
Expand Down Expand Up @@ -99,6 +100,9 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo

return connected_nodes

def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]:
return ((source, target) for source, target, _ in self._graph.in_edges(int_node_id))

def _find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph using Rustworkx.
Expand All @@ -107,6 +111,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]:
"""
return find_fundamental_cycles_rustworkx(self._graph)

def _all_branches(self) -> Generator[tuple[int, int], None, None]:
return ((source, target) for source, target in self._graph.edge_list())


class _NodeVisitor(BFSVisitor):
def __init__(self, nodes_to_ignore: list[int]):
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

"""Grid tests"""

from collections import Counter

import numpy as np
import pytest
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -37,6 +39,35 @@ def test_graph_has_branch(graph):
assert not graph.has_branch(1, 3)


def test_graph_all_branches(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)

assert [(1, 2)] == list(graph.all_branches)


def test_graph_all_branches_parallel(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)
graph.add_branch(1, 2)
graph.add_branch(2, 1)

assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches)


def test_graph_in_branches(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)
graph.add_branch(1, 2)
graph.add_branch(2, 1)

assert [(2, 1), (2, 1), (2, 1)] == list(graph.in_branches(1))
assert [(1, 2), (1, 2), (1, 2)] == list(graph.in_branches(2))


def test_graph_delete_branch(graph):
"""Test whether a branch is deleted correctly"""
graph.add_node(1)
Expand Down Expand Up @@ -320,3 +351,30 @@ def test_get_connected_ignore_multiple_nodes(self, graph_with_2_routes):
connected_nodes = graph.get_connected(node_id=1, nodes_to_ignore=[2, 4])

assert {5} == set(connected_nodes)


def test_tmp_remove_nodes(graph_with_2_routes) -> None:
graph = graph_with_2_routes

assert graph.nr_branches == 4

# add parallel branches to test whether they are restored correctly
graph.add_branch(1, 5)
graph.add_branch(5, 1)

assert graph.nr_nodes == 5
assert graph.nr_branches == 6

before_sets = [frozenset(branch) for branch in graph.all_branches]
counter_before = Counter(before_sets)

with graph.tmp_remove_nodes([1, 2]):
assert graph.nr_nodes == 3
assert list(graph.all_branches) == [(5, 4)]

assert graph.nr_nodes == 5
assert graph.nr_branches == 6

after_sets = [frozenset(branch) for branch in graph.all_branches]
counter_after = Counter(after_sets)
assert counter_before == counter_after

0 comments on commit a732914

Please sign in to comment.