Skip to content

Commit

Permalink
metal lowbit kernels: pip install
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelcandales committed Feb 26, 2025
1 parent 8706d3f commit e626dd5
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 11 deletions.
44 changes: 43 additions & 1 deletion .github/workflows/torchao_experimental_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:
- 'gh/**'

jobs:
test:
test-cpu-ops:
strategy:
matrix:
runner: [macos-14]
Expand All @@ -36,6 +36,7 @@ jobs:
pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.6.0.dev20250104"
pip install numpy
pip install pytest
pip install pyyaml
USE_CPP=1 pip install .
- name: Run python tests
run: |
Expand All @@ -56,3 +57,44 @@ jobs:
sh build_and_run_tests.sh
rm -rf /tmp/cmake-out
popd
test-mps-ops:
strategy:
matrix:
runner: [macos-m1-stable]
runs-on: ${{matrix.runner}}
steps:
- name: Print machine info
run: |
uname -a
if [ $(uname -s) == Darwin ]; then
sysctl machdep.cpu.brand_string
sysctl machdep.cpu.core_count
fi
- name: Checkout repo
uses: actions/checkout@v3
with:
submodules: true
- name: Setup environment
uses: conda-incubator/setup-miniconda@v3
with:
python-version: "3.11"
miniconda-version: "latest"
activate-environment: venv
- name: Install requirements
run: |
conda init
conda activate venv
pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.7.0.dev20250131"
pip install numpy
pip install parameterized
pip install pytest
pip install pyyaml
USE_CPP=1 pip install .
- name: Run mps tests
run: |
conda activate venv
pushd torchao/experimental/ops/mps/test
python test_lowbit.py
python test_quantizer.py
popd
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,4 @@ checkpoints/

# Experimental
torchao/experimental/cmake-out
torchao/experimental/deps
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,19 @@ def build_cmake(self, ext):
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

build_mps_ops = "ON" if torch.mps.is_available() else "OFF"

subprocess.check_call(
[
"cmake",
ext.sourcedir,
"-DCMAKE_BUILD_TYPE=" + build_type,
# Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16
"-DTORCHAO_BUILD_KLEIDIAI=OFF",
"-DTORCHAO_BUILD_MPS_OPS=" + build_mps_ops,
"-DTorch_DIR=" + torch_dir,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DCMAKE_INSTALL_PREFIX=cmake-out",
],
cwd=self.build_temp,
)
Expand Down
4 changes: 4 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
if len(experimental_lib) == 0 and Path.cwd().name == "ao":
experimental_lib = list(
(Path(__file__).parent.parent / "build").rglob("libtorchao_ops_aten.*")
)
if len(experimental_lib) > 0:
assert (
len(experimental_lib) == 1
Expand Down
7 changes: 7 additions & 0 deletions torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ if (NOT CMAKE_BUILD_TYPE)
endif()

option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF)
option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF)


if(NOT TORCHAO_INCLUDE_DIRS)
Expand Down Expand Up @@ -51,6 +52,12 @@ if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
torchao_ops_linear_8bit_act_xbit_weight_aten
torchao_ops_embedding_xbit_aten
)
if (TORCHAO_BUILD_MPS_OPS)
message(STATUS "Building with MPS support")
add_subdirectory(ops/mps)
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
endif()

install(
TARGETS torchao_ops_aten
EXPORT _targets
Expand Down
3 changes: 2 additions & 1 deletion torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ find_package(Torch REQUIRED)
# Generate metal_shader_lib.h by running gen_metal_shader_lib.py
set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal)
file(GLOB METAL_FILES ${METAL_SHADERS_DIR}/*.metal)
set(METAL_SHADERS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal.yaml)
set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py)
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
add_custom_command(
OUTPUT ${GENERATED_METAL_SHADER_LIB}
COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB}
DEPENDS ${METAL_FILES} ${GEN_SCRIPT}
DEPENDS ${METAL_FILES} ${METAL_SHADERS_YAML} ${GEN_SCRIPT}
COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py"
)
add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB})
Expand Down
9 changes: 5 additions & 4 deletions torchao/experimental/ops/mps/test/test_lowbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
import torch
from parameterized import parameterized

libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)
import torchao # noqa: F401

try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
Expand Down
10 changes: 5 additions & 5 deletions torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
import torch
from parameterized import parameterized

import torchao # noqa: F401
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer, _quantize

libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)

try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
Expand Down

0 comments on commit e626dd5

Please sign in to comment.