From 2e5579bd1270a9189e8aa5ff4aee798f49b2912d Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 26 Feb 2025 13:56:18 -0500 Subject: [PATCH] metal lowbit kernels: pip install --- .github/workflows/torchao_experimental_test.yml | 9 +++++++++ .gitignore | 1 + setup.py | 4 ++++ torchao/__init__.py | 4 ++++ torchao/experimental/CMakeLists.txt | 7 +++++++ torchao/experimental/ops/mps/CMakeLists.txt | 3 ++- torchao/experimental/ops/mps/test/test_lowbit.py | 9 +++++---- torchao/experimental/ops/mps/test/test_quantizer.py | 10 +++++----- 8 files changed, 37 insertions(+), 10 deletions(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index e1511ffe9a..a9c96afd95 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -35,7 +35,9 @@ jobs: conda activate venv pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.6.0.dev20250104" pip install numpy + pip install parameterized pip install pytest + pip install pyyaml USE_CPP=1 pip install . - name: Run python tests run: | @@ -56,3 +58,10 @@ jobs: sh build_and_run_tests.sh rm -rf /tmp/cmake-out popd + - name: Run mps tests + run: | + conda activate venv + pushd torchao/experimental/ops/mps/test + python test_lowbit.py + python test_quantizer.py + popd diff --git a/.gitignore b/.gitignore index 726d2976f6..d8c3199a1e 100644 --- a/.gitignore +++ b/.gitignore @@ -375,3 +375,4 @@ checkpoints/ # Experimental torchao/experimental/cmake-out +torchao/experimental/deps diff --git a/setup.py b/setup.py index ee3ebbf453..fa0bcc5b0b 100644 --- a/setup.py +++ b/setup.py @@ -174,6 +174,8 @@ def build_cmake(self, ext): if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) + build_mps_ops = "ON" if torch.mps.is_available() else "OFF" + subprocess.check_call( [ "cmake", @@ -181,8 +183,10 @@ def build_cmake(self, ext): "-DCMAKE_BUILD_TYPE=" + build_type, # Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16 "-DTORCHAO_BUILD_KLEIDIAI=OFF", + "-DTORCHAO_BUILD_MPS_OPS=" + build_mps_ops, "-DTorch_DIR=" + torch_dir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, + "-DCMAKE_INSTALL_PREFIX=cmake-out", ], cwd=self.build_temp, ) diff --git a/torchao/__init__.py b/torchao/__init__.py index cc453e2d14..e3e03df5e1 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -35,6 +35,10 @@ # running the script `torchao/experimental/build_torchao_ops.sh ` # For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*")) + if len(experimental_lib) == 0 and Path.cwd().name == "ao": + experimental_lib = list( + (Path(__file__).parent.parent / "build").rglob("libtorchao_ops_aten.*") + ) if len(experimental_lib) > 0: assert ( len(experimental_lib) == 1 diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index a90cc5884a..67dfc7b779 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -16,6 +16,7 @@ if (NOT CMAKE_BUILD_TYPE) endif() option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF) +option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF) if(NOT TORCHAO_INCLUDE_DIRS) @@ -51,6 +52,12 @@ if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") torchao_ops_linear_8bit_act_xbit_weight_aten torchao_ops_embedding_xbit_aten ) + if (TORCHAO_BUILD_MPS_OPS) + message(STATUS "Building with MPS support") + add_subdirectory(ops/mps) + target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten) + endif() + install( TARGETS torchao_ops_aten EXPORT _targets diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt index 820205fa27..8dcdec523e 100644 --- a/torchao/experimental/ops/mps/CMakeLists.txt +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -28,12 +28,13 @@ find_package(Torch REQUIRED) # Generate metal_shader_lib.h by running gen_metal_shader_lib.py set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal) file(GLOB METAL_FILES ${METAL_SHADERS_DIR}/*.metal) +set(METAL_SHADERS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal.yaml) set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py) set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h) add_custom_command( OUTPUT ${GENERATED_METAL_SHADER_LIB} COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB} - DEPENDS ${METAL_FILES} ${GEN_SCRIPT} + DEPENDS ${METAL_FILES} ${METAL_SHADERS_YAML} ${GEN_SCRIPT} COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py" ) add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB}) diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index 437fb7578f..d5ffad53e4 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -10,10 +10,7 @@ import torch from parameterized import parameterized -libname = "libtorchao_ops_mps_aten.dylib" -libpath = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) -) +import torchao # noqa: F401 try: for nbit in range(1, 8): @@ -21,6 +18,10 @@ getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") except AttributeError: try: + libname = "libtorchao_ops_mps_aten.dylib" + libpath = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) + ) torch.ops.load_library(libpath) except: raise RuntimeError(f"Failed to load library {libpath}") diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index b530c6ea83..7afa91183e 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -12,19 +12,19 @@ import torch from parameterized import parameterized +import torchao # noqa: F401 from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer, _quantize -libname = "libtorchao_ops_mps_aten.dylib" -libpath = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) -) - try: for nbit in range(1, 8): getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") except AttributeError: try: + libname = "libtorchao_ops_mps_aten.dylib" + libpath = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) + ) torch.ops.load_library(libpath) except: raise RuntimeError(f"Failed to load library {libpath}")