Skip to content

Commit

Permalink
Add an mlir-cuda-runner tool.
Browse files Browse the repository at this point in the history
This tool allows to execute MLIR IR snippets written in the GPU dialect
on a CUDA capable GPU. For this to work, a working CUDA install is required
and the build has to be configured with MLIR_CUDA_RUNNER_ENABLED set to 1.

PiperOrigin-RevId: 256551415
  • Loading branch information
Stephan Herhut authored and jpienaar committed Jul 4, 2019
1 parent bf1d5b2 commit 315de6a
Show file tree
Hide file tree
Showing 15 changed files with 428 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ else()
set(MLIR_CUDA_CONVERSIONS_ENABLED 0)
endif()

set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner")

include_directories( "include")
include_directories( ${MLIR_INCLUDE_DIR})

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule,
return {};
}
targetMachine.reset(
target->createTargetMachine(triple.str(), "sm_75", "+ptx60", {}, {}));
target->createTargetMachine(triple.str(), "sm_35", "+ptx60", {}, {}));
}

// Set the data layout of the llvm module to match what the ptx target needs.
Expand Down
23 changes: 14 additions & 9 deletions lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,32 @@ namespace {
constexpr const char *kCubinAnnotation = "nvvm.cubin";
constexpr const char *kCubinGetterAnnotation = "nvvm.cubingetter";
constexpr const char *kCubinGetterSuffix = "_cubin";
constexpr const char *kMallocHelperName = "mcuMalloc";
constexpr const char *kMallocHelperName = "malloc";

/// A pass generating getter functions for all cubin blobs annotated on
/// functions via the nvvm.cubin attribute.
///
/// The functions allocate memory using a mcuMalloc helper function with the
/// signature void *mcuMalloc(int32_t size). This function has to be provided
/// by the actual runner that executes the generated code.
/// The functions allocate memory using the system malloc call with signature
/// void *malloc(size_t size). This function has to be provided by the actual
/// runner that executes the generated code.
///
/// This is a stop-gap measure until MLIR supports global constants.
class GpuGenerateCubinAccessorsPass
: public ModulePass<GpuGenerateCubinAccessorsPass> {
private:
LLVM::LLVMType getIndexType() {
unsigned bits =
llvmDialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
return LLVM::LLVMType::getIntNTy(llvmDialect, bits);
}

Function getMallocHelper(Location loc, Builder &builder) {
Function result = getModule().getNamedFunction(kMallocHelperName);
if (!result) {
result = Function::create(
loc, kMallocHelperName,
builder.getFunctionType(
ArrayRef<Type>{LLVM::LLVMType::getInt32Ty(llvmDialect)},
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
builder.getFunctionType(ArrayRef<Type>{getIndexType()},
LLVM::LLVMType::getInt8PtrTy(llvmDialect)));
getModule().push_back(result);
}
return result;
Expand All @@ -84,8 +89,8 @@ class GpuGenerateCubinAccessorsPass
OpBuilder ob(result.getBody());
ob.createBlock();
auto sizeConstant = ob.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt32Ty(llvmDialect),
builder.getI32IntegerAttr(blob.getValue().size()));
loc, getIndexType(),
builder.getIntegerAttr(builder.getIndexType(), blob.getValue().size()));
auto memory =
ob.create<LLVM::CallOp>(
loc, ArrayRef<Type>{LLVM::LLVMType::getInt8PtrTy(llvmDialect)},
Expand Down
10 changes: 10 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ llvm_canonicalize_cmake_booleans(
# for linalg integration tests.
set(MLIR_LINALG_INTEGRATION_TEST_LIB_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})

# Passed to lit.site.cfg.py.in to set up the path where to find the libraries
# for the mlir cuda runner tests.
set(MLIR_CUDA_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})

configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
Expand Down Expand Up @@ -48,6 +52,12 @@ if(LLVM_BUILD_EXAMPLES)
)
endif()

if(MLIR_CUDA_RUNNER_ENABLED)
list(APPEND MLIR_TEST_DEPENDS
mlir-cuda-runner
)
endif()

add_lit_testsuite(check-mlir "Running the MLIR regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${MLIR_TEST_DEPENDS}
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/GPUToCUDA/insert-cubin-getter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ func @kernel(!llvm.float, !llvm<"float*">)
// CHECK: attributes {gpu.kernel, nvvm.cubin = "CUBIN", nvvm.cubingetter = @kernel_cubin}
attributes {gpu.kernel, nvvm.cubin = "CUBIN"}

// CHECK: func @mcuMalloc(!llvm.i32) -> !llvm<"i8*">
// CHECK: func @malloc(!llvm.i64) -> !llvm<"i8*">
// CHECK: func @kernel_cubin() -> !llvm<"i8*">
// CHECK-NEXT: %0 = llvm.constant(5 : i32) : !llvm.i32
// CHECK-NEXT: %1 = llvm.call @mcuMalloc(%0) : (!llvm.i32) -> !llvm<"i8*">
// CHECK-NEXT: %0 = llvm.constant(5 : index) : !llvm.i64
// CHECK-NEXT: %1 = llvm.call @malloc(%0) : (!llvm.i64) -> !llvm<"i8*">
// CHECK-NEXT: %2 = llvm.constant(0 : i32) : !llvm.i32
// CHECK-NEXT: %3 = llvm.getelementptr %1[%2] : (!llvm<"i8*">, !llvm.i32) -> !llvm<"i8*">
// CHECK-NEXT: %4 = llvm.constant(67 : i8) : !llvm.i8
Expand Down
1 change: 1 addition & 0 deletions test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
ToolSubst('toy-ch4', unresolved='ignore'),
ToolSubst('toy-ch5', unresolved='ignore'),
ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'),
ToolSubst('%cuda_wrapper_library_dir', config.cuda_wrapper_library_dir, unresolved='ignore')
])

llvm_config.add_tool_substitutions(tools, tool_dirs)
2 changes: 2 additions & 0 deletions test/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.linalg_test_lib_dir = "@MLIR_LINALG_INTEGRATION_TEST_LIB_DIR@"
config.build_examples = @LLVM_BUILD_EXAMPLES@
config.run_cuda_tests = @MLIR_CUDA_CONVERSIONS_ENABLED@
config.cuda_wrapper_library_dir = "@MLIR_CUDA_WRAPPER_LIBRARY_DIR@"
config.enable_cuda_runner = @MLIR_CUDA_RUNNER_ENABLED@

# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.
Expand Down
30 changes: 30 additions & 0 deletions test/mlir-cuda-runner/gpu-to-cubin.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext | FileCheck %s

func @other_func(%arg0 : f32, %arg1 : memref<?xf32>) {
%cst = constant 1 : index
%cst2 = dim %arg1, 0 : memref<?xf32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
threads(%tx, %ty, %tz) in (%block_x = %cst2, %block_y = %cst, %block_z = %cst)
args(%kernel_arg0 = %arg0, %kernel_arg1 = %arg1) : f32, memref<?xf32> {
store %kernel_arg0, %kernel_arg1[%tx] : memref<?xf32>
gpu.return
}
return
}

// CHECK: [1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00]
func @main() {
%arg0 = alloc() : memref<5xf32>
%20 = constant 0 : i32
%21 = constant 5 : i32
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
call @mcuMemHostRegister(%22, %20) : (memref<?xf32>, i32) -> ()
call @mcuPrintFloat(%22) : (memref<?xf32>) -> ()
%24 = constant 1.0 : f32
call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
call @mcuPrintFloat(%22) : (memref<?xf32>) -> ()
return
}

func @mcuMemHostRegister(%ptr : memref<?xf32>, %flags : i32)
func @mcuPrintFloat(%ptr : memref<?xf32>)
2 changes: 2 additions & 0 deletions test/mlir-cuda-runner/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
if not config.enable_cuda_runner:
config.unsupported = True
1 change: 1 addition & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(mlir-cuda-runner)
add_subdirectory(mlir-cpu-runner)
add_subdirectory(mlir-opt)
add_subdirectory(mlir-tblgen)
Expand Down
12 changes: 11 additions & 1 deletion tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "mlir/Support/FileUtilities.h"
#include "mlir/Transforms/Passes.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassNameParser.h"
Expand Down Expand Up @@ -249,7 +250,12 @@ static Error compileAndExecuteSingleFloatReturnFunction(
return Error::success();
}

int run(int argc, char **argv) {
// Entry point for all CPU runners. Expects the common argc/argv arguments for
// standard C++ main functions and an mlirTransformer.
// The latter is applied after parsing the input into MLIR IR and before passing
// the MLIR module to the ExecutionEngine.
int run(int argc, char **argv,
llvm::function_ref<LogicalResult(mlir::Module)> mlirTransformer) {
llvm::PrettyStackTraceProgram x(argc, argv);
llvm::InitLLVM y(argc, argv);

Expand Down Expand Up @@ -292,6 +298,10 @@ int run(int argc, char **argv) {
return 1;
}

if (mlirTransformer)
if (failed(mlirTransformer(m.get())))
return EXIT_FAILURE;

auto transformer =
mlir::makeLLVMPassesTransformer(passes, optLevel, optPosition);
auto error = mainFuncType.getValue() == "f32"
Expand Down
14 changes: 12 additions & 2 deletions tools/mlir-cpu-runner/mlir-cpu-runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
// latter.
//
//===----------------------------------------------------------------------===//
extern int run(int argc, char **argv);

int main(int argc, char **argv) { return run(argc, argv); }
#include "llvm/ADT/STLExtras.h"

namespace mlir {
class ModuleOp;
struct LogicalResult;
} // namespace mlir

// TODO(herhut) Factor out into an include file and proper library.
extern int run(int argc, char **argv,
llvm::function_ref<mlir::LogicalResult(mlir::ModuleOp)>);

int main(int argc, char **argv) { return run(argc, argv, nullptr); }
74 changes: 74 additions & 0 deletions tools/mlir-cuda-runner/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
set(LLVM_OPTIONAL_SOURCES
cuda-runtime-wrappers.cpp
mlir-cuda-runner.cpp
)

if(MLIR_CUDA_RUNNER_ENABLED)
if (NOT ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD))
message(SEND_ERROR
"Building the mlir cuda runner requires the NVPTX backend")
endif()

# Configure CUDA runner support. Using check_language first allows us to give
# a custom error message.
include(CheckLanguage)
check_language(CUDA)
if (CMAKE_CUDA_COMPILER)
enable_language(CUDA)
else()
message(SEND_ERROR
"Building the mlir cuda runner requires a working CUDA install")
endif()

# We need the libcuda.so library.
find_library(CUDA_RUNTIME_LIBRARY cuda)

add_llvm_library(cuda-runtime-wrappers SHARED
cuda-runtime-wrappers.cpp
)
target_include_directories(cuda-runtime-wrappers
PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
LLVMSupport
)
target_link_libraries(cuda-runtime-wrappers
LLVMSupport
${CUDA_RUNTIME_LIBRARY}
)

set(FULL_LINK_LIBS
MLIRAffineOps
MLIRGPU
MLIRGPUtoCUDATransforms
MLIRGPUtoNVVMTransforms
MLIRLLVMIR
MLIRStandardOps
MLIRStandardToLLVM
MLIRTargetLLVMIR
MLIRTransforms
MLIRTranslation
)
set(LIBS
MLIRIR
MLIRParser
MLIREDSC
MLIRAnalysis
MLIRCPURunnerLib
MLIRExecutionEngine
MLIRSupport
LLVMCore
LLVMSupport
${CUDA_RUNTIME_LIBRARY}
)
add_llvm_executable(mlir-cuda-runner
mlir-cuda-runner.cpp
)
add_dependencies(mlir-cuda-runner cuda-runtime-wrappers)
target_include_directories(mlir-cuda-runner
PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
)
llvm_update_compile_flags(mlir-cuda-runner)
whole_archive_link(mlir-cuda-runner ${FULL_LINK_LIBS})
target_link_libraries(mlir-cuda-runner PRIVATE ${FULL_LINK_LIBS} ${LIBS})


endif()
107 changes: 107 additions & 0 deletions tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
//===- cuda-runtime-wrappers.cpp - MLIR CUDA runner wrapper library -------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Implements C wrappers around the CUDA library for easy linking in ORC jit.
// Also adds some debugging helpers that are helpful when writing MLIR code to
// run on GPUs.
//
//===----------------------------------------------------------------------===//

#include <assert.h>
#include <memory.h>

#include "llvm/Support/raw_ostream.h"

#include "cuda.h"

namespace {
int32_t reportErrorIfAny(CUresult result, const char *where) {
if (result != CUDA_SUCCESS) {
llvm::errs() << "CUDA failed with " << result << " in " << where << "\n";
}
return result;
}
} // anonymous namespace

extern "C" int32_t mcuModuleLoad(void **module, void *data) {
int32_t err = reportErrorIfAny(
cuModuleLoadData(reinterpret_cast<CUmodule *>(module), data),
"ModuleLoad");
return err;
}

extern "C" int32_t mcuModuleGetFunction(void **function, void *module,
const char *name) {
return reportErrorIfAny(
cuModuleGetFunction(reinterpret_cast<CUfunction *>(function),
reinterpret_cast<CUmodule>(module), name),
"GetFunction");
}

// The wrapper uses intptr_t instead of CUDA's unsigned int to match
// the type of MLIR's index type. This avoids the need for casts in the
// generated MLIR code.
extern "C" int32_t mcuLaunchKernel(void *function, intptr_t gridX,
intptr_t gridY, intptr_t gridZ,
intptr_t blockX, intptr_t blockY,
intptr_t blockZ, int32_t smem, void *stream,
void **params, void **extra) {
return reportErrorIfAny(
cuLaunchKernel(reinterpret_cast<CUfunction>(function), gridX, gridY,
gridZ, blockX, blockY, blockZ, smem,
reinterpret_cast<CUstream>(stream), params, extra),
"LaunchKernel");
}

extern "C" void *mcuGetStreamHelper() {
CUstream stream;
reportErrorIfAny(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "StreamCreate");
return stream;
}

extern "C" int32_t mcuStreamSynchronize(void *stream) {
return reportErrorIfAny(
cuStreamSynchronize(reinterpret_cast<CUstream>(stream)), "StreamSync");
}

/// Helper functions for writing mlir example code

// A struct that corresponds to how MLIR represents unknown-length 1d memrefs.
struct memref_t {
float *values;
intptr_t length;
};

// Allows to register a pointer with the CUDA runtime. Helpful until
// we have transfer functions implemented.
extern "C" void mcuMemHostRegister(const memref_t arg, int32_t flags) {
reportErrorIfAny(cuMemHostRegister(arg.values, arg.length, flags),
"MemHostRegister");
}

/// Prints the given float array to stderr.
extern "C" void mcuPrintFloat(const memref_t arg) {
if (arg.length == 0) {
llvm::outs() << "[]\n";
return;
}
llvm::outs() << "[" << arg.values[0];
for (int pos = 1; pos < arg.length; pos++) {
llvm::outs() << ", " << arg.values[pos];
}
llvm::outs() << "]\n";
}
Loading

0 comments on commit 315de6a

Please sign in to comment.