Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variant specific patches #485

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ name and use the suffix `.patch`. The filenames are sorted lexicographically, so
any text between the prefix and suffix can be used to ensure the patches are
applied in a specific order.

Patch files can also be placed in a variant specific subdirectory, in order
to allow variant specific patches, e.g. when the code base is a variant specific
fork of the package and the global patches don't apply.

Patches are applied by running `patch -p1 filename` while inside the root of the
source tree.

Expand All @@ -189,6 +193,7 @@ pytorch-v2.2.1/003-fbgemm-no-maybe-uninitialized.patch
pytorch-v2.2.1/004-fix-release-version.patch
pytorch-v2.2.2/001-remove-cmake-build-requirement.patch
pytorch-v2.2.2/002-dist-info-no-run-build-deps.patch
pytorch-v2.2.2/cuda/002-enforce-cudnn.patch
pytorch-v2.2.2/003-fbgemm-no-maybe-uninitialized.patch
pytorch-v2.2.2/004-fix-release-version.patch
xformers-0.0.26.post1/pyproject.toml.patch
Expand Down
22 changes: 20 additions & 2 deletions src/fromager/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def patches_for_requirement(
patches_dir: pathlib.Path,
req: Requirement,
version: Version,
variant: str = "",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variant name should be passed by passing a Context object as first argument.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be the only method with a Context object in this package. This is mostly helpers. And an extra variable with a default value reduces the size of the change.

Not saying I'm against the idea. Just giving more context about why I did it this way.

) -> typing.Iterable[pathlib.Path]:
"""Iterator producing patches to apply to the source for a given version of a requirement.
Expand All @@ -108,12 +109,29 @@ def patches_for_requirement(
override_name = pkgname_to_override_module(req.name)
unversioned_patch_dir = patches_dir / override_name
versioned_patch_dir = patches_dir / f"{override_name}-{version}"

unversioned_patch_files = list(unversioned_patch_dir.glob("*.patch"))
versioned_patch_files = list(versioned_patch_dir.glob("*.patch"))

# The list of files must exist to be joined to the global patch files
fabiendupont marked this conversation as resolved.
Show resolved Hide resolved
if variant:
unversioned_variant_patch_dir = unversioned_patch_dir / variant
if unversioned_variant_patch_dir.exists():
unversioned_patch_files.extend(
list(unversioned_variant_patch_dir.glob("*.patch"))
)
versioned_variant_patch_dir = versioned_patch_dir / variant
if versioned_variant_patch_dir.exists():
versioned_patch_files.extend(
list(versioned_variant_patch_dir.glob("*.patch"))
)

return itertools.chain(
# Apply all of the unversioned patches first, in order based on
# filename.
sorted(unversioned_patch_dir.glob("*.patch")),
sorted(unversioned_patch_files, key=lambda f: f.name),
# Then apply any for this specific version, in order based on filename.
sorted(versioned_patch_dir.glob("*.patch")),
sorted(versioned_patch_files, key=lambda f: f.name),
)


Expand Down
6 changes: 5 additions & 1 deletion src/fromager/packagesettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,11 @@ def get_patches(self) -> PatchMap:
prefix_len = len(pattern) - 1
for patchdir in self._patches_dir.glob(pattern):
version = Version(patchdir.name[prefix_len:])
patches[version] = sorted(patchdir.glob("*.patch"))
versioned_patches = list(patchdir.glob("*.patch"))
variant_patchdir = patchdir / self._variant
if variant_patchdir.exists():
versioned_patches.extend(list(variant_patchdir.glob("*.patch")))
patches[version] = sorted(versioned_patches, key=lambda f: f.name)
self._patches = patches
return self._patches

Expand Down
1 change: 1 addition & 0 deletions src/fromager/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def patch_source(
patches_dir=ctx.settings.patches_dir,
req=req,
version=version,
variant=ctx.variant,
):
_apply_patch(p, source_root_dir)
patch_count += 1
Expand Down
33 changes: 27 additions & 6 deletions tests/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,45 @@ def test_patches_for_requirement(tmp_path: pathlib.Path):
project_patch_dir = patches_dir / "project-1.2.3"
project_patch_dir.mkdir()

p1 = project_patch_dir / "001.patch"
p2 = project_patch_dir / "002.patch"
variant_1_patch_dir = project_patch_dir / "brie"
variant_1_patch_dir.mkdir()

variant_2_patch_dir = project_patch_dir / "feta"
variant_2_patch_dir.mkdir()

gp1 = project_patch_dir / "001.patch"
gp2 = project_patch_dir / "002.patch"
sp1 = variant_1_patch_dir / "001.patch"
sp2 = variant_2_patch_dir / "001.patch"
np1 = project_patch_dir / "not-a-patch.txt"

# Create all of the test files
for p in [p1, p2]:
p.write_text("this is a patch file")
for gp in [gp1, gp2]:
gp.write_text("this is a global patch file")
for sp in [sp1, sp2]:
sp.write_text("this is a specific patch file")
for f in [np1]:
f.write_text("this is not a patch file")

results = list(
results_without_variant = list(
overrides.patches_for_requirement(
patches_dir=patches_dir,
req=Requirement("project"),
version=Version("1.2.3"),
)
)

results_with_variant = list(
overrides.patches_for_requirement(
patches_dir=patches_dir,
req=Requirement("project"),
version=Version("1.2.3"),
variant="brie",
)
)
assert results == [p1, p2]

assert results_without_variant == [gp1, gp2]
assert results_with_variant == [gp1, sp1, gp2]


def test_invoke_override_with_exact_args():
Expand Down
3 changes: 3 additions & 0 deletions tests/test_packagesettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def test_pbi_test_pkg(testdata_context: context.WorkContext) -> None:
assert pbi.get_patches() == {
Version("1.0.2"): [
patchdir / "001-somepatch.patch",
patchdir / pbi.variant / "002-myvariantpatch.patch",
patchdir / "002-otherpatch.patch",
],
}
Expand Down Expand Up @@ -292,6 +293,8 @@ def test_pbi_other(testdata_context: context.WorkContext) -> None:
assert pbi.get_patches() == {
Version("1.0.0"): [
patchdir / "001-mypatch.patch",
patchdir / pbi.variant / "001-myvariantpatch.patch",
patchdir / "002-myotherpatch.patch",
],
}
assert pbi.get_patches() is pbi.get_patches()
Expand Down
87 changes: 87 additions & 0 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,93 @@ def test_patch_sources_apply_only_unversioned(
)


@patch("fromager.sources._apply_patch")
def test_patch_sources_apply_global_and_variant_specific_unversioned_and_versioned(
apply_patch: Mock,
tmp_path: pathlib.Path,
tmp_context: context.WorkContext,
):
patches_dir = tmp_path / "patches_dir"
patches_dir.mkdir()
tmp_context.settings.patches_dir = patches_dir

tmp_context.variant = "brie"

deepspeed_versioned_patch_dir = patches_dir / "deepspeed-0.5.0"
deepspeed_versioned_patch_dir.mkdir()

deepspeed_versioned_brie_patch_dir = deepspeed_versioned_patch_dir / "brie"
deepspeed_versioned_brie_patch_dir.mkdir()

deepspeed_versioned_feta_patch_dir = deepspeed_versioned_patch_dir / "feta"
deepspeed_versioned_feta_patch_dir.mkdir()

global_versioned_patch_file_1 = deepspeed_versioned_patch_dir / "01-deepspeed.patch"
global_versioned_patch_file_1.write_text("This is a test patch")
global_versioned_patch_file_2 = deepspeed_versioned_patch_dir / "02-deepspeed.patch"
global_versioned_patch_file_2.write_text("This is a test patch")

specific_versioned_brie_patch_file_1 = (
deepspeed_versioned_brie_patch_dir / "01-deepspeed.patch"
)
specific_versioned_brie_patch_file_1.write_text("This is a test patch for brie")

specific_versioned_feta_patch_file_1 = (
deepspeed_versioned_feta_patch_dir / "01-deepspeed.patch"
)
specific_versioned_feta_patch_file_1.write_text("This is a test patch for feta")

deepspeed_unversioned_patch_dir = patches_dir / "deepspeed"
deepspeed_unversioned_patch_dir.mkdir()

deepspeed_unversioned_brie_patch_dir = deepspeed_unversioned_patch_dir / "brie"
deepspeed_unversioned_brie_patch_dir.mkdir()

deepspeed_unversioned_feta_patch_dir = deepspeed_unversioned_patch_dir / "feta"
deepspeed_unversioned_feta_patch_dir.mkdir()

global_unversioned_patch_file_1 = (
deepspeed_unversioned_patch_dir / "01-deepspeed.patch"
)
global_unversioned_patch_file_1.write_text("This is a test patch")

global_unversioned_patch_file_2 = (
deepspeed_unversioned_patch_dir / "02-deepspeed.patch"
)
global_unversioned_patch_file_2.write_text("This is a test patch")

specific_unversioned_brie_patch_file_1 = (
deepspeed_unversioned_brie_patch_dir / "02-deepspeed.patch"
)
specific_unversioned_brie_patch_file_1.write_text("This is a test patch for brie")

specific_unversioned_feta_patch_file_1 = (
deepspeed_unversioned_feta_patch_dir / "01-deepspeed.patch"
)
specific_unversioned_feta_patch_file_1.write_text("This is a test patch for feta")

source_root_dir = tmp_path / "deepspeed-0.5.0"
source_root_dir.mkdir()

sources.patch_source(
ctx=tmp_context,
source_root_dir=source_root_dir,
req=Requirement("deepspeed"),
version=Version("0.5.0"),
)
assert apply_patch.call_count == 6
apply_patch.asset_has_calls(
[
call(global_unversioned_patch_file_1, source_root_dir),
call(global_unversioned_patch_file_2, source_root_dir),
call(specific_unversioned_brie_patch_file_1, source_root_dir),
call(global_versioned_patch_file_1, source_root_dir),
call(specific_versioned_brie_patch_file_1, source_root_dir),
call(global_versioned_patch_file_2, source_root_dir),
]
)


@patch("logging.Logger.warning")
def test_warning_for_older_patch(mock, tmp_path: pathlib.Path):
# create patches dir
Expand Down
Loading