Skip to content

Commit

Permalink
Merge pull request #537 from tiran/refactor-constraints
Browse files Browse the repository at this point in the history
Refactor constraints module
  • Loading branch information
mergify[bot] authored Jan 23, 2025
2 parents c84ee7a + a3cf190 commit d7f5395
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 44 deletions.
48 changes: 28 additions & 20 deletions src/fromager/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing

from packaging.requirements import Requirement
from packaging.utils import canonicalize_name
from packaging.utils import NormalizedName, canonicalize_name
from packaging.version import Version

from . import requirements_file
Expand All @@ -12,8 +12,33 @@


class Constraints:
def __init__(self, data: dict[str, Requirement]):
self._data = {canonicalize_name(n): v for n, v in data.items()}
def __init__(self) -> None:
# mapping of canonical names to requirements
# NOTE: Requirement.name is not normalized
self._data: dict[NormalizedName, Requirement] = {}

def __iter__(self) -> typing.Iterable[NormalizedName]:
yield from self._data

def add_constraint(self, unparsed: str) -> None:
"""Add new constraint, must not conflict with any existing constraints"""
req = Requirement(unparsed)
canon_name = canonicalize_name(req.name)
previous = self._data.get(canon_name)
if previous is not None:
raise KeyError(
f"{canon_name}: new constraint '{req}' conflicts with '{previous}'"
)
if requirements_file.evaluate_marker(req, req):
logger.debug(f"adding constraint {req}")
self._data[canon_name] = req

def load_constraints_file(self, constraints_file: str | pathlib.Path) -> None:
"""Load constraints from a constraints file"""
logger.info("loading constraints from %s", constraints_file)
content = requirements_file.parse_requirements_file(constraints_file)
for line in content:
self.add_constraint(line)

def get_constraint(self, name: str) -> Requirement | None:
return self._data.get(canonicalize_name(name))
Expand All @@ -29,20 +54,3 @@ def is_satisfied_by(self, pkg_name: str, version: Version) -> bool:
if constraint:
return constraint.specifier.contains(version, prereleases=True)
return True


def _parse(content: typing.Iterable[str]) -> Constraints:
constraints = {}
for line in content:
req = Requirement(line)
if requirements_file.evaluate_marker(req, req):
constraints[req.name] = req
return Constraints(constraints)


def load(constraints_file: str | pathlib.Path | None) -> Constraints:
if not constraints_file:
return Constraints({})
logger.info("loading constraints from %s", constraints_file)
parsed_req_file = requirements_file.parse_requirements_file(constraints_file)
return _parse(parsed_req_file)
4 changes: 2 additions & 2 deletions src/fromager/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def __init__(
)
self.settings = active_settings
self.input_constraints_uri: str | None
self.constraints = constraints.Constraints()
if constraints_file is not None:
self.input_constraints_uri = constraints_file
self.constraints = constraints.load(constraints_file)
self.constraints.load_constraints_file(constraints_file)
else:
self.input_constraints_uri = None
self.constraints = constraints.Constraints({})
self.sdists_repo = pathlib.Path(sdists_repo).absolute()
self.sdists_downloads = self.sdists_repo / "downloads"
self.sdists_builds = self.sdists_repo / "builds"
Expand Down
2 changes: 1 addition & 1 deletion src/fromager/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __init__(
self.include_sdists = include_sdists
self.include_wheels = include_wheels
self.sdist_server_url = sdist_server_url
self.constraints = constraints or Constraints({})
self.constraints = constraints or Constraints()
self.req_type = req_type

def identify(self, requirement_or_candidate: Requirement | Candidate) -> str:
Expand Down
53 changes: 39 additions & 14 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pathlib
from unittest.mock import Mock, patch

import pytest
from packaging.requirements import Requirement
Expand All @@ -9,40 +8,66 @@


def test_constraint_is_satisfied_by():
c = constraints.Constraints({"foo": Requirement("foo<=1.1")})
c = constraints.Constraints()
c.add_constraint("foo<=1.1")
assert c.is_satisfied_by("foo", "1.1")
assert c.is_satisfied_by("foo", Version("1.0"))
assert c.is_satisfied_by("bar", Version("2.0"))


def test_constraint_canonical_name():
c = constraints.Constraints({"flash_attn": Requirement("flash_attn<=1.1")})
c = constraints.Constraints()
c.add_constraint("flash_attn<=1.1")
assert c.is_satisfied_by("flash_attn", "1.1")
assert c.is_satisfied_by("flash-attn", "1.1")
assert c.is_satisfied_by("Flash-ATTN", "1.1")
assert list(c) == ["flash-attn"]


def test_constraint_not_is_satisfied_by():
c = constraints.Constraints({"foo": Requirement("foo<=1.1")})
c = constraints.Constraints()
c.add_constraint("foo<=1.1")
c.add_constraint("bar>=2.0")
assert not c.is_satisfied_by("foo", "1.2")
assert not c.is_satisfied_by("foo", Version("2.0"))
assert not c.is_satisfied_by("bar", Version("1.0"))


def test_load_empty_constraints_file():
assert constraints.load(None)._data == {}
def test_add_constraint_conflict():
c = constraints.Constraints()
c.add_constraint("foo<=1.1")
c.add_constraint("flit_core==2.0rc3")
with pytest.raises(KeyError):
c.add_constraint("foo<=1.1")
with pytest.raises(KeyError):
c.add_constraint("foo>1.1")
with pytest.raises(KeyError):
c.add_constraint("flit_core>2.0.0")
with pytest.raises(KeyError):
c.add_constraint("flit-core>2.0.0")


def test_allow_prerelease():
c = constraints.Constraints()
c.add_constraint("foo>=1.1")
assert not c.allow_prerelease("foo")
c.add_constraint("bar>=1.1a0")
assert c.allow_prerelease("bar")
c.add_constraint("flit_core==2.0rc3")
assert c.allow_prerelease("flit_core")


def test_load_non_existant_constraints_file(tmp_path: pathlib.Path):
non_existant_file = tmp_path / "non_existant.txt"
c = constraints.Constraints()
with pytest.raises(FileNotFoundError):
constraints.load(non_existant_file)
c.load_constraints_file(non_existant_file)


@patch("fromager.requirements_file.parse_requirements_file")
def test_load_constraints_file(parse_requirements_file: Mock, tmp_path: pathlib.Path):
def test_load_constraints_file(tmp_path: pathlib.Path):
constraint_file = tmp_path / "constraint.txt"
constraint_file.write_text("a\n")
parse_requirements_file.return_value = ["torch==3.1.0"]
assert constraints.load(constraint_file)._data == {
"torch": Requirement("torch==3.1.0")
}
constraint_file.write_text("egg\ntorch==3.1.0 # comment\n")
c = constraints.Constraints()
c.load_constraints_file(constraint_file)
assert list(c) == ["egg", "torch"] # type: ignore
assert c.get_constraint("torch") == Requirement("torch==3.1.0")
17 changes: 10 additions & 7 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,8 @@ def test_provider_choose_sdist():


def test_provider_choose_either_with_constraint():
constraint = constraints.Constraints(
{"hydra-core": Requirement("hydra-core==1.3.2")}
)
constraint = constraints.Constraints()
constraint.add_constraint("hydra-core==1.3.2")
with requests_mock.Mocker() as r:
r.get(
"https://pypi.org/simple/hydra-core/",
Expand All @@ -204,7 +203,8 @@ def test_provider_choose_either_with_constraint():


def test_provider_constraint_mismatch():
constraint = constraints.Constraints({"hydra-core": Requirement("hydra-core<=1.1")})
constraint = constraints.Constraints()
constraint.add_constraint("hydra-core<=1.1")
with requests_mock.Mocker() as r:
r.get(
"https://pypi.org/simple/hydra-core/",
Expand All @@ -220,7 +220,8 @@ def test_provider_constraint_mismatch():


def test_provider_constraint_match():
constraint = constraints.Constraints({"hydra-core": Requirement("hydra-core<=1.3")})
constraint = constraints.Constraints()
constraint.add_constraint("hydra-core<=1.3")
with requests_mock.Mocker() as r:
r.get(
"https://pypi.org/simple/hydra-core/",
Expand Down Expand Up @@ -525,7 +526,8 @@ def test_resolve_github():


def test_github_constraint_mismatch():
constraint = constraints.Constraints({"fromager": Requirement("fromager>=1.0")})
constraint = constraints.Constraints()
constraint.add_constraint("fromager>=1.0")
with requests_mock.Mocker() as r:
r.get(
"https://api.github.com:443/repos/python-wheel-build/fromager",
Expand All @@ -547,7 +549,8 @@ def test_github_constraint_mismatch():


def test_github_constraint_match():
constraint = constraints.Constraints({"fromager": Requirement("fromager<0.9")})
constraint = constraints.Constraints()
constraint.add_constraint("fromager<0.9")
with requests_mock.Mocker() as r:
r.get(
"https://api.github.com:443/repos/python-wheel-build/fromager",
Expand Down

0 comments on commit d7f5395

Please sign in to comment.