diff --git a/.github/workflows/build-ci.sh b/.github/workflows/build-ci.sh index 2190804..cb85c67 100755 --- a/.github/workflows/build-ci.sh +++ b/.github/workflows/build-ci.sh @@ -2,8 +2,28 @@ set -e +function status { + echo -e "\033[1m$1\033[0m" +} + +function info { + echo -e "\033[1;34m$1\033[0m" +} + +function warning { + echo -e "\033[1;33m$1\033[0m" +} + +function verbose_cmd { + if [ $verbose -eq 1 ]; then + "$@" + else + "$@" > /dev/null + fi +} + project_root="$( cd -- "$(dirname "$0")/../.." >/dev/null 2>&1 ; pwd -P )" -echo "Project root: $project_root" +status "Project root: $project_root" py_venv_path="$project_root/.venv" cinnamon_path="$project_root/cinnamon" @@ -11,43 +31,119 @@ llvm_path="$project_root/llvm" torch_mlir_path="$project_root/torch-mlir" upmem_path="$project_root/upmem" +verbose=0 +reconfigure=0 + setup_python_venv=1 checkout_and_build_llvm=1 checkout_and_build_torch_mlir=1 checkout_upmem=1 +build_cinnamon_wheel=1 + enable_cuda=0 enable_roc=0 -if echo "$@" | grep -q "no-python-venv"; then +# Section for configuring based on legacy environment variables +############################################################### + +if [ -n "$LLVM_BUILD_DIR" ]; then + checkout_and_build_llvm=external + TORCH_MLIR_CMAKE_OPTIONS="$TORCH_MLIR_CMAKE_OPTIONS -DLLVM_DIR=$LLVM_BUILD_DIR/lib/cmake/llvm" + CINNAMON_CMAKE_OPTIONS="$CINNAMON_CMAKE_OPTIONS -DLLVM_DIR=$LLVM_BUILD_DIR/lib/cmake/llvm" + + info "Using environment variable LLVM_BUILD_DIR for configuration" + info "Dependent targets will use '$LLVM_BUILD_DIR'" + + if [ ! -d "$LLVM_BUILD_DIR" ]; then + warning "Directory '$LLVM_BUILD_DIR' does not exist" + fi +fi + +if [ -n "$TORCH_MLIR_INSTALL_DIR" ]; then + checkout_and_build_torch_mlir=external + CINNAMON_CMAKE_OPTIONS="$CINNAMON_CMAKE_OPTIONS -DTORCH_MLIR_DIR=$TORCH_MLIR_INSTALL_DIR" + + info "Using environment variable TORCH_MLIR_INSTALL_DIR for configuration" + info "Dependent targets will use '$TORCH_MLIR_INSTALL_DIR'" + + if [ ! -d "$TORCH_MLIR_INSTALL_DIR" ]; then + warning "Directory '$TORCH_MLIR_INSTALL_DIR' does not exist" + fi +fi + +if [ -n "$UPMEM_HOME" ]; then + checkout_upmem=external + CINNAMON_CMAKE_OPTIONS="$CINNAMON_CMAKE_OPTIONS -DUPMEM_DIR=$UPMEM_HOME" + + info "Using environment variable UPMEM_HOME for configuration" + info "Dependent targets will use '$UPMEM_HOME'" + + if [ ! -d "$UPMEM_HOME" ]; then + warning "Directory '$UPMEM_HOME' does not exist" + fi +fi + +############################################################### + +if echo "$@" | grep -q -- "-verbose"; then + verbose=1 +else + info "Some steps will be run in quiet mode, use -verbose to see all output" +fi + +if echo "$@" | grep -q -- "-reconfigure"; then + reconfigure=1 +fi + +if echo "$@" | grep -q -- "-no-python-venv"; then setup_python_venv=0 fi -if echo "$@" | grep -q "no-llvm"; then +if echo "$@" | grep -q -- "-no-llvm"; then checkout_and_build_llvm=0 fi -if echo "$@" | grep -q "no-torch-mlir"; then +if echo "$@" | grep -q -- "-no-torch-mlir"; then checkout_and_build_torch_mlir=0 fi -if echo "$@" | grep -q "no-upmem"; then +if echo "$@" | grep -q -- "-no-upmem"; then checkout_upmem=0 fi -if echo "$@" | grep -q "enable-cuda"; then +if echo "$@" | grep -q -- "-no-cinnamon-wheel"; then + build_cinnamon_wheel=0 +fi + +if echo "$@" | grep -q -- "-enable-cuda"; then enable_cuda=1 fi -if echo "$@" | grep -q "enable-roc"; then +if echo "$@" | grep -q -- "-enable-roc"; then enable_roc=1 fi if [[ $setup_python_venv -eq 1 ]]; then + # NOTE: This is a temporary workaround as some distros ship python3.13 which does not yet provide a torch package + supported_python_executable=python3 + if command -v python3.12 &> /dev/null; then + supported_python_executable=python3.12 + fi + + reconfigure_python_venv=0 if [ ! -d "$py_venv_path" ]; then - python3 -m venv "$py_venv_path" + status "Creating Python venv" + $supported_python_executable -m venv "$py_venv_path" source "$py_venv_path/bin/activate" + reconfigure_python_venv=1 + else + status "Enabling Python venv" + source "$py_venv_path/bin/activate" + fi + if [ $reconfigure -eq 1 ] || [ $reconfigure_python_venv -eq 1 ]; then + status "Installing Python dependencies" # https://pytorch.org/get-started/locally/ if [[ $enable_cuda -eq 1 ]]; then torch_source=https://download.pytorch.org/whl/cu124 @@ -57,150 +153,196 @@ if [[ $setup_python_venv -eq 1 ]]; then torch_source=https://download.pytorch.org/whl/cpu fi - pip install --upgrade pip - pip install ninja-build - pip install torch torchvision torchaudio --index-url $torch_source - pip install pybind11 - else - source "$py_venv_path/bin/activate" + verbose_cmd pip install --upgrade pip + verbose_cmd pip install torch torchvision torchaudio --index-url $torch_source + verbose_cmd pip install pybind11 + verbose_cmd pip install build fi -else - echo "Skipping Python venv setup" - echo "Make sure to have a correct Python environment set up" +elif [[ $setup_python_venv -eq 0 ]]; then + warning "Skipping Python venv setup" + warning "Make sure to have a correct Python environment set up" fi if [[ $checkout_and_build_llvm -eq 1 ]]; then + reconfigure_llvm=0 if [ ! -d "$llvm_path" ]; then - git clone https://github.com/oowekyala/llvm-project "$llvm_path" + status "Checking out LLVM" + git clone --branch llvmorg-19.1.3 --depth 1 https://github.com/llvm/llvm-project "$llvm_path" + status "Applying patches to LLVM" cd "$llvm_path" + patch_dir="$project_root/patches/llvm" + for patch in $(ls $patch_dir); do + git apply $patch_dir/$patch + done + + reconfigure_llvm=1 + fi + + cd "$llvm_path" - git checkout cinnamon-llvm - cmake -S llvm -B build -GNinja \ + if [ $reconfigure -eq 1 ] || [ $reconfigure_llvm -eq 1 ]; then + status "Configuring LLVM" + cmake -S llvm -B build \ + -Wno-dev \ -DLLVM_ENABLE_PROJECTS="mlir;llvm;clang" \ -DLLVM_TARGETS_TO_BUILD="host" \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DLLVM_BUILD_TOOLS=OFF \ -DCMAKE_BUILD_TYPE=Release \ -DBUILD_SHARED_LIBS=ON \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DLLVM_INCLUDE_BENCHMARKS=OFF \ -DLLVM_OPTIMIZED_TABLEGEN=ON \ -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD=SPIRV \ $LLVM_CMAKE_OPTIONS fi - cd "$llvm_path" - git pull + status "Building LLVM" cmake --build build --target all llc opt export PATH=$llvm_path/build/bin:$PATH -else - echo "Skipping LLVM checkout and build" -fi - -llvm_dir=${LLVM_BUILD_DIR:-"$llvm_path/build"} -echo "Using LLVM installation in $llvm_dir" -if [[ ! -d "$llvm_dir" ]]; then - echo "No LLVM installation found" - exit 1 +elif [[ $checkout_and_build_llvm -eq 0 ]]; then + warning "Skipping LLVM checkout and build" + warning "The following steps will need LLVM_DIR and MLIR_DIR to be set in their respective _CMAKE_OPTIONS" fi if [[ $checkout_and_build_torch_mlir -eq 1 ]]; then + reconfigure_torch_mlir=0 if [ ! -d "$torch_mlir_path" ]; then - git clone https://github.com/llvm/torch-mlir "$torch_mlir_path" + status "Checking out Torch-MLIR" + git clone --depth 1 https://github.com/llvm/torch-mlir "$torch_mlir_path" + + cd "$torch_mlir_path" + git fetch --depth 1 origin 98e08023bbf71e00ab81e980eac9f7c96f1f24b4 + git checkout 98e08023bbf71e00ab81e980eac9f7c96f1f24b4 + + reconfigure_torch_mlir=1 fi + cd "$torch_mlir_path" - if [ ! -d "build" ]; then - mkdir build - git checkout snapshot-20240127.1096 - cmake -S . -B build -GNinja \ - -DLLVM_DIR="$llvm_dir/lib/cmake/llvm" \ - -DMLIR_DIR="$llvm_dir/lib/cmake/mlir" \ + if [ $reconfigure -eq 1 ] || [ $reconfigure_torch_mlir -eq 1 ]; then + status "Configuring Torch-MLIR" + dependency_paths="" + + if [[ $setup_python_venv -eq 1 ]]; then + dependency_paths="$dependency_paths -DPython3_FIND_VIRTUALENV=ONLY" + fi + + if [[ $checkout_and_build_llvm -eq 1 ]]; then + dependency_paths="$dependency_paths -DLLVM_DIR=$llvm_path/build/lib/cmake/llvm" + dependency_paths="$dependency_paths -DMLIR_DIR=$llvm_path/build/lib/cmake/mlir" + fi + + cmake -S . -B build \ + $dependency_paths \ + -Wno-dev \ -DCMAKE_BUILD_TYPE=Release \ -DTORCH_MLIR_OUT_OF_TREE_BUILD=ON \ -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ $TORCH_MLIR_CMAKE_OPTIONS fi + status "Building Torch-MLIR" cmake --build build --target all TorchMLIRPythonModules - cmake --install build --prefix install + verbose_cmd cmake --install build --prefix install if [[ $setup_python_venv -eq 1 ]]; then + status "Building and installing Torch-MLIR Python package" python_package_dir=build/tools/torch-mlir/python_packages/torch_mlir python_package_rel_build_dir=../../../python_packages/torch_mlir mkdir -p $(dirname $python_package_dir) ln -s "$python_package_rel_build_dir" "$python_package_dir" 2> /dev/null || true - TORCH_MLIR_CMAKE_BUILD_DIR_ALREADY_BUILT=1 TORCH_MLIR_CMAKE_BUILD_DIR=build python setup.py build install + TORCH_MLIR_CMAKE_ALREADY_BUILT=1 TORCH_MLIR_CMAKE_BUILD_DIR=build PYTHONWARNINGS=ignore verbose_cmd python setup.py build install + elif [[ $setup_python_venv -eq 0 ]]; then + warning "Skipping Torch-MLIR Python package build" + warning "Make sure to have a correct Python environment set up" fi -else - echo "Skipping Torch-MLIR checkout and build" -fi - -torch_mlir_dir=${TORCH_MLIR_INSTALL_DIR:-"$torch_mlir_path/install"} -echo "Using torch-mlir installation in $torch_mlir_dir" -if [[ ! -d "$torch_mlir_dir" ]]; then - echo "(warning) No torch-mlir installation found, project will be built without torch-mlir support" +elif [[ $checkout_and_build_torch_mlir -eq 0 ]]; then + warning "Skipping Torch-MLIR checkout and build" + warning "The following steps will need TORCH_MLIR_DIR to be set in their respective _CMAKE_OPTIONS" fi if [[ $checkout_upmem -eq 1 ]]; then if [ ! -d "$upmem_path" ]; then + status "Downloading UpMem SDK" upmem_archive="upmem.tar.gz" curl http://sdk-releases.upmem.com/2024.1.0/ubuntu_22.04/upmem-2024.1.0-Linux-x86_64.tar.gz --output "$upmem_archive" mkdir "$upmem_path" tar xf "$upmem_archive" -C "$upmem_path" --strip-components=1 rm "$upmem_archive" fi -else - echo "Skipping UpMem checkout" -fi - -upmem_dir=${UPMEM_HOME:-"$upmem_path"} -echo "Using UPMEM installation in $upmem_dir" -if [[ ! -d "$upmem_dir" ]]; then - echo "(warning) No UPMEM installation found, project will be built without UPMEM support" +elif [[ $checkout_upmem -eq 0 ]]; then + warning "Skipping UpMem checkout" + warning "The following steps will need UPMEM_DIR to be set in their respective _CMAKE_OPTIONS" fi cd "$cinnamon_path" -if [ ! -d "build" ]; then +if [ ! -d "build" ] || [ $reconfigure -eq 1 ]; then + status "Configuring Cinnamon" + ln -s "$project_root/LICENSE" "$cinnamon_path/python/" 2>/dev/null || true dependency_paths="" - if [[ -d "$torch_mlir_dir" ]]; then - dependency_paths="$dependency_paths -DTORCH_MLIR_DIR=$torch_mlir_dir" + if [[ $checkout_and_build_llvm -eq 1 ]]; then + dependency_paths="$dependency_paths -DLLVM_DIR=$llvm_path/build/lib/cmake/llvm" + dependency_paths="$dependency_paths -DMLIR_DIR=$llvm_path/build/lib/cmake/mlir" fi - if [[ -d "$upmem_dir" ]]; then - dependency_paths="$dependency_paths -DUPMEM_DIR=$upmem_dir" + if [[ $checkout_and_build_torch_mlir -eq 1 ]]; then + dependency_paths="$dependency_paths -DTORCH_MLIR_DIR=$torch_mlir_path/install" fi - cmake -S . -B "build" -GNinja \ + if [[ $checkout_upmem -eq 1 ]]; then + dependency_paths="$dependency_paths -DUPMEM_DIR=$upmem_path" + fi + + cmake -S . -B "build" \ -DCMAKE_BUILD_TYPE=RelWithDebInfo \ $dependency_paths \ - -DLLVM_DIR="$llvm_dir/lib/cmake/llvm" \ - -DMLIR_DIR="$llvm_dir/lib/cmake/mlir" \ - -DTORCH_MLIR_DIR="$torch_mlir_dir" \ -DCINM_BUILD_GPU_SUPPORT=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ $CINNAMON_CMAKE_OPTIONS fi +status "Building Cinnamon" cmake --build build --target all -if [[ $setup_python_venv -eq 1 ]]; then +if [[ $setup_python_venv -eq 1 ]] && [[ -n "$llvm_path" ]] && [[ -n "$torch_mlir_path" ]]; then + status "Building Cinnamon Python package" site_packages_dir="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')" - cinnamon_python_package_dir_src="$project_root/cinnamon/python/cinnamon" - cinnamon_binaries_dir_src="$project_root/cinnamon/build/bin" + cinnamon_python_package_dir_src="$project_root/cinnamon/python/src/cinnamon" cinnamon_python_package_dir_dest="$site_packages_dir/cinnamon" - cinnamon_binaries_dir_dest="$site_packages_dir/cinnamon/_resources" + cinnamon_python_package_resource_dir="$cinnamon_python_package_dir_dest/_resources" + + cinnamon_python_resources="" + + cinnamon_python_resources="$cinnamon_python_resources $cinnamon_path/build/bin/cinm-opt" + cinnamon_python_resources="$cinnamon_python_resources $cinnamon_path/build/lib/libMemristorDialectRuntime.so" + + cinnamon_python_resources="$cinnamon_python_resources $torch_mlir_path/build/bin/torch-mlir-opt" + + cinnamon_python_resources="$cinnamon_python_resources $llvm_path/build/bin/mlir-translate" + cinnamon_python_resources="$cinnamon_python_resources $llvm_path/build/bin/clang" if [ ! -d "$cinnamon_python_package_dir_dest" ]; then ln -s "$cinnamon_python_package_dir_src" "$cinnamon_python_package_dir_dest" fi - if [ ! -d "$cinnamon_binaries_dir_dest" ]; then - ln -s "$cinnamon_binaries_dir_src" "$cinnamon_binaries_dir_dest" + mkdir -p "$cinnamon_python_package_resource_dir" || true + + for resource in $cinnamon_python_resources; do + ln -s "$resource" "$cinnamon_python_package_resource_dir" 2>/dev/null || true + done + + if [[ $build_cinnamon_wheel -eq 1 ]]; then + cd "$cinnamon_path/python" + PYTHONWARNINGS=ignore verbose_cmd python -m build fi +elif [[ $setup_python_venv -eq 0 ]]; then + warning "Skipping Cinnamon Python package build" + warning "Make sure to have a correct Python environment set up" fi \ No newline at end of file diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 75fcf9d..d898a94 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -3,6 +3,8 @@ run-name: 'Build and Test: ${{ github.event.head_commit.message }}' on: workflow_dispatch: push: + pull_request: + types: [opened, reopened] jobs: main: name: Build and test @@ -32,7 +34,7 @@ jobs: key: cinnamon-dependencies-${{ runner.os }} - name: Build - run: .github/workflows/build-ci.sh + run: .github/workflows/build-ci.sh -reconfigure - name: Test working-directory: cinnamon/build diff --git a/.gitignore b/.gitignore index c0658df..7c9d959 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .vscode .directory .venv -llvm -torch-mlir -upmem \ No newline at end of file +/llvm +/torch-mlir +/upmem \ No newline at end of file diff --git a/build.sh b/build.sh deleted file mode 100755 index cc3f08e..0000000 --- a/build.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -project_root="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" -llvm_path="$project_root/llvm" -cinnamon_path="$project_root/cinnamon" - -export PATH=$llvm_path/build/bin:$PATH - -if [[ $1 != "no-llvm" ]]; then - if [ ! -d "$llvm_path" ]; then - git clone https://github.com/oowekyala/llvm-project "$llvm_path" - - cd "$llvm_path" - - git checkout cinnamon-llvm - cmake -S llvm -B build \ - -DLLVM_ENABLE_PROJECTS="mlir;llvm;clang" \ - -DLLVM_TARGETS_TO_BUILD="host" \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_ENABLE_BINDINGS_PYTHON=OFF \ - -DLLVM_BUILD_TOOLS=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DBUILD_SHARED_LIBS=ON \ - -DLLVM_OPTIMIZED_TABLEGEN=ON - fi - - cd "$llvm_path" - git pull - cmake --build build --target all llc opt -fi - -cd "$cinnamon_path" - -if [ ! -d "build" ]; then - cmake -S . -B "build" \ - -DCMAKE_BUILD_TYPE=RelWithDebInfo \ - -DLLVM_DIR="$llvm_path"/build/lib/cmake/llvm \ - -DMLIR_DIR="$llvm_path"/build/lib/cmake/mlir \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ - $CINNAMON_CMAKE_OPTIONS -fi - -cmake --build build --target all \ No newline at end of file diff --git a/cinnamon/include/cinm-mlir/Dialect/Cinm/Interfaces/TilingInterface.h b/cinnamon/include/cinm-mlir/Dialect/Cinm/Interfaces/TilingInterface.h index 38b28b1..8ab2a73 100644 --- a/cinnamon/include/cinm-mlir/Dialect/Cinm/Interfaces/TilingInterface.h +++ b/cinnamon/include/cinm-mlir/Dialect/Cinm/Interfaces/TilingInterface.h @@ -84,7 +84,7 @@ Value createArithIntOrFloatOp(OpBuilder &builder, Location loc, Value a, Value b) { assert(a.getType() == b.getType() && "Mismatched type"); assert(a.getType().isIntOrIndexOrFloat() && "Expected scalar type"); - if (a.getType().isa()) { + if (isa(a.getType())) { return builder.create(loc, a, b); } else { return builder.create(loc, a, b); diff --git a/cinnamon/include/cinm-mlir/Dialect/UPMEM/IR/UPMEMOps.td b/cinnamon/include/cinm-mlir/Dialect/UPMEM/IR/UPMEMOps.td index bb3124d..87f6d7b 100644 --- a/cinnamon/include/cinm-mlir/Dialect/UPMEM/IR/UPMEMOps.td +++ b/cinnamon/include/cinm-mlir/Dialect/UPMEM/IR/UPMEMOps.td @@ -33,7 +33,7 @@ include "mlir/IR/RegionKindInterface.td" class UPMEM_IndexOp traits = []> : UPMEM_Op])>, + Pure, DeclareOpInterfaceMethods])>, Results<(outs Index)> { let assemblyFormat = "attr-dict"; } @@ -128,7 +128,7 @@ def UPMEM_PrivateWRAMAllocOp : UPMEM_Op<"pwram_alloc", [ def UPMEM_LaunchOp : UPMEM_Op<"launch", [ AutomaticAllocationScope, AttrSizedOperandSegments, UPMEM_AsyncOpInterface, IsolatedFromAbove, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "UPMEM kernel launch operation"; let arguments = (ins diff --git a/cinnamon/lib/Conversion/CimToMemristor/CimToMemristor.cpp b/cinnamon/lib/Conversion/CimToMemristor/CimToMemristor.cpp index 313e3a5..18cfa49 100644 --- a/cinnamon/lib/Conversion/CimToMemristor/CimToMemristor.cpp +++ b/cinnamon/lib/Conversion/CimToMemristor/CimToMemristor.cpp @@ -40,7 +40,7 @@ struct ConvertCimOpToMemristor : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto tileId = op.getOperand(0); - auto resultShape = op.getResult().getType().template cast(); + auto resultShape = cast(op.getResult().getType()); auto resultAllocOp = rewriter.create( op.getLoc(), @@ -49,7 +49,7 @@ struct ConvertCimOpToMemristor : OpConversionPattern { ValueRange{}); auto createBufferizeOp = [&](Value value) { - auto shapedType = value.getType().cast(); + auto shapedType = cast(value.getType()); return rewriter.create( op.getLoc(), MemRefType::get(shapedType.getShape(), shapedType.getElementType()), diff --git a/cinnamon/lib/Conversion/CinmToCim/CinmToCim.cpp b/cinnamon/lib/Conversion/CinmToCim/CinmToCim.cpp index 1367183..73c19d6 100644 --- a/cinnamon/lib/Conversion/CinmToCim/CinmToCim.cpp +++ b/cinnamon/lib/Conversion/CinmToCim/CinmToCim.cpp @@ -32,7 +32,7 @@ namespace { // Creates the specified type for a value with correct shape and element type // Condition: The value must be shaped type template static T getShapedType(Value value) { - auto shapedType = value.getType().cast(); + auto shapedType = cast(value.getType()); return T::get(shapedType.getShape(), shapedType.getElementType()); } diff --git a/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp b/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp index accbf6b..ee1fc2e 100644 --- a/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp +++ b/cinnamon/lib/Conversion/CinmToCnm/CinmToCnm.cpp @@ -298,14 +298,14 @@ LogicalResult convertInputIntoAlloc(Location loc, Value &inputBuf, // For each input of the reduce, we need to // convert single element to tensor<1xelementTy> - if (!inputBuf.getType().dyn_cast()) { + if (!isa(inputBuf.getType())) { inputBuf = rewriter.create( RankedTensorType::get(SmallVector(wgTy.getShape().size(), 1), inputBuf.getType()), ValueRange{inputBuf}); } - auto inputType = inputBuf.getType().cast(); + auto inputType = cast(inputBuf.getType()); llvm::SmallVector shapeOfBuffer; std::optional> reshapeInto; @@ -318,9 +318,9 @@ LogicalResult convertInputIntoAlloc(Location loc, Value &inputBuf, return failure(); if (reshapeInto) { - inputBuf = - cinm::reshapeStatic(rewriter, rewriter.getLoc(), inputBuf, - inputType.cast(), *reshapeInto); + inputBuf = cinm::reshapeStatic(rewriter, rewriter.getLoc(), inputBuf, + cast(inputType), + *reshapeInto); } // Allocate a cinm buffer @@ -350,7 +350,7 @@ cnm::LaunchOp createLaunchOp( auto &launchBlock = launchOp.getBody().emplaceBlock(); // arguments are memrefs with same shape as inputs for (auto input : launchOp.getParams()) { - if (auto inputTy = input.getType().dyn_cast()) { + if (auto inputTy = dyn_cast(input.getType())) { auto mappedTy = MemRefType::get(inputTy.getShape(), inputTy.getElementType()); launchBlock.addArgument(mappedTy, input.getLoc()); @@ -428,8 +428,8 @@ LogicalResult convertCinmToCnm( auto res = builder.create(alloc, workgroup, map, outBuf); auto shapedBack = cinm::reshapeStatic( builder, builder.getLoc(), - res.getOutput().cast>(), - result.getType().cast().getShape()); + cast>(res.getOutput()), + cast(result.getType()).getShape()); resultValues.push_back(shapedBack); } @@ -514,7 +514,7 @@ struct ConvertElementWiseToCnm : public OpConversionPattern { ValueRange outputs) { SmallVector affineMaps; for (const auto &i : inputs) { - MemRefType t = i.getType().cast(); + MemRefType t = cast(i.getType()); affineMaps.push_back(AffineMap::getMultiDimIdentityMap( t.getRank(), op.getContext())); @@ -541,7 +541,7 @@ struct ConvertElementWiseToCnm : public OpConversionPattern { Value rhs = IsScalarOp ? inputs[1u] : args[1u]; if constexpr (IsScalarOp) { if (const auto memrefType = - rhs.getType().dyn_cast()) { + dyn_cast(rhs.getType())) { const Value zero = builder.create(loc, 0); rhs = builder.create( @@ -622,7 +622,7 @@ struct ConvertCinmGemmToCnm : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; static Value transpose(ImplicitLocOpBuilder &builder, Value tensor) { - auto inTy = tensor.getType().cast(); + auto inTy = cast(tensor.getType()); auto shape = inTy.getShape(); SmallVector newShape{shape[1], shape[0]}; SmallVector perms{1, 0}; @@ -785,10 +785,8 @@ struct ConvertCinmReduceToCnm : public OpConversionPattern { op.getResult().getType(), builder.getZeroAttr(op.getResult().getType())); - const bool isFloatOp = op.getType() - .cast() - .getElementType() - .dyn_cast() != nullptr; + const bool isFloatOp = isa( + cast(op.getType()).getElementType()); llvm::SmallVector newResults; if (convertCinmToCnm( diff --git a/cinnamon/lib/Conversion/CnmToGPU/CnmToGPU.cpp b/cinnamon/lib/Conversion/CnmToGPU/CnmToGPU.cpp index cb797e5..41977e3 100644 --- a/cinnamon/lib/Conversion/CnmToGPU/CnmToGPU.cpp +++ b/cinnamon/lib/Conversion/CnmToGPU/CnmToGPU.cpp @@ -51,11 +51,11 @@ MemRefType convertCnmBufferToMemRefType(cnm::BufferType bufferType) { void convertLaunchParameter(ConversionPatternRewriter &rewriter, Location loc, Value buffer, ValueRange threadIds, BlockArgument arg) { - if (!buffer.getType().dyn_cast()) { + const auto bufferType = dyn_cast(buffer.getType()); + + if (!bufferType) return; - } - const BufferType bufferType = buffer.getType().dyn_cast(); const MemRefType memrefType = convertCnmBufferToMemRefType(bufferType); const Value source = createOrFoldUnrealizedConversionCast( @@ -122,8 +122,8 @@ struct ConvertCnmScatterToGPU : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { const WorkgroupType workgroupType = op.getWg().getType(); const ArrayRef workgroupShape = workgroupType.getShape(); - const cnm::BufferType bufferType = - op.getOperandTypes()[1].dyn_cast(); + const auto bufferType = + dyn_cast(op.getOperand(1).getType()); Value src = rewriter.getRemappedValue(op.getOperand(0)); Value dst = rewriter.getRemappedValue(op.getOperand(1)); @@ -155,8 +155,8 @@ struct ConvertCnmGatherToGPU : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { const WorkgroupType workgroupType = op.getWg().getType(); const ArrayRef workgroupShape = workgroupType.getShape(); - const cnm::BufferType bufferType = - op.getOperandTypes()[0].dyn_cast(); + const auto bufferType = + dyn_cast(op.getOperand(0).getType()); Value src = rewriter.getRemappedValue(op.getOperand(0)); src = createOrFoldUnrealizedConversionCast( @@ -282,7 +282,6 @@ struct ConvertCnmToGPUPass RewritePatternSet patterns(&getContext()); populateCnmToGPUConversionPatterns(patterns, &getContext()); - populateReconcileUnrealizedCastsPatterns(patterns); ConversionTarget target(getContext()); target.addIllegalDialect(); diff --git a/cinnamon/lib/Conversion/CnmToUPMEM/CnmToUPMEM.cpp b/cinnamon/lib/Conversion/CnmToUPMEM/CnmToUPMEM.cpp index a40a897..0a7e51c 100644 --- a/cinnamon/lib/Conversion/CnmToUPMEM/CnmToUPMEM.cpp +++ b/cinnamon/lib/Conversion/CnmToUPMEM/CnmToUPMEM.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #define GEN_PASS_DEF_CONVERTCNMTOUPMEMPASS #include "cinm-mlir/Conversion/CnmPasses.h.inc" @@ -40,8 +41,8 @@ template T reduceMul(ArrayRef arr) { } MemRefType convertTensorToMemref(ShapedType ty) { - if (ty.isa()) - return ty.cast(); + if (isa(ty)) + return cast(ty); return MemRefType::get(ty.getShape(), ty.getElementType()); } @@ -126,7 +127,7 @@ struct ConvertCnmGatherToUPMEM : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Value outputBuf = adaptor.getOutputBuf(); - bool isBufferized = op.getOutputBuf().getType().isa(); + bool isBufferized = isa(op.getOutputBuf().getType()); if (!isBufferized) { outputBuf = rewriter.create( op->getLoc(), convertTensorToMemref(op.getOutputBuf().getType())); @@ -165,7 +166,7 @@ struct ConvertCnmLaunchToUPMEM : public OpConversionPattern { const size_t availableWRAM = 32 * 1024; size_t requiredWRAM = 0; for (Value buffer : op.getParams()) { - const BufferType bufferType = buffer.getType().cast(); + const BufferType bufferType = cast(buffer.getType()); const size_t elementSize = bufferType.getElementType().getIntOrFloatBitWidth() / 8; requiredWRAM += reduceMul(bufferType.getShape()) * elementSize; @@ -207,7 +208,7 @@ struct ConvertCnmLaunchToUPMEM : public OpConversionPattern { continue; } - const BufferType bufferType = buffer.getType().cast(); + const BufferType bufferType = cast(buffer.getType()); const size_t chunkSize = reduceMul(bufferType.getShape()); const size_t memoryPerTasklet = chunksPerTasklet * chunkSize; const size_t memoryPerDPU = wgShape[2] * memoryPerTasklet; @@ -355,7 +356,6 @@ struct ConvertCnmToUPMEMPass RewritePatternSet patterns(&getContext()); populateCnmToUPMEMConversionPatterns(converter, patterns); - populateReconcileUnrealizedCastsPatterns(patterns); populateFinalBufferizationPatterns(patterns); ConversionTarget target(getContext()); diff --git a/cinnamon/lib/Conversion/CommonPatterns.cpp b/cinnamon/lib/Conversion/CommonPatterns.cpp index 40bf147..8f628b9 100644 --- a/cinnamon/lib/Conversion/CommonPatterns.cpp +++ b/cinnamon/lib/Conversion/CommonPatterns.cpp @@ -94,7 +94,7 @@ LogicalResult ConvertCnmSetZeroToAffine::matchAndRewrite( cnm::SetZeroOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const { const Value dst = rewriter.getRemappedValue(op.getOperand()); - const MemRefType type = dst.getType().cast(); + const MemRefType type = cast(dst.getType()); const SmallVector loopSizes{type.getShape()}; const SmallVector loopSteps(loopSizes.size(), 1); @@ -125,8 +125,8 @@ SmallVector createAffineApply(OpBuilder &builder, Location loc, void createMemrefSubviewCopy(OpBuilder &builder, Location loc, Value src, Value dst, ArrayRef sliceShape, ValueRange srcOffsets, ValueRange dstOffsets) { - MemRefType srcType = src.getType().cast(); - MemRefType dstType = dst.getType().cast(); + MemRefType srcType = cast(src.getType()); + MemRefType dstType = cast(dst.getType()); SmallVector srcStaticOffsets(srcType.getRank(), 0); SmallVector srcStaticSizes{srcType.getShape()}; diff --git a/cinnamon/lib/Conversion/TorchToCinm/TorchToCinm.cpp b/cinnamon/lib/Conversion/TorchToCinm/TorchToCinm.cpp index f4b7e96..8e9c477 100644 --- a/cinnamon/lib/Conversion/TorchToCinm/TorchToCinm.cpp +++ b/cinnamon/lib/Conversion/TorchToCinm/TorchToCinm.cpp @@ -49,20 +49,20 @@ struct ConvertTorchTensorOpToCinm : OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto lhs = op.getOperand(0); - auto lhsType = lhs.getType().template cast(); + auto lhsType = cast(lhs.getType()); auto lhsConversionOp = rewriter.create( op.getLoc(), lhsType.toBuiltinTensor(), lhs); auto rhs = op.getOperand(1); - auto rhsType = rhs.getType().template cast(); + auto rhsType = cast(rhs.getType()); auto rhsConversionOp = rewriter.create( op.getLoc(), rhsType.toBuiltinTensor(), rhs); auto result = op.getResult(); auto resultType = - result.getType().template cast(); + cast(result.getType()); auto cinmComputeOp = rewriter.create( op.getLoc(), resultType.toBuiltinTensor()); diff --git a/cinnamon/lib/Conversion/UPMEMToLLVM/UPMEMToLLVM.cpp b/cinnamon/lib/Conversion/UPMEMToLLVM/UPMEMToLLVM.cpp index 6f8f15e..6a34cfa 100644 --- a/cinnamon/lib/Conversion/UPMEMToLLVM/UPMEMToLLVM.cpp +++ b/cinnamon/lib/Conversion/UPMEMToLLVM/UPMEMToLLVM.cpp @@ -43,6 +43,7 @@ #include #include #include +#include #include namespace mlir { @@ -147,7 +148,7 @@ static FailureOr linearizeAffineMap(AffineMap map, } auto layoutMap = bufferTy.getLayout().getAffineMap(); - if (bufferTy.getLayout().isa()) { + if (isa(bufferTy.getLayout())) { // Replace offsets with 0 to delete the symbols. // Offset is calculated outside of the affine map. layoutMap = layoutMap.replaceDimsAndSymbols( @@ -314,8 +315,8 @@ outlineAffineMap(ImplicitLocOpBuilder &rewriter, // to find it later affineMapFun->setAttr("upmem.generated_from", AffineMapAttr::get(*linearMap)); - rewriter = ImplicitLocOpBuilder::atBlockBegin(rewriter.getLoc(), - affineMapFun.addEntryBlock()); + rewriter = ImplicitLocOpBuilder::atBlockBegin( + rewriter.getLoc(), affineMapFun.addEntryBlock(rewriter)); Value arg = affineMapFun.getArgument(0); // affine expects to deal with index type only arg = createOrFoldUnrealizedConversionCast(rewriter.getLoc(), rewriter, @@ -375,7 +376,7 @@ static LogicalResult lowerScatterOrGather(Op op, typename Op::Adaptor adaptor, } Value bareHostBuf = adaptor.getHostBuffer(); - if (adaptor.getHostBuffer().getType().template isa()) { + if (isa(adaptor.getHostBuffer().getType())) { // Here we compute the pointer to the start of the memref // converted memref Value basePtr = @@ -556,8 +557,8 @@ struct ConvertUPMEMToLLVMPass const auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - // if (type.isa() && inputs.size() == 1 && - // inputs[0].getType().isa()) { + // if (isa(type) && inputs.size() == 1 && + // isa(inputs[0].getType())) { // return builder.create(loc, type, inputs) // .getResult(); // } @@ -570,7 +571,6 @@ struct ConvertUPMEMToLLVMPass RewritePatternSet patterns(&getContext()); populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); populateUPMEMToLLVMConversionPatterns(converter, patterns); - populateReconcileUnrealizedCastsPatterns(patterns); ConversionTarget target(getContext()); target.addIllegalDialect(); diff --git a/cinnamon/lib/Dialect/Cim/IR/CimOps.cpp b/cinnamon/lib/Dialect/Cim/IR/CimOps.cpp index b7d3b4d..a96999c 100644 --- a/cinnamon/lib/Dialect/Cim/IR/CimOps.cpp +++ b/cinnamon/lib/Dialect/Cim/IR/CimOps.cpp @@ -37,14 +37,15 @@ void AcquireDeviceOp::getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) { setNameFn(getResult(), "cim_dev"); } -void AcquireCrossbarOp::getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) { +void AcquireCrossbarOp::getAsmResultNames( + ::mlir::OpAsmSetValueNameFn setNameFn) { setNameFn(getResult(), "cim_cbr"); } ::mlir::LogicalResult GemmOp::verify() { - auto lhs = getLhs().getType().cast(); - auto rhs = getRhs().getType().cast(); - auto result = getResult().getType().cast(); + auto lhs = cast(getLhs().getType()); + auto rhs = cast(getRhs().getType()); + auto result = cast(getResult().getType()); if (lhs.getElementType() != rhs.getElementType()) return emitOpError("lhs and rhs must have the same element type"); @@ -67,9 +68,9 @@ ::mlir::LogicalResult GemmOp::verify() { } ::mlir::LogicalResult GemvOp::verify() { - auto lhs = getLhs().getType().cast(); - auto rhs = getRhs().getType().cast(); - auto result = getResult().getType().cast(); + auto lhs = cast(getLhs().getType()); + auto rhs = cast(getRhs().getType()); + auto result = cast(getResult().getType()); if (lhs.getElementType() != rhs.getElementType()) return emitOpError("lhs and rhs must have the same element type"); diff --git a/cinnamon/lib/Dialect/Cim/Transforms/SchedulingPasses.cpp b/cinnamon/lib/Dialect/Cim/Transforms/SchedulingPasses.cpp index 640f033..e4168de 100644 --- a/cinnamon/lib/Dialect/Cim/Transforms/SchedulingPasses.cpp +++ b/cinnamon/lib/Dialect/Cim/Transforms/SchedulingPasses.cpp @@ -26,211 +26,220 @@ namespace mlir::cim { - static bool isCimOp(Operation &op) { - return op.getName().getStringRef().starts_with("cim.op"); +static bool isCimOp(Operation &op) { + return op.getName().getStringRef().starts_with("cim.op"); +} + +static void scheduleCimOpOnCrossbar(Operation &op, Value crossbar) { + op.getOperand(0).replaceUsesWithIf( + crossbar, [&](OpOperand &use) { return use.getOwner() == &op; }); +} + +static Operation *insertBarrierForCimOpResult(PatternRewriter &rewriter, + Operation &cimOp) { + auto future = cimOp.getResult(0); + + auto shapedType = cast(future.getType()); + auto barrierOp = rewriter.create( + cimOp.getLoc(), + RankedTensorType::get(shapedType.getShape(), shapedType.getElementType()), + future); + future.replaceAllUsesExcept(barrierOp.getResult(), barrierOp); + return barrierOp; +} + +static bool cimScheduleCheckDynamicallyLegal(Operation *op) { + auto acquireDeviceOp = dyn_cast(op); + + if (!acquireDeviceOp) + return true; + + return acquireDeviceOp.getIsFullyScheduled(); +}; + +static ReleaseDeviceOp getReleaseDeviceOp(AcquireDeviceOp op) { + for (auto *user : op.getResult().getUsers()) + if (auto releaseDeviceOp = dyn_cast(user)) + return releaseDeviceOp; + return nullptr; +} + +static ReleaseCrossbarOp getReleaseCrossbarOp(AcquireCrossbarOp op) { + for (auto *user : op.getResult().getUsers()) + if (auto releaseCrossbarOp = dyn_cast(user)) + return releaseCrossbarOp; + return nullptr; +} + +static std::vector +getDependentAcquireCrossbarOps(AcquireDeviceOp op) { + std::vector acquireCrossbarOps; + for (auto *user : op.getResult().getUsers()) + if (auto acquireCrossbarOp = dyn_cast(user)) + acquireCrossbarOps.push_back(acquireCrossbarOp); + return acquireCrossbarOps; +} + +static std::vector getDependentCimOps(AcquireCrossbarOp op) { + std::vector ops; + for (auto *user : op.getResult().getUsers()) + if (isCimOp(*user)) + ops.push_back(user); + return ops; +} + +static bool hasUsersOutsideAcquiredBlock(Operation *op, Value crossbarId) { + for (auto *user : op->getUsers()) { + if (isCimOp(*user) && user->getOperand(0) == crossbarId) + continue; + + return true; } - static void scheduleCimOpOnCrossbar(Operation &op, Value crossbar) { - op.getOperand(0).replaceUsesWithIf( - crossbar, [&](OpOperand &use) { return use.getOwner() == &op; }); + return false; +} + +static std::pair, std::vector> +prepareForScheduling(AcquireDeviceOp acquireDeviceOp, + PatternRewriter &rewriter) { + auto releaseDeviceOp = getReleaseDeviceOp(acquireDeviceOp); + auto acquireCrossbarOps = getDependentAcquireCrossbarOps(acquireDeviceOp); + + // save only one of the potentially multiple crossbar ids + auto savedCrossbarOp = acquireCrossbarOps.back(); + acquireCrossbarOps.pop_back(); + + // delete the rest, rebind their uses to saved crossbar id + for (auto acquireCrossbarOp : acquireCrossbarOps) { + auto releaseCrossbarOp = getReleaseCrossbarOp(acquireCrossbarOp); + acquireCrossbarOp.getResult().replaceAllUsesWith( + savedCrossbarOp.getResult()); + acquireCrossbarOp.erase(); + releaseCrossbarOp.erase(); } - static Operation *insertBarrierForCimOpResult(PatternRewriter &rewriter, Operation &cimOp) { - auto future = cimOp.getResult(0); + std::unordered_set discoveredBarriers; - auto shapedType = future.getType().cast(); - auto barrierOp = rewriter.create( - cimOp.getLoc(), - RankedTensorType::get(shapedType.getShape(), shapedType.getElementType()), - future); - future.replaceAllUsesExcept(barrierOp.getResult(), barrierOp); - return barrierOp; - } - - static bool cimScheduleCheckDynamicallyLegal(Operation *op) { - auto acquireDeviceOp = dyn_cast(op); - - if (!acquireDeviceOp) - return true; - - return acquireDeviceOp.getIsFullyScheduled(); - }; - - static ReleaseDeviceOp getReleaseDeviceOp(AcquireDeviceOp op) { - for (auto *user : op.getResult().getUsers()) - if (auto releaseDeviceOp = dyn_cast(user)) - return releaseDeviceOp; - return nullptr; - } - - static ReleaseCrossbarOp getReleaseCrossbarOp(AcquireCrossbarOp op) { - for (auto *user : op.getResult().getUsers()) - if (auto releaseCrossbarOp = dyn_cast(user)) - return releaseCrossbarOp; - return nullptr; - } - - static std::vector - getDependentAcquireCrossbarOps(AcquireDeviceOp op) { - std::vector acquireCrossbarOps; - for (auto *user : op.getResult().getUsers()) - if (auto acquireCrossbarOp = dyn_cast(user)) - acquireCrossbarOps.push_back(acquireCrossbarOp); - return acquireCrossbarOps; - } - - static std::vector - getDependentCimOps(AcquireCrossbarOp op) { - std::vector ops; - for (auto *user : op.getResult().getUsers()) - if (isCimOp(*user)) - ops.push_back(user); - return ops; - } - - static bool - hasUsersOutsideAcquiredBlock(Operation *op, Value crossbarId) { - for (auto *user : op->getUsers()) { - if (isCimOp(*user) && user->getOperand(0) == crossbarId) + auto cimOps = getDependentCimOps(savedCrossbarOp); + for (auto *cimOp : cimOps) { + for (auto operand : cimOp->getOperands()) { + // check if operand is a tensor created by a cim.barrier operation + auto *definingOp = operand.getDefiningOp(); + if (!llvm::isa(operand.getType()) || !definingOp || + !llvm::isa(definingOp)) continue; - return true; - } - - return false; - } - - static std::pair, std::vector> - prepareForScheduling(AcquireDeviceOp acquireDeviceOp, PatternRewriter &rewriter) { - auto releaseDeviceOp = getReleaseDeviceOp(acquireDeviceOp); - auto acquireCrossbarOps = getDependentAcquireCrossbarOps(acquireDeviceOp); - - // save only one of the potentially multiple crossbar ids - auto savedCrossbarOp = acquireCrossbarOps.back(); - acquireCrossbarOps.pop_back(); - - // delete the rest, rebind their uses to saved crossbar id - for (auto acquireCrossbarOp : acquireCrossbarOps) { - auto releaseCrossbarOp = getReleaseCrossbarOp(acquireCrossbarOp); - acquireCrossbarOp.getResult().replaceAllUsesWith(savedCrossbarOp.getResult()); - acquireCrossbarOp.erase(); - releaseCrossbarOp.erase(); + discoveredBarriers.insert(definingOp); } - std::unordered_set discoveredBarriers; - - auto cimOps = getDependentCimOps(savedCrossbarOp); - for (auto *cimOp : cimOps) { - for (auto operand : cimOp->getOperands()) { - // check if operand is a tensor created by a cim.barrier operation - auto *definingOp = operand.getDefiningOp(); - if (!llvm::isa(operand.getType()) || !definingOp || !llvm::isa(definingOp)) - continue; - - discoveredBarriers.insert(definingOp); - } - - for (auto user : cimOp->getUsers()) { - if (llvm::isa(user)) - discoveredBarriers.insert(user); - } + for (auto user : cimOp->getUsers()) { + if (llvm::isa(user)) + discoveredBarriers.insert(user); } + } - std::vector crossbarIds; - std::vector roots; + std::vector crossbarIds; + std::vector roots; - // add saved crossbar id to crossbarIds - crossbarIds.push_back(savedCrossbarOp.getResult()); + // add saved crossbar id to crossbarIds + crossbarIds.push_back(savedCrossbarOp.getResult()); - // recreate acquire_crossbar operations, until the number of available crossbars is reached - while (crossbarIds.size() != acquireDeviceOp.getAvailableCrossbarCount()) { - rewriter.setInsertionPointAfter(acquireDeviceOp.getOperation()); - auto acquireCrossbarOp = - rewriter.create(acquireDeviceOp->getLoc(), acquireDeviceOp.getResult()); - crossbarIds.push_back(acquireCrossbarOp.getResult()); + // recreate acquire_crossbar operations, until the number of available + // crossbars is reached + while (crossbarIds.size() != acquireDeviceOp.getAvailableCrossbarCount()) { + rewriter.setInsertionPointAfter(acquireDeviceOp.getOperation()); + auto acquireCrossbarOp = rewriter.create( + acquireDeviceOp->getLoc(), acquireDeviceOp.getResult()); + crossbarIds.push_back(acquireCrossbarOp.getResult()); - rewriter.setInsertionPoint(releaseDeviceOp); - rewriter.create(releaseDeviceOp->getLoc(), acquireCrossbarOp.getResult()); - } + rewriter.setInsertionPoint(releaseDeviceOp); + rewriter.create(releaseDeviceOp->getLoc(), + acquireCrossbarOp.getResult()); + } - // find all roots for scheduling - for (auto *barrier : discoveredBarriers) { - if (hasUsersOutsideAcquiredBlock(barrier, savedCrossbarOp.getResult())) - roots.push_back(barrier->getOperand(0)); + // find all roots for scheduling + for (auto *barrier : discoveredBarriers) { + if (hasUsersOutsideAcquiredBlock(barrier, savedCrossbarOp.getResult())) + roots.push_back(barrier->getOperand(0)); - // replace operand with its cim.future - barrier->getResult(0).replaceAllUsesWith(barrier->getOperand(0)); - barrier->erase(); - } - - return {crossbarIds, roots}; + // replace operand with its cim.future + barrier->getResult(0).replaceAllUsesWith(barrier->getOperand(0)); + barrier->erase(); } - template class Scheduler> - struct CimSchedulePattern : public RewritePattern { - CimSchedulePattern(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag{}, 1, context) {} + return {crossbarIds, roots}; +} - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto acquireDeviceOp = dyn_cast(op); +template