Skip to content

Commit

Permalink
Support variant specific patches
Browse files Browse the repository at this point in the history
When building wheels for different variants, the code base might be
different from one variant to the other, for example when the code comes
from a fork with variant specific features. And with a different code
base, the patches may fail to be applied.

This changes proposes to support patch files suffixed by the variant
name for variant specific patches. The global patches are suffixed by
 `.patch` and the variant specific patches are suffixed by
`.patch.{variant}`. The sorting is done on all patches, so that the
lexical order is maintained between global and specific patches.

Signed-off-by: Fabien Dupont <[email protected]>
  • Loading branch information
fabiendupont committed Oct 25, 2024
1 parent c7b3834 commit f0d43f9
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 8 deletions.
16 changes: 14 additions & 2 deletions src/fromager/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def patches_for_requirement(
patches_dir: pathlib.Path,
req: Requirement,
version: Version,
variant: str = "",
) -> typing.Iterable[pathlib.Path]:
"""Iterator producing patches to apply to the source for a given version of a requirement.
Expand All @@ -108,12 +109,23 @@ def patches_for_requirement(
override_name = pkgname_to_override_module(req.name)
unversioned_patch_dir = patches_dir / override_name
versioned_patch_dir = patches_dir / f"{override_name}-{version}"

unversioned_patch_files = list(unversioned_patch_dir.glob("*.patch"))
versioned_patch_files = list(versioned_patch_dir.glob("*.patch"))

# The list of files must exist to be joined to the global patch files
if variant:
unversioned_patch_files += list(
unversioned_patch_dir.glob(f"*.patch.{variant}")
)
versioned_patch_files += list(versioned_patch_dir.glob(f"*.patch.{variant}"))

return itertools.chain(
# Apply all of the unversioned patches first, in order based on
# filename.
sorted(unversioned_patch_dir.glob("*.patch")),
sorted(unversioned_patch_files),
# Then apply any for this specific version, in order based on filename.
sorted(versioned_patch_dir.glob("*.patch")),
sorted(versioned_patch_files),
)


Expand Down
1 change: 1 addition & 0 deletions src/fromager/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def patch_source(
patches_dir=ctx.settings.patches_dir,
req=req,
version=version,
variant=ctx.variant,
):
_apply_patch(p, source_root_dir)
patch_count += 1
Expand Down
27 changes: 21 additions & 6 deletions tests/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,39 @@ def test_patches_for_requirement(tmp_path: pathlib.Path):
project_patch_dir = patches_dir / "project-1.2.3"
project_patch_dir.mkdir()

p1 = project_patch_dir / "001.patch"
p2 = project_patch_dir / "002.patch"
gp1 = project_patch_dir / "001.patch"
gp2 = project_patch_dir / "002.patch"
sp1 = project_patch_dir / "001.patch.brie"
sp2 = project_patch_dir / "001.patch.feta"
np1 = project_patch_dir / "not-a-patch.txt"

# Create all of the test files
for p in [p1, p2]:
p.write_text("this is a patch file")
for gp in [gp1, gp2]:
gp.write_text("this is a global patch file")
for sp in [sp1, sp2]:
sp.write_text("this is a specific patch file")
for f in [np1]:
f.write_text("this is not a patch file")

results = list(
results_without_variant = list(
overrides.patches_for_requirement(
patches_dir=patches_dir,
req=Requirement("project"),
version=Version("1.2.3"),
)
)
assert results == [p1, p2]

results_with_variant = list(
overrides.patches_for_requirement(
patches_dir=patches_dir,
req=Requirement("project"),
version=Version("1.2.3"),
variant="brie",
)
)

assert results_without_variant == [gp1, gp2]
assert results_with_variant == [gp1, sp1, gp2]


def test_invoke_override_with_exact_args():
Expand Down
68 changes: 68 additions & 0 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,74 @@ def test_patch_sources_apply_only_unversioned(
)


@patch("fromager.sources._apply_patch")
def test_patch_sources_apply_global_and_variant_specific_unversioned_and_versioned(
apply_patch: Mock,
tmp_path: pathlib.Path,
tmp_context: context.WorkContext,
):
patches_dir = tmp_path / "patches_dir"
patches_dir.mkdir()
tmp_context.settings.patches_dir = patches_dir

tmp_context.variant = "brie"

deepspeed_versioned_patch_dir = patches_dir / "deepspeed-0.5.0"
deepspeed_versioned_patch_dir.mkdir()
global_versioned_patch_file_1 = deepspeed_versioned_patch_dir / "01-deepspeed.patch"
global_versioned_patch_file_1.write_text("This is a test patch")
global_versioned_patch_file_2 = deepspeed_versioned_patch_dir / "02-deepspeed.patch"
global_versioned_patch_file_2.write_text("This is a test patch")
specific_versioned_patch_file_1 = (
deepspeed_versioned_patch_dir / "01-deepspeed.patch.brie"
)
specific_versioned_patch_file_1.write_text("This is a test patch")
specific_versioned_patch_file_2 = (
deepspeed_versioned_patch_dir / "01-deepspeed.patch.feta"
)
specific_versioned_patch_file_2.write_text("This is a test patch")

deepspeed_unversioned_patch_dir = patches_dir / "deepspeed"
deepspeed_unversioned_patch_dir.mkdir()
global_unversioned_patch_file_1 = (
deepspeed_unversioned_patch_dir / "01-deepspeed.patch"
)
global_unversioned_patch_file_1.write_text("This is a test patch")
global_unversioned_patch_file_2 = (
deepspeed_unversioned_patch_dir / "02-deepspeed.patch"
)
global_unversioned_patch_file_2.write_text("This is a test patch")
specific_unversioned_patch_file_1 = (
deepspeed_unversioned_patch_dir / "02-deepspeed.patch.brie"
)
specific_unversioned_patch_file_1.write_text("This is a test patch")
specific_unversioned_patch_file_2 = (
deepspeed_unversioned_patch_dir / "01-deepspeed.patch.feta"
)
specific_unversioned_patch_file_2.write_text("This is a test patch")

source_root_dir = tmp_path / "deepspeed-0.5.0"
source_root_dir.mkdir()

sources.patch_source(
ctx=tmp_context,
source_root_dir=source_root_dir,
req=Requirement("deepspeed"),
version=Version("0.5.0"),
)
assert apply_patch.call_count == 6
apply_patch.asset_has_calls(
[
call(global_unversioned_patch_file_1, source_root_dir),
call(global_unversioned_patch_file_2, source_root_dir),
call(specific_unversioned_patch_file_1, source_root_dir),
call(global_versioned_patch_file_1, source_root_dir),
call(specific_versioned_patch_file_1, source_root_dir),
call(global_versioned_patch_file_2, source_root_dir),
]
)


@patch("logging.Logger.warning")
def test_warning_for_older_patch(mock, tmp_path: pathlib.Path):
# create patches dir
Expand Down

0 comments on commit f0d43f9

Please sign in to comment.