diff --git a/install/install_requirements.sh b/install/install_requirements.sh index e9a858ab8..b8baa72ce 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -81,7 +81,7 @@ then REQUIREMENTS_TO_INSTALL=( torch=="2.7.0.${PYTORCH_NIGHTLY_VERSION}" torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}" - torchtune=="0.6.0" + #torchtune=="0.6.0" # no 0.6.0 on xpu nightly ) else REQUIREMENTS_TO_INSTALL=( @@ -115,6 +115,17 @@ fi "${REQUIREMENTS_TO_INSTALL[@]}" ) +# Temporatory instal torchtune nightly from cpu nightly link since no torchtune nightly for xpu now +# TODO: Change to install torchtune from xpu nightly link, once torchtune xpu nightly is ready +if [[ -x "$(command -v xpu-smi)" ]]; +then +( + set -x + $PIP_EXECUTABLE install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" \ + torchtune=="0.6.0.${TUNE_NIGHTLY_VERSION}" +) +fi + # For torchao need to install from github since nightly build doesn't have macos build. # TODO: Remove this and install nightly build, once it supports macos ( diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 1e04800ab..f40936b81 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -29,7 +29,7 @@ from torchchat.utils.build_utils import ( device_sync, is_cpu_device, - is_cuda_or_cpu_device, + is_cuda_or_cpu_or_xpu_device, name_to_dtype, ) from torchchat.utils.measure_time import measure_time @@ -539,7 +539,7 @@ def _initialize_model( _set_gguf_kwargs(builder_args, is_et=is_pte, context="generate") if builder_args.dso_path: - if not is_cuda_or_cpu_device(builder_args.device): + if not is_cuda_or_cpu_or_xpu_device(builder_args.device): print( f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead" ) @@ -573,7 +573,7 @@ def do_nothing(max_batch_size, max_seq_length): raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}") elif builder_args.aoti_package_path: - if not is_cuda_or_cpu_device(builder_args.device): + if not is_cuda_or_cpu_or_xpu_device(builder_args.device): print( f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead" ) diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index a0862ff94..b9c32a7fe 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -303,6 +303,5 @@ def get_device(device) -> str: def is_cpu_device(device) -> bool: return device == "" or str(device) == "cpu" - -def is_cuda_or_cpu_device(device) -> bool: - return is_cpu_device(device) or ("cuda" in str(device)) +def is_cuda_or_cpu_or_xpu_device(device) -> bool: + return is_cpu_device(device) or ("cuda" in str(device)) or ("xpu" in str(device))