diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index db79fa4..8e17cf4 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -40,27 +40,17 @@ jobs: strategy: fail-fast: false matrix: - # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the - # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] - python-version: ['3.9', '3.10', '3.11', '3.12'] - jax-version: ['0.4.24'] - cuda-version: ['11.8.0', '12.3.1'] + python-version: ['cp39', 'cp310', 'cp311', 'cp312'] + cuda-version: ['11.8', '12.3'] steps: - name: Checkout uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Set CUDA and PyTorch versions run: | echo "MATRIX_CUDA_MAJOR=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV - echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - echo "MATRIX_JAX_VERSION=$(echo ${{ matrix.jax-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV - name: Free up disk space if: ${{ runner.os == 'Linux' }} @@ -77,53 +67,17 @@ jobs: with: swap-size-gb: 10 - - name: Install CUDA ${{ matrix.cuda-version }} - if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.14 - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda-version }} - linux-local-args: '["--toolkit"]' - # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 - # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} - method: 'network' - # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions, - # not just nvcc - # sub-packages: '["nvcc"]' - - - name: Install Jax ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} - run: | - pip install --upgrade pip - pip install --upgrade "jax[cuda${MATRIX_CUDA_MAJOR}_local] == ${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - shell: - bash - - - name: Build wheel - run: | - # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 - # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 - # However this still fails so I'm using a newer version of setuptools - pip install setuptools==68.0.0 - # setuptools-cuda-cpp on pypi has a bug that breaks ninja - pip install git+https://github.com/nshepperd/setuptools-cuda-cpp - pip install ninja packaging wheel pybind11 - export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH - export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH - # Set MAX_JOBS to allocate 8GB per job, which should be enough to build comfortably - free -h - export MAX_JOBS=3 - echo "Building with ${MAX_JOBS} jobs" - python setup.py bdist_wheel --dist-dir=dist - tmpname=cu${MATRIX_CUDA_VERSION}jax${{ matrix.jax-version }} - wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") - ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} - echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - shell: - bash + - name: Build wheels + uses: pypa/cibuildwheel@v2.17.0 + env: + CIBW_BUILD: ${{ matrix.python-version }}-manylinux_x86_64 + CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014_x86_64_cuda_${{ matrix.cuda-version }} - name: Log Built Wheels run: | - ls dist + ls wheelhouse + wheel_name=$(basename wheelhouse/*.whl) + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - name: Get the tag version id: extract_branch @@ -144,7 +98,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ steps.get_current_release.outputs.upload_url }} - asset_path: ./dist/${{env.wheel_name}} + asset_path: ./wheelhouse/${{env.wheel_name}} asset_name: ${{env.wheel_name}} asset_content_type: application/* @@ -152,7 +106,7 @@ jobs: uses: actions/upload-artifact@v4 with: name: ${{env.wheel_name}} - path: ./dist/${{env.wheel_name}} + path: ./wheelhouse/${{env.wheel_name}} publish_package: name: Publish package diff --git a/pyproject.toml b/pyproject.toml index 808827c..0237616 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,2 @@ [build-system] -requires = ["setuptools", "wheel", "setuptools-cuda-cpp", "packaging", "pybind11"] \ No newline at end of file +requires = ["setuptools", "wheel", "setuptools-cuda-cpp @ git+https://github.com/nshepperd/setuptools-cuda-cpp", "packaging", "pybind11"] \ No newline at end of file diff --git a/setup.py b/setup.py index adfa464..2f0a5f4 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,6 @@ from setuptools import setup, find_packages from setuptools_cuda_cpp import CUDAExtension, BuildExtension, fix_dll -# from setuptools_cuda.inspections import find_cuda_home import pybind11 import subprocess @@ -53,14 +52,31 @@ def get_platform(): else: raise ValueError("Unsupported platform: {}".format(sys.platform)) - -def get_cuda_bare_metal_version(cuda_dir): +def locate_cuda(): + if 'sdist' in sys.argv: + return None + cuda_dir = os.environ.get("CUDA_HOME", None) + if cuda_dir is None: + if os.path.exists("/usr/local/cuda"): + cuda_dir = "/usr/local/cuda" + os.environ["CUDA_HOME"] = cuda_dir + elif os.path.exists("/opt/cuda"): + cuda_dir = "/opt/cuda" + os.environ["CUDA_HOME"] = cuda_dir + else: + raise RuntimeError("CUDA_HOME not set and no CUDA installation found") + return cuda_dir + + +def get_cuda_version(): + cuda_dir = locate_cuda() + if cuda_dir is None: + return "" raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version + version = output[release_idx].split(",")[0].split('.')[0] # should be 11 or 12 + return f'+cu{version}' def append_nvcc_threads(nvcc_extra_args): @@ -180,9 +196,9 @@ def get_package_version(): public_version = ast.literal_eval(version_match.group(1)) local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") if local_version: - return f"{public_version}+{local_version}" + return f"{public_version}+{local_version}{get_cuda_version()}" else: - return str(public_version) + return f"{public_version}{get_cuda_version()}" class NinjaBuildExtension(BuildExtension):