From 15bc9e52d29580d876eba6caa5778a96365e2238 Mon Sep 17 00:00:00 2001 From: Shubh Bapna Date: Thu, 5 Dec 2024 16:15:41 -0500 Subject: [PATCH] fix cycles in when traversing graph due to cycles graph traversal would cause maximum recursive depth reached error Signed-off-by: Shubh Bapna --- src/fromager/dependency_graph.py | 11 ++++++--- tests/test_graph.py | 38 ++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/fromager/dependency_graph.py b/src/fromager/dependency_graph.py index bb9d36b9..08fa4558 100644 --- a/src/fromager/dependency_graph.py +++ b/src/fromager/dependency_graph.py @@ -241,6 +241,7 @@ def get_dependency_edges( visited = set() for edge in self._depth_first_traversal( self.nodes[ROOT].children, + set(), match_dep_types=match_dep_types, ): if edge.destination_node.key not in visited: @@ -278,13 +279,17 @@ def get_install_dependency_versions( def _depth_first_traversal( self, - start_node: list[DependencyEdge], + start_edges: list[DependencyEdge], + visited: set[str], match_dep_types: list[RequirementType] | None = None, ) -> typing.Iterable[DependencyEdge]: - for edge in start_node: + for edge in start_edges: + if edge.destination_node.key in visited: + continue if match_dep_types and edge.req_type not in match_dep_types: continue + visited.add(edge.destination_node.key) yield edge yield from self._depth_first_traversal( - edge.destination_node.children, match_dep_types + edge.destination_node.children, visited, match_dep_types ) diff --git a/tests/test_graph.py b/tests/test_graph.py index 71d84de6..96889cd4 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -143,3 +143,41 @@ def test_get_install_dependencies(): for node in graph.get_install_dependencies() ] assert install_nodes == ["a==2.0", "d==6.0", "b==3.0", "e==6.0"] + + +def test_cycles_get_install_dependencies(): + graph = dependency_graph.DependencyGraph.from_dict(raw_graph) + # create cycle: a depends on d and d depends on a + graph.add_dependency( + parent_name=canonicalize_name("a"), + parent_version=Version("2.0"), + req_type=requirements_file.RequirementType.INSTALL, + req=Requirement("d>=4.0"), + req_version=Version("6.0"), + download_url="url for d", + ) + + graph.add_dependency( + parent_name=canonicalize_name("d"), + parent_version=Version("6.0"), + req_type=requirements_file.RequirementType.INSTALL, + req=Requirement("a<=2.0"), + req_version=Version("2.0"), + download_url="url for a", + ) + + # add another duplicate toplevel + graph.add_dependency( + parent_name=None, + parent_version=None, + req_type=requirements_file.RequirementType.TOP_LEVEL, + req=Requirement("a<=2.0"), + req_version=Version("2.0"), + download_url="url for a", + ) + + install_nodes = [ + f"{node.to_dict()['canonicalized_name']}=={node.to_dict()['version']}" + for node in graph.get_install_dependencies() + ] + assert install_nodes == ["a==2.0", "d==6.0"]