From c4fb7bbd36b7f3c61f3303d1a7dbd20478635d0c Mon Sep 17 00:00:00 2001 From: Vincent Koppen Date: Wed, 29 Jan 2025 11:40:14 +0100 Subject: [PATCH] refactor(BaseGraph): improve performance of get_components Signed-off-by: Vincent Koppen --- src/power_grid_model_ds/_core/model/graphs/models/base.py | 3 ++- .../_core/model/graphs/models/rustworkx.py | 7 ++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/power_grid_model_ds/_core/model/graphs/models/base.py b/src/power_grid_model_ds/_core/model/graphs/models/base.py index 4f0a7bd..b1d112b 100644 --- a/src/power_grid_model_ds/_core/model/graphs/models/base.py +++ b/src/power_grid_model_ds/_core/model/graphs/models/base.py @@ -250,7 +250,8 @@ def get_all_paths(self, ext_start_node_id: int, ext_end_node_id: int) -> list[li def get_components(self, substation_nodes: NDArray[np.int32]) -> list[list[int]]: """Returns all separate components when the substation_nodes are removed of the graph as lists""" - internal_components = self._get_components(substation_nodes=self._externals_to_internals(substation_nodes)) + with self.tmp_remove_nodes(substation_nodes): + internal_components = self._get_components() return [self._internals_to_externals(component) for component in internal_components] def get_connected( diff --git a/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py b/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py index b509391..39c66e5 100644 --- a/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py +++ b/src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py @@ -84,11 +84,8 @@ def _get_shortest_path(self, source: int, target: int) -> tuple[list[int], int]: def _get_all_paths(self, source: int, target: int) -> list[list[int]]: return list(rx.all_simple_paths(self._graph, source, target)) - def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: - no_os_graph = self._graph.copy() - for os_node in substation_nodes: - no_os_graph.remove_node(os_node) - components = rx.connected_components(no_os_graph) + def _get_components(self) -> list[list[int]]: + components = rx.connected_components(self._graph) return [list(component) for component in components] def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: