Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: add tmp_remove_nodes method to graph #17

Merged
merged 15 commits into from
Feb 5, 2025
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0
1.1
41 changes: 41 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: MPL-2.0

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Generator

import numpy as np
from numpy._typing import NDArray
Expand Down Expand Up @@ -34,6 +36,13 @@ def nr_nodes(self):
def nr_branches(self):
"""Returns the number of branches in the graph"""

@property
@abstractmethod
def all_branches(self) -> list[frozenset[int]]:
"""Returns all branches in the graph as a list of node pairs (frozensets).
Warning: Depending on graph engine, performance could be slow for large graphs
"""

@abstractmethod
def external_to_internal(self, ext_node_id: int) -> int:
"""Convert external node id to internal node id (internal)
Expand Down Expand Up @@ -164,6 +173,31 @@ def delete_branch3_array(self, branch_array: Branch3Array, raise_on_fail: bool =
branches = _get_branch3_branches(branch3)
self.delete_branch_array(branches, raise_on_fail=raise_on_fail)

@contextmanager
def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
"""Context manager that temporarily removes nodes and their branches from the graph.
Example:
>>> with graph.tmp_remove_nodes([1, 2, 3]):
>>> assert not graph.has_node(1)
>>> assert graph.has_node(1)
In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without
considering certain nodes.
Thijss marked this conversation as resolved.
Show resolved Hide resolved
"""
edge_list = []
for node in nodes:
internal_node = self.external_to_internal(node)
node_edges = [
(self.internal_to_external(source), self.internal_to_external(target))
for source, target in self._in_edges(internal_node)
]
edge_list += node_edges
self._delete_node(internal_node)
yield edge_list
Thijss marked this conversation as resolved.
Show resolved Hide resolved
for node in nodes:
self.add_node(node)
for source, target in edge_list:
self.add_branch(source, target)

def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
"""Calculate the shortest path between two nodes

Expand Down Expand Up @@ -270,6 +304,13 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
return branch.is_active.item()
return True

@abstractmethod
def _in_edges(self, internal_node: int) -> list[tuple[int, int]]:
"""Return all edges a node occurs in.

Return a list of tuples with the source and target node id. These are internal node ids.
"""
vincentkoppen marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...

Expand Down
12 changes: 12 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def nr_nodes(self):
def nr_branches(self):
return self._graph.num_edges()

@property
def all_branches(self) -> list[frozenset[int]]:
internal_branches = ((source, target) for source, target in self._graph.edge_list())
external_branches = [
frozenset([self.internal_to_external(source), self.internal_to_external(target)])
for source, target in internal_branches
]
return external_branches

@property
def external_ids(self) -> list[int]:
return list(self._external_to_internal.keys())
Expand Down Expand Up @@ -99,6 +108,9 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo

return connected_nodes

def _in_edges(self, internal_node: int) -> list[tuple[int, int]]:
return [(source, target) for source, target, _ in self._graph.in_edges(internal_node)]

def _find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph using Rustworkx.

Expand Down
53 changes: 53 additions & 0 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

"""Grid tests"""

from collections import Counter

import numpy as np
import pytest
from numpy.testing import assert_array_equal
Expand All @@ -27,6 +29,34 @@ def test_graph_initialize(graph):
assert 1 == graph.nr_branches


def test_graph_has_branch(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)

assert graph.has_branch(1, 2)
assert graph.has_branch(2, 1) # reversed should work too
assert not graph.has_branch(1, 3)


def test_graph_all_branches(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)

assert [{1, 2}] == graph.all_branches


def test_graph_all_branches_parallel(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)
graph.add_branch(1, 2)
graph.add_branch(2, 1)

assert [{1, 2}, {1, 2}, {1, 2}] == graph.all_branches


def test_graph_delete_branch(graph):
"""Test whether a branch is deleted correctly"""
graph.add_node(1)
Expand Down Expand Up @@ -310,3 +340,26 @@ def test_get_connected_ignore_multiple_nodes(self, graph_with_2_routes):
connected_nodes = graph.get_connected(node_id=1, nodes_to_ignore=[2, 4])

assert {5} == set(connected_nodes)


def test_tmp_remove_nodes(graph_with_2_routes) -> None:
graph = graph_with_2_routes

assert graph.nr_branches == 4

# add parallel branches to test whether they are restored correctly
graph.add_branch(1, 5)
graph.add_branch(5, 1)

assert graph.nr_nodes == 5
assert graph.nr_branches == 6
counter_before: Counter[frozenset] = Counter(graph.all_branches)

with graph.tmp_remove_nodes([1, 2]):
assert graph.nr_nodes == 3
assert graph.all_branches == [{4, 5}]

assert graph.nr_nodes == 5
assert graph.nr_branches == 6
counter_after: Counter[frozenset] = Counter(graph.all_branches)
assert counter_before == counter_after
Loading