From f0d43f9ffe2064211683f5979f34fdd02befd60f Mon Sep 17 00:00:00 2001 From: Fabien Dupont Date: Fri, 25 Oct 2024 07:46:24 -0400 Subject: [PATCH] Support variant specific patches When building wheels for different variants, the code base might be different from one variant to the other, for example when the code comes from a fork with variant specific features. And with a different code base, the patches may fail to be applied. This changes proposes to support patch files suffixed by the variant name for variant specific patches. The global patches are suffixed by `.patch` and the variant specific patches are suffixed by `.patch.{variant}`. The sorting is done on all patches, so that the lexical order is maintained between global and specific patches. Signed-off-by: Fabien Dupont --- src/fromager/overrides.py | 16 +++++++-- src/fromager/sources.py | 1 + tests/test_overrides.py | 27 ++++++++++++---- tests/test_sources.py | 68 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 8 deletions(-) diff --git a/src/fromager/overrides.py b/src/fromager/overrides.py index c7c649c5..e97db83d 100644 --- a/src/fromager/overrides.py +++ b/src/fromager/overrides.py @@ -98,6 +98,7 @@ def patches_for_requirement( patches_dir: pathlib.Path, req: Requirement, version: Version, + variant: str = "", ) -> typing.Iterable[pathlib.Path]: """Iterator producing patches to apply to the source for a given version of a requirement. @@ -108,12 +109,23 @@ def patches_for_requirement( override_name = pkgname_to_override_module(req.name) unversioned_patch_dir = patches_dir / override_name versioned_patch_dir = patches_dir / f"{override_name}-{version}" + + unversioned_patch_files = list(unversioned_patch_dir.glob("*.patch")) + versioned_patch_files = list(versioned_patch_dir.glob("*.patch")) + + # The list of files must exist to be joined to the global patch files + if variant: + unversioned_patch_files += list( + unversioned_patch_dir.glob(f"*.patch.{variant}") + ) + versioned_patch_files += list(versioned_patch_dir.glob(f"*.patch.{variant}")) + return itertools.chain( # Apply all of the unversioned patches first, in order based on # filename. - sorted(unversioned_patch_dir.glob("*.patch")), + sorted(unversioned_patch_files), # Then apply any for this specific version, in order based on filename. - sorted(versioned_patch_dir.glob("*.patch")), + sorted(versioned_patch_files), ) diff --git a/src/fromager/sources.py b/src/fromager/sources.py index ca62b988..7c2f4534 100644 --- a/src/fromager/sources.py +++ b/src/fromager/sources.py @@ -299,6 +299,7 @@ def patch_source( patches_dir=ctx.settings.patches_dir, req=req, version=version, + variant=ctx.variant, ): _apply_patch(p, source_root_dir) patch_count += 1 diff --git a/tests/test_overrides.py b/tests/test_overrides.py index 4f4bacf8..11a0f9c1 100644 --- a/tests/test_overrides.py +++ b/tests/test_overrides.py @@ -16,24 +16,39 @@ def test_patches_for_requirement(tmp_path: pathlib.Path): project_patch_dir = patches_dir / "project-1.2.3" project_patch_dir.mkdir() - p1 = project_patch_dir / "001.patch" - p2 = project_patch_dir / "002.patch" + gp1 = project_patch_dir / "001.patch" + gp2 = project_patch_dir / "002.patch" + sp1 = project_patch_dir / "001.patch.brie" + sp2 = project_patch_dir / "001.patch.feta" np1 = project_patch_dir / "not-a-patch.txt" # Create all of the test files - for p in [p1, p2]: - p.write_text("this is a patch file") + for gp in [gp1, gp2]: + gp.write_text("this is a global patch file") + for sp in [sp1, sp2]: + sp.write_text("this is a specific patch file") for f in [np1]: f.write_text("this is not a patch file") - results = list( + results_without_variant = list( overrides.patches_for_requirement( patches_dir=patches_dir, req=Requirement("project"), version=Version("1.2.3"), ) ) - assert results == [p1, p2] + + results_with_variant = list( + overrides.patches_for_requirement( + patches_dir=patches_dir, + req=Requirement("project"), + version=Version("1.2.3"), + variant="brie", + ) + ) + + assert results_without_variant == [gp1, gp2] + assert results_with_variant == [gp1, sp1, gp2] def test_invoke_override_with_exact_args(): diff --git a/tests/test_sources.py b/tests/test_sources.py index 9d0aa9b7..a2bffda7 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -179,6 +179,74 @@ def test_patch_sources_apply_only_unversioned( ) +@patch("fromager.sources._apply_patch") +def test_patch_sources_apply_global_and_variant_specific_unversioned_and_versioned( + apply_patch: Mock, + tmp_path: pathlib.Path, + tmp_context: context.WorkContext, +): + patches_dir = tmp_path / "patches_dir" + patches_dir.mkdir() + tmp_context.settings.patches_dir = patches_dir + + tmp_context.variant = "brie" + + deepspeed_versioned_patch_dir = patches_dir / "deepspeed-0.5.0" + deepspeed_versioned_patch_dir.mkdir() + global_versioned_patch_file_1 = deepspeed_versioned_patch_dir / "01-deepspeed.patch" + global_versioned_patch_file_1.write_text("This is a test patch") + global_versioned_patch_file_2 = deepspeed_versioned_patch_dir / "02-deepspeed.patch" + global_versioned_patch_file_2.write_text("This is a test patch") + specific_versioned_patch_file_1 = ( + deepspeed_versioned_patch_dir / "01-deepspeed.patch.brie" + ) + specific_versioned_patch_file_1.write_text("This is a test patch") + specific_versioned_patch_file_2 = ( + deepspeed_versioned_patch_dir / "01-deepspeed.patch.feta" + ) + specific_versioned_patch_file_2.write_text("This is a test patch") + + deepspeed_unversioned_patch_dir = patches_dir / "deepspeed" + deepspeed_unversioned_patch_dir.mkdir() + global_unversioned_patch_file_1 = ( + deepspeed_unversioned_patch_dir / "01-deepspeed.patch" + ) + global_unversioned_patch_file_1.write_text("This is a test patch") + global_unversioned_patch_file_2 = ( + deepspeed_unversioned_patch_dir / "02-deepspeed.patch" + ) + global_unversioned_patch_file_2.write_text("This is a test patch") + specific_unversioned_patch_file_1 = ( + deepspeed_unversioned_patch_dir / "02-deepspeed.patch.brie" + ) + specific_unversioned_patch_file_1.write_text("This is a test patch") + specific_unversioned_patch_file_2 = ( + deepspeed_unversioned_patch_dir / "01-deepspeed.patch.feta" + ) + specific_unversioned_patch_file_2.write_text("This is a test patch") + + source_root_dir = tmp_path / "deepspeed-0.5.0" + source_root_dir.mkdir() + + sources.patch_source( + ctx=tmp_context, + source_root_dir=source_root_dir, + req=Requirement("deepspeed"), + version=Version("0.5.0"), + ) + assert apply_patch.call_count == 6 + apply_patch.asset_has_calls( + [ + call(global_unversioned_patch_file_1, source_root_dir), + call(global_unversioned_patch_file_2, source_root_dir), + call(specific_unversioned_patch_file_1, source_root_dir), + call(global_versioned_patch_file_1, source_root_dir), + call(specific_versioned_patch_file_1, source_root_dir), + call(global_versioned_patch_file_2, source_root_dir), + ] + ) + + @patch("logging.Logger.warning") def test_warning_for_older_patch(mock, tmp_path: pathlib.Path): # create patches dir