Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XPU support for AOT inductor #1503

Merged
merged 25 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ then
REQUIREMENTS_TO_INSTALL=(
torch=="2.7.0.${PYTORCH_NIGHTLY_VERSION}"
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
torchtune=="0.6.0"
#torchtune=="0.6.0" # no 0.6.0 on xpu nightly
)
else
REQUIREMENTS_TO_INSTALL=(
Expand Down Expand Up @@ -115,6 +115,17 @@ fi
"${REQUIREMENTS_TO_INSTALL[@]}"
)

# Temporatory instal torchtune nightly from cpu nightly link since no torchtune nightly for xpu now
# TODO: Change to install torchtune from xpu nightly link, once torchtune xpu nightly is ready
if [[ -x "$(command -v xpu-smi)" ]];
then
(
set -x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse the conditional on line 84?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because line 84 is to install using pytorch xpu link: https://download.pytorch.org/whl/nightly/xpu. Since we don't have torchtune nightly for xpu now, we temporarily to install torchtune cpu nightly from cpu link: https://download.pytorch.org/whl/nightly/cpu. After xpu nightly is ready, we'll update line 84 to install from pytorch xpu link.

Copy link
Contributor

@Jack-Khuu Jack-Khuu Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I meant reuse the if [[ -x "$(command -v xpu-smi)" ]]; check

This PR is fine to merge as-is though since after we get torchtune nightly this is going to be removed anyways 😃

$PIP_EXECUTABLE install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" \
torchtune=="0.6.0.${TUNE_NIGHTLY_VERSION}"
)
fi

# For torchao need to install from github since nightly build doesn't have macos build.
# TODO: Remove this and install nightly build, once it supports macos
(
Expand Down
6 changes: 3 additions & 3 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torchchat.utils.build_utils import (
device_sync,
is_cpu_device,
is_cuda_or_cpu_device,
is_cuda_or_cpu_or_xpu_device,
name_to_dtype,
)
from torchchat.utils.measure_time import measure_time
Expand Down Expand Up @@ -539,7 +539,7 @@ def _initialize_model(
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")

if builder_args.dso_path:
if not is_cuda_or_cpu_device(builder_args.device):
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
print(
f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead"
)
Expand Down Expand Up @@ -573,7 +573,7 @@ def do_nothing(max_batch_size, max_seq_length):
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")

elif builder_args.aoti_package_path:
if not is_cuda_or_cpu_device(builder_args.device):
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
print(
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
)
Expand Down
5 changes: 2 additions & 3 deletions torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,5 @@ def get_device(device) -> str:
def is_cpu_device(device) -> bool:
return device == "" or str(device) == "cpu"


def is_cuda_or_cpu_device(device) -> bool:
return is_cpu_device(device) or ("cuda" in str(device))
def is_cuda_or_cpu_or_xpu_device(device) -> bool:
return is_cpu_device(device) or ("cuda" in str(device)) or ("xpu" in str(device))
Loading