diff --git a/src/fromager/resolver.py b/src/fromager/resolver.py index 5bee1b8e..64003ff7 100644 --- a/src/fromager/resolver.py +++ b/src/fromager/resolver.py @@ -233,7 +233,9 @@ def get_preference( def is_satisfied_by(self, requirement: Requirement, candidate: Candidate) -> bool: if canonicalize_name(requirement.name) != candidate.name: return False - allow_prerelease = self.constraints.allow_prerelease(requirement.name) + allow_prerelease = self.constraints.allow_prerelease(requirement.name) or bool( + requirement.specifier.prereleases + ) return requirement.specifier.contains( candidate.version, prereleases=allow_prerelease ) and self.constraints.is_satisfied_by(requirement.name, candidate.version) @@ -291,7 +293,10 @@ def find_matches( continue # Skip versions that do not match the requirement. Allow prereleases only if constraints allow prereleases if not all( - r.specifier.contains(candidate.version, prereleases=allow_prerelease) + r.specifier.contains( + candidate.version, + prereleases=(allow_prerelease or bool(r.specifier.prereleases)), + ) for r in identifier_reqs ): if DEBUG_RESOLVER: @@ -372,7 +377,10 @@ def find_matches( continue # Skip versions that do not match the requirement if not all( - r.specifier.contains(version, prereleases=allow_prerelease) + r.specifier.contains( + version, + prereleases=(allow_prerelease or bool(r.specifier.prereleases)), + ) for r in identifier_reqs ): if DEBUG_RESOLVER: diff --git a/tests/test_resolver.py b/tests/test_resolver.py index da2f8106..ef61b9f4 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -23,6 +23,8 @@ hydra_core-1.3.2-1-py3-none-any.whl
hydra_core-1.3.2-2-py3-none-any.whl +
+hydra_core-2.0.0a1-py3-none-any.whl @@ -51,6 +53,28 @@ def test_provider_choose_wheel(): assert str(candidate.version) == "1.3.2" +def test_provider_choose_wheel_prereleases(): + with requests_mock.Mocker() as r: + r.get( + "https://pypi.org/simple/hydra-core/", + text=_hydra_core_simple_response, + ) + + provider = resolver.PyPIProvider(include_sdists=False) + reporter = resolvelib.BaseReporter() + rslvr = resolvelib.Resolver(provider, reporter) + + result = rslvr.resolve([Requirement("hydra-core==2.0.0a1")]) + assert "hydra-core" in result.mapping + + candidate = result.mapping["hydra-core"] + assert ( + candidate.url + == "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-2.0.0a1-py3-none-any.whl" + ) + assert str(candidate.version) == "2.0.0a1" + + def test_provider_choose_sdist(): with requests_mock.Mocker() as r: r.get(