diff --git a/CMakeLists.txt b/CMakeLists.txt index 02aae658c664..f03d283785f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,8 @@ else() set(MLIR_CUDA_CONVERSIONS_ENABLED 0) endif() +set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner") + include_directories( "include") include_directories( ${MLIR_INCLUDE_DIR}) diff --git a/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 246bd549f43f..da896afb0908 100644 --- a/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -124,7 +124,7 @@ OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule, return {}; } targetMachine.reset( - target->createTargetMachine(triple.str(), "sm_75", "+ptx60", {}, {})); + target->createTargetMachine(triple.str(), "sm_35", "+ptx60", {}, {})); } // Set the data layout of the llvm module to match what the ptx target needs. diff --git a/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index 550491b3cd0e..6de304333899 100644 --- a/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -40,27 +40,32 @@ namespace { constexpr const char *kCubinAnnotation = "nvvm.cubin"; constexpr const char *kCubinGetterAnnotation = "nvvm.cubingetter"; constexpr const char *kCubinGetterSuffix = "_cubin"; -constexpr const char *kMallocHelperName = "mcuMalloc"; +constexpr const char *kMallocHelperName = "malloc"; /// A pass generating getter functions for all cubin blobs annotated on /// functions via the nvvm.cubin attribute. /// -/// The functions allocate memory using a mcuMalloc helper function with the -/// signature void *mcuMalloc(int32_t size). This function has to be provided -/// by the actual runner that executes the generated code. +/// The functions allocate memory using the system malloc call with signature +/// void *malloc(size_t size). This function has to be provided by the actual +/// runner that executes the generated code. /// /// This is a stop-gap measure until MLIR supports global constants. class GpuGenerateCubinAccessorsPass : public ModulePass { private: + LLVM::LLVMType getIndexType() { + unsigned bits = + llvmDialect->getLLVMModule().getDataLayout().getPointerSizeInBits(); + return LLVM::LLVMType::getIntNTy(llvmDialect, bits); + } + Function getMallocHelper(Location loc, Builder &builder) { Function result = getModule().getNamedFunction(kMallocHelperName); if (!result) { result = Function::create( loc, kMallocHelperName, - builder.getFunctionType( - ArrayRef{LLVM::LLVMType::getInt32Ty(llvmDialect)}, - LLVM::LLVMType::getInt8PtrTy(llvmDialect))); + builder.getFunctionType(ArrayRef{getIndexType()}, + LLVM::LLVMType::getInt8PtrTy(llvmDialect))); getModule().push_back(result); } return result; @@ -84,8 +89,8 @@ class GpuGenerateCubinAccessorsPass OpBuilder ob(result.getBody()); ob.createBlock(); auto sizeConstant = ob.create( - loc, LLVM::LLVMType::getInt32Ty(llvmDialect), - builder.getI32IntegerAttr(blob.getValue().size())); + loc, getIndexType(), + builder.getIntegerAttr(builder.getIndexType(), blob.getValue().size())); auto memory = ob.create( loc, ArrayRef{LLVM::LLVMType::getInt8PtrTy(llvmDialect)}, diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a57505d5c5c1..2e102395e83c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -11,6 +11,10 @@ llvm_canonicalize_cmake_booleans( # for linalg integration tests. set(MLIR_LINALG_INTEGRATION_TEST_LIB_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +# Passed to lit.site.cfg.py.in to set up the path where to find the libraries +# for the mlir cuda runner tests. +set(MLIR_CUDA_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) + configure_lit_site_cfg( ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py @@ -48,6 +52,12 @@ if(LLVM_BUILD_EXAMPLES) ) endif() +if(MLIR_CUDA_RUNNER_ENABLED) + list(APPEND MLIR_TEST_DEPENDS + mlir-cuda-runner + ) +endif() + add_lit_testsuite(check-mlir "Running the MLIR regression tests" ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${MLIR_TEST_DEPENDS} diff --git a/test/Conversion/GPUToCUDA/insert-cubin-getter.mlir b/test/Conversion/GPUToCUDA/insert-cubin-getter.mlir index 19090ce87b59..c8814e2cb7cf 100644 --- a/test/Conversion/GPUToCUDA/insert-cubin-getter.mlir +++ b/test/Conversion/GPUToCUDA/insert-cubin-getter.mlir @@ -4,10 +4,10 @@ func @kernel(!llvm.float, !llvm<"float*">) // CHECK: attributes {gpu.kernel, nvvm.cubin = "CUBIN", nvvm.cubingetter = @kernel_cubin} attributes {gpu.kernel, nvvm.cubin = "CUBIN"} -// CHECK: func @mcuMalloc(!llvm.i32) -> !llvm<"i8*"> +// CHECK: func @malloc(!llvm.i64) -> !llvm<"i8*"> // CHECK: func @kernel_cubin() -> !llvm<"i8*"> -// CHECK-NEXT: %0 = llvm.constant(5 : i32) : !llvm.i32 -// CHECK-NEXT: %1 = llvm.call @mcuMalloc(%0) : (!llvm.i32) -> !llvm<"i8*"> +// CHECK-NEXT: %0 = llvm.constant(5 : index) : !llvm.i64 +// CHECK-NEXT: %1 = llvm.call @malloc(%0) : (!llvm.i64) -> !llvm<"i8*"> // CHECK-NEXT: %2 = llvm.constant(0 : i32) : !llvm.i32 // CHECK-NEXT: %3 = llvm.getelementptr %1[%2] : (!llvm<"i8*">, !llvm.i32) -> !llvm<"i8*"> // CHECK-NEXT: %4 = llvm.constant(67 : i8) : !llvm.i8 diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 3e6dfc3c538f..cf938946289d 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -67,6 +67,7 @@ ToolSubst('toy-ch4', unresolved='ignore'), ToolSubst('toy-ch5', unresolved='ignore'), ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'), + ToolSubst('%cuda_wrapper_library_dir', config.cuda_wrapper_library_dir, unresolved='ignore') ]) llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 2388ef0a5abf..830b65fdd3b3 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -33,6 +33,8 @@ config.mlir_tools_dir = "@MLIR_TOOLS_DIR@" config.linalg_test_lib_dir = "@MLIR_LINALG_INTEGRATION_TEST_LIB_DIR@" config.build_examples = @LLVM_BUILD_EXAMPLES@ config.run_cuda_tests = @MLIR_CUDA_CONVERSIONS_ENABLED@ +config.cuda_wrapper_library_dir = "@MLIR_CUDA_WRAPPER_LIBRARY_DIR@" +config.enable_cuda_runner = @MLIR_CUDA_RUNNER_ENABLED@ # Support substitution of the tools_dir with user parameters. This is # used when we can't determine the tool dir at configuration time. diff --git a/test/mlir-cuda-runner/gpu-to-cubin.mlir b/test/mlir-cuda-runner/gpu-to-cubin.mlir new file mode 100644 index 000000000000..6610337b17db --- /dev/null +++ b/test/mlir-cuda-runner/gpu-to-cubin.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext | FileCheck %s + +func @other_func(%arg0 : f32, %arg1 : memref) { + %cst = constant 1 : index + %cst2 = dim %arg1, 0 : memref + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst) + threads(%tx, %ty, %tz) in (%block_x = %cst2, %block_y = %cst, %block_z = %cst) + args(%kernel_arg0 = %arg0, %kernel_arg1 = %arg1) : f32, memref { + store %kernel_arg0, %kernel_arg1[%tx] : memref + gpu.return + } + return +} + +// CHECK: [1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00] +func @main() { + %arg0 = alloc() : memref<5xf32> + %20 = constant 0 : i32 + %21 = constant 5 : i32 + %22 = memref_cast %arg0 : memref<5xf32> to memref + call @mcuMemHostRegister(%22, %20) : (memref, i32) -> () + call @mcuPrintFloat(%22) : (memref) -> () + %24 = constant 1.0 : f32 + call @other_func(%24, %22) : (f32, memref) -> () + call @mcuPrintFloat(%22) : (memref) -> () + return +} + +func @mcuMemHostRegister(%ptr : memref, %flags : i32) +func @mcuPrintFloat(%ptr : memref) diff --git a/test/mlir-cuda-runner/lit.local.cfg b/test/mlir-cuda-runner/lit.local.cfg new file mode 100644 index 000000000000..b063ddda7e1d --- /dev/null +++ b/test/mlir-cuda-runner/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_cuda_runner: + config.unsupported = True \ No newline at end of file diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index a2e4c6fe5c9f..2566dd872883 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(mlir-cuda-runner) add_subdirectory(mlir-cpu-runner) add_subdirectory(mlir-opt) add_subdirectory(mlir-tblgen) diff --git a/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp b/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp index 1b1fbccdf629..86e673b13627 100644 --- a/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp +++ b/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp @@ -34,6 +34,7 @@ #include "mlir/Support/FileUtilities.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassNameParser.h" @@ -249,7 +250,12 @@ static Error compileAndExecuteSingleFloatReturnFunction( return Error::success(); } -int run(int argc, char **argv) { +// Entry point for all CPU runners. Expects the common argc/argv arguments for +// standard C++ main functions and an mlirTransformer. +// The latter is applied after parsing the input into MLIR IR and before passing +// the MLIR module to the ExecutionEngine. +int run(int argc, char **argv, + llvm::function_ref mlirTransformer) { llvm::PrettyStackTraceProgram x(argc, argv); llvm::InitLLVM y(argc, argv); @@ -292,6 +298,10 @@ int run(int argc, char **argv) { return 1; } + if (mlirTransformer) + if (failed(mlirTransformer(m.get()))) + return EXIT_FAILURE; + auto transformer = mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition); auto error = mainFuncType.getValue() == "f32" diff --git a/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/tools/mlir-cpu-runner/mlir-cpu-runner.cpp index 43f4eab31d81..e7ac071534fc 100644 --- a/tools/mlir-cpu-runner/mlir-cpu-runner.cpp +++ b/tools/mlir-cpu-runner/mlir-cpu-runner.cpp @@ -20,6 +20,16 @@ // latter. // //===----------------------------------------------------------------------===// -extern int run(int argc, char **argv); -int main(int argc, char **argv) { return run(argc, argv); } +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +class ModuleOp; +struct LogicalResult; +} // namespace mlir + +// TODO(herhut) Factor out into an include file and proper library. +extern int run(int argc, char **argv, + llvm::function_ref); + +int main(int argc, char **argv) { return run(argc, argv, nullptr); } diff --git a/tools/mlir-cuda-runner/CMakeLists.txt b/tools/mlir-cuda-runner/CMakeLists.txt new file mode 100644 index 000000000000..826076bb6e54 --- /dev/null +++ b/tools/mlir-cuda-runner/CMakeLists.txt @@ -0,0 +1,74 @@ +set(LLVM_OPTIONAL_SOURCES + cuda-runtime-wrappers.cpp + mlir-cuda-runner.cpp + ) + +if(MLIR_CUDA_RUNNER_ENABLED) + if (NOT ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)) + message(SEND_ERROR + "Building the mlir cuda runner requires the NVPTX backend") + endif() + + # Configure CUDA runner support. Using check_language first allows us to give + # a custom error message. + include(CheckLanguage) + check_language(CUDA) + if (CMAKE_CUDA_COMPILER) + enable_language(CUDA) + else() + message(SEND_ERROR + "Building the mlir cuda runner requires a working CUDA install") + endif() + + # We need the libcuda.so library. + find_library(CUDA_RUNTIME_LIBRARY cuda) + + add_llvm_library(cuda-runtime-wrappers SHARED + cuda-runtime-wrappers.cpp + ) + target_include_directories(cuda-runtime-wrappers + PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + LLVMSupport + ) + target_link_libraries(cuda-runtime-wrappers + LLVMSupport + ${CUDA_RUNTIME_LIBRARY} + ) + + set(FULL_LINK_LIBS + MLIRAffineOps + MLIRGPU + MLIRGPUtoCUDATransforms + MLIRGPUtoNVVMTransforms + MLIRLLVMIR + MLIRStandardOps + MLIRStandardToLLVM + MLIRTargetLLVMIR + MLIRTransforms + MLIRTranslation + ) + set(LIBS + MLIRIR + MLIRParser + MLIREDSC + MLIRAnalysis + MLIRCPURunnerLib + MLIRExecutionEngine + MLIRSupport + LLVMCore + LLVMSupport + ${CUDA_RUNTIME_LIBRARY} + ) + add_llvm_executable(mlir-cuda-runner + mlir-cuda-runner.cpp + ) + add_dependencies(mlir-cuda-runner cuda-runtime-wrappers) + target_include_directories(mlir-cuda-runner + PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ) + llvm_update_compile_flags(mlir-cuda-runner) + whole_archive_link(mlir-cuda-runner ${FULL_LINK_LIBS}) + target_link_libraries(mlir-cuda-runner PRIVATE ${FULL_LINK_LIBS} ${LIBS}) + + +endif() diff --git a/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp new file mode 100644 index 000000000000..795f04e0020f --- /dev/null +++ b/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -0,0 +1,107 @@ +//===- cuda-runtime-wrappers.cpp - MLIR CUDA runner wrapper library -------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Implements C wrappers around the CUDA library for easy linking in ORC jit. +// Also adds some debugging helpers that are helpful when writing MLIR code to +// run on GPUs. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "llvm/Support/raw_ostream.h" + +#include "cuda.h" + +namespace { +int32_t reportErrorIfAny(CUresult result, const char *where) { + if (result != CUDA_SUCCESS) { + llvm::errs() << "CUDA failed with " << result << " in " << where << "\n"; + } + return result; +} +} // anonymous namespace + +extern "C" int32_t mcuModuleLoad(void **module, void *data) { + int32_t err = reportErrorIfAny( + cuModuleLoadData(reinterpret_cast(module), data), + "ModuleLoad"); + return err; +} + +extern "C" int32_t mcuModuleGetFunction(void **function, void *module, + const char *name) { + return reportErrorIfAny( + cuModuleGetFunction(reinterpret_cast(function), + reinterpret_cast(module), name), + "GetFunction"); +} + +// The wrapper uses intptr_t instead of CUDA's unsigned int to match +// the type of MLIR's index type. This avoids the need for casts in the +// generated MLIR code. +extern "C" int32_t mcuLaunchKernel(void *function, intptr_t gridX, + intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, void *stream, + void **params, void **extra) { + return reportErrorIfAny( + cuLaunchKernel(reinterpret_cast(function), gridX, gridY, + gridZ, blockX, blockY, blockZ, smem, + reinterpret_cast(stream), params, extra), + "LaunchKernel"); +} + +extern "C" void *mcuGetStreamHelper() { + CUstream stream; + reportErrorIfAny(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "StreamCreate"); + return stream; +} + +extern "C" int32_t mcuStreamSynchronize(void *stream) { + return reportErrorIfAny( + cuStreamSynchronize(reinterpret_cast(stream)), "StreamSync"); +} + +/// Helper functions for writing mlir example code + +// A struct that corresponds to how MLIR represents unknown-length 1d memrefs. +struct memref_t { + float *values; + intptr_t length; +}; + +// Allows to register a pointer with the CUDA runtime. Helpful until +// we have transfer functions implemented. +extern "C" void mcuMemHostRegister(const memref_t arg, int32_t flags) { + reportErrorIfAny(cuMemHostRegister(arg.values, arg.length, flags), + "MemHostRegister"); +} + +/// Prints the given float array to stderr. +extern "C" void mcuPrintFloat(const memref_t arg) { + if (arg.length == 0) { + llvm::outs() << "[]\n"; + return; + } + llvm::outs() << "[" << arg.values[0]; + for (int pos = 1; pos < arg.length; pos++) { + llvm::outs() << ", " << arg.values[pos]; + } + llvm::outs() << "]\n"; +} diff --git a/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/tools/mlir-cuda-runner/mlir-cuda-runner.cpp new file mode 100644 index 000000000000..fd66bf9dfbd1 --- /dev/null +++ b/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -0,0 +1,158 @@ +//===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is a command line utility that executes an MLIR file on the GPU by +// translating MLIR to NVVM/LVVM IR before JIT-compiling and executing the +// latter. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/STLExtras.h" + +#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/GPU/GPUDialect.h" +#include "mlir/GPU/Passes.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "cuda.h" + +using namespace mlir; + +// TODO(herhut) Factor out into an include file and proper library. +extern int run(int argc, char **argv, + llvm::function_ref); + +inline void emit_cuda_error(const llvm::Twine &message, const char *buffer, + CUresult error, Function &function) { + function.emitError(message.concat(" failed with error code ") + .concat(llvm::Twine{error}) + .concat("[") + .concat(buffer) + .concat("]")); +} + +#define RETURN_ON_CUDA_ERROR(expr, msg) \ + { \ + auto _cuda_error = (expr); \ + if (_cuda_error != CUDA_SUCCESS) { \ + emit_cuda_error(msg, jitErrorBuffer, _cuda_error, function); \ + return {}; \ + } \ + } + +OwnedCubin compilePtxToCubin(const std::string ptx, Function &function) { + char jitErrorBuffer[4096] = {0}; + + RETURN_ON_CUDA_ERROR(cuInit(0), "cuInit"); + + // Linking requires a device context. + CUdevice device; + RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0), "cuDeviceGet"); + CUcontext context; + RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device), "cuCtxCreate"); + CUlinkState linkState; + + CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES}; + void *jitOptionsVals[] = {jitErrorBuffer, + reinterpret_cast(sizeof(jitErrorBuffer))}; + + RETURN_ON_CUDA_ERROR(cuLinkCreate(2, /* number of jit options */ + jitOptions, /* jit options */ + jitOptionsVals, /* jit option values */ + &linkState), + "cuLinkCreate"); + + RETURN_ON_CUDA_ERROR( + cuLinkAddData(linkState, CUjitInputType::CU_JIT_INPUT_PTX, + const_cast(static_cast(ptx.c_str())), + ptx.length(), function.getName().data(), /* kernel name */ + 0, /* number of jit options */ + nullptr, /* jit options */ + nullptr /* jit option values */ + ), + "cuLinkAddData"); + + void *cubinData; + size_t cubinSize; + RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize), + "cuLinkComplete"); + + char *cubinAsChar = static_cast(cubinData); + OwnedCubin result = llvm::make_unique>( + cubinAsChar, cubinAsChar + cubinSize); + + // This will also destroy the cubin data. + RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState), "cuLinkDestroy"); + + return result; +} + +namespace { +struct GPULaunchFuncOpLowering : public LLVMOpLowering { +public: + explicit GPULaunchFuncOpLowering(LLVMTypeConverter &lowering_) + : LLVMOpLowering(gpu::LaunchFuncOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_) {} + + // Convert the kernel arguments to an LLVM type, preserve the rest. + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { + auto launchOp = dyn_cast(rewriter.clone(*op)); + + for (auto operand : llvm::enumerate(operands)) + launchOp.setOperand(operand.index(), operand.value()); + + return rewriter.replaceOp(op, llvm::None), this->matchSuccess(); + } +}; +} // end anonymous namespace + +static LogicalResult runMLIRPasses(Module m) { + // As we gradually lower, the IR is inconsistent between passes. So do not + // verify inbetween. + PassManager pm(/*verifyPasses=*/false); + + pm.addPass(createGpuKernelOutliningPass()); + pm.addPass(createConvertToLLVMIRPass([](LLVMTypeConverter &converter, + OwningRewritePatternList &patterns) { + populateStdToLLVMConversionPatterns(converter, patterns); + patterns.push_back(llvm::make_unique(converter)); + })); + pm.addPass(createLowerGpuOpsToNVVMOpsPass()); + pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin)); + pm.addPass(createGenerateCubinAccessorPass()); + pm.addPass(createConvertGpuLaunchFuncToCudaCallsPass()); + + if (failed(pm.run(m))) + return failure(); + + if (failed(m.verify())) + return failure(); + + return success(); +} + +int main(int argc, char **argv) { return run(argc, argv, &runMLIRPasses); }