Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: add tmp_remove_nodes method to graph #17

Open
wants to merge 14 commits into
base: release1_1
Choose a base branch
from
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0
1.1
50 changes: 50 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: MPL-2.0

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Generator

import numpy as np
from numpy._typing import NDArray
Expand Down Expand Up @@ -34,6 +36,14 @@ def nr_nodes(self):
def nr_branches(self):
"""Returns the number of branches in the graph"""

@property
def all_branches(self) -> Generator[tuple[int, int], None, None]:
"""Returns all branches in the graph."""
return (
(self.internal_to_external(source), self.internal_to_external(target))
for source, target in self._all_branches()
)

@abstractmethod
def external_to_internal(self, ext_node_id: int) -> int:
"""Convert external node id to internal node id (internal)
Expand Down Expand Up @@ -63,6 +73,14 @@ def has_node(self, node_id: int) -> bool:

return self._has_node(node_id=internal_node_id)

def in_edges(self, node_id: int) -> Generator[tuple[int, int], None, None]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps it is better to call it in_branches?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think in_branches is better. More consistent with the other public functions.

"""Return all edges a node occurs in."""
Thijss marked this conversation as resolved.
Show resolved Hide resolved
int_node_id = self.external_to_internal(node_id)
internal_edges = self._in_edges(int_node_id=int_node_id)
return (
(self.internal_to_external(source), self.internal_to_external(target)) for source, target in internal_edges
)

def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
"""Add a node to the graph."""
if self.has_node(ext_node_id):
Expand Down Expand Up @@ -164,6 +182,28 @@ def delete_branch3_array(self, branch_array: Branch3Array, raise_on_fail: bool =
branches = _get_branch3_branches(branch3)
self.delete_branch_array(branches, raise_on_fail=raise_on_fail)

@contextmanager
def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
"""Context manager that temporarily removes nodes and their branches from the graph.
Example:
>>> with graph.tmp_remove_nodes([1, 2, 3]):
>>> assert not graph.has_node(1)
>>> assert graph.has_node(1)
In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without
considering certain nodes.
Thijss marked this conversation as resolved.
Show resolved Hide resolved
"""
edge_list = []
for node in nodes:
internal_node = self.external_to_internal(node)
edge_list += list(self.in_edges(node))
self._delete_node(internal_node)
Thijss marked this conversation as resolved.
Show resolved Hide resolved
yield

for node in nodes:
self.add_node(node)
for source, target in edge_list:
self.add_branch(source, target)

def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
"""Calculate the shortest path between two nodes

Expand Down Expand Up @@ -270,6 +310,13 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
return branch.is_active.item()
return True

@abstractmethod
def _in_edges(self, int_node_id: int) -> Generator[tuple[int, int], None, None]:
"""Return all edges a node occurs in.
Return a list of tuples with the source and target node id.
These are internal node ids.
"""

@abstractmethod
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...

Expand Down Expand Up @@ -307,6 +354,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ...
@abstractmethod
def _find_fundamental_cycles(self) -> list[list[int]]: ...

@abstractmethod
def _all_branches(self) -> Generator[tuple[int, int], None, None]: ...


def _get_branch3_branches(branch3: Branch3Array) -> BranchArray:
node_1 = branch3.node_1.item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: MPL-2.0

import logging
from typing import Generator

import rustworkx as rx
from rustworkx import NoEdgeBetweenNodes
Expand Down Expand Up @@ -99,6 +100,9 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo

return connected_nodes

def _in_edges(self, int_node_id: int) -> Generator[tuple[int, int], None, None]:
return ((source, target) for source, target, _ in self._graph.in_edges(int_node_id))

def _find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph using Rustworkx.

Expand All @@ -107,6 +111,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]:
"""
return find_fundamental_cycles_rustworkx(self._graph)

def _all_branches(self) -> Generator[tuple[int, int], None, None]:
return ((source, target) for source, target in self._graph.edge_list())


class _NodeVisitor(BFSVisitor):
def __init__(self, nodes_to_ignore: list[int]):
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

"""Grid tests"""

from collections import Counter

import numpy as np
import pytest
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -37,6 +39,35 @@ def test_graph_has_branch(graph):
assert not graph.has_branch(1, 3)


def test_graph_all_branches(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)

assert [(1, 2)] == list(graph.all_branches)


def test_graph_all_branches_parallel(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)
graph.add_branch(1, 2)
graph.add_branch(2, 1)

assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches)


def test_graph_in_edges(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)
graph.add_branch(1, 2)
graph.add_branch(2, 1)

assert [(2, 1), (2, 1), (2, 1)] == list(graph.in_edges(1))
assert [(1, 2), (1, 2), (1, 2)] == list(graph.in_edges(2))


def test_graph_delete_branch(graph):
"""Test whether a branch is deleted correctly"""
graph.add_node(1)
Expand Down Expand Up @@ -320,3 +351,30 @@ def test_get_connected_ignore_multiple_nodes(self, graph_with_2_routes):
connected_nodes = graph.get_connected(node_id=1, nodes_to_ignore=[2, 4])

assert {5} == set(connected_nodes)


def test_tmp_remove_nodes(graph_with_2_routes) -> None:
graph = graph_with_2_routes

assert graph.nr_branches == 4

# add parallel branches to test whether they are restored correctly
graph.add_branch(1, 5)
graph.add_branch(5, 1)

assert graph.nr_nodes == 5
assert graph.nr_branches == 6

before_sets = [frozenset(branch) for branch in graph.all_branches]
vincentkoppen marked this conversation as resolved.
Show resolved Hide resolved
counter_before = Counter(before_sets)

with graph.tmp_remove_nodes([1, 2]):
assert graph.nr_nodes == 3
assert list(graph.all_branches) == [(5, 4)]

assert graph.nr_nodes == 5
assert graph.nr_branches == 6

after_sets = [frozenset(branch) for branch in graph.all_branches]
counter_after = Counter(after_sets)
assert counter_before == counter_after
Loading