Skip to content

Commit

Permalink
MLIR bazel build and test (#1618)
Browse files Browse the repository at this point in the history
* MLIR bazel build and test

* Fix MLIR memory bug

* fix

* print errs

* fix
  • Loading branch information
wsmoses authored Jan 24, 2024
1 parent 0d8c77c commit 304e21b
Show file tree
Hide file tree
Showing 125 changed files with 451 additions and 275 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/enzyme-bazel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,21 @@ jobs:
timeout-minutes: 500
steps:

- name: Prep
run: |
python -m pip install lit
- uses: actions/checkout@v4
- uses: actions/checkout@v4
with:
repository: 'llvm/llvm-project'
path: 'llvm-project'

- name: cmake
- name: Build
run: |
cd enzyme
bazel build :EnzymeStatic :enzymemlir-opt
- name: Test
run: |
cd enzyme
bazel build :EnzymeStatic
bazel test --test_output=errors ...
2 changes: 1 addition & 1 deletion .github/workflows/enzyme-mlir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- uses: actions/checkout@v3
with:
repository: 'llvm/llvm-project'
ref: '5ed11e767c0c39a3bc8e035588e7a383849d46a8'
ref: 'bc82cfb38d83f1afeb2c290aa472c2e2e88919cb'
path: 'llvm-project'

- name: Get MLIR commit hash
Expand Down
257 changes: 257 additions & 0 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
load("@llvm-project//llvm:tblgen.bzl", "gentbl")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path")
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")

licenses(["notice"])

Expand Down Expand Up @@ -191,6 +194,7 @@ cc_library(
"@llvm-project//llvm:TransformUtils",
"@llvm-project//llvm:config",
],
alwayslink = 1
)

cc_binary(
Expand All @@ -213,3 +217,256 @@ genrule(
cmd = "cp $< $@",
output_to_bindir = 1,
)

td_library(
name = "EnzymeDialectTdFiles",
srcs = [
"Enzyme/MLIR/Dialect/Dialect.td",
],
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
"@llvm-project//mlir:ViewLikeInterfaceTdFiles",
"@llvm-project//mlir:FunctionInterfacesTdFiles",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:LoopLikeInterfaceTdFiles",
]
)

gentbl_cc_library(
name = "EnzymeOpsIncGen",
tbl_outs = [
(
["-gen-op-decls"],
"Enzyme/MLIR/Dialect/EnzymeOps.h.inc",
),
(
["-gen-op-defs"],
"Enzyme/MLIR/Dialect/EnzymeOps.cpp.inc",
),
(
[
"-gen-dialect-decls",
"-dialect=enzyme",
],
"Enzyme/MLIR/Dialect/EnzymeOpsDialect.h.inc",
),
(
[
"-gen-dialect-defs",
"-dialect=enzyme",
],
"Enzyme/MLIR/Dialect/EnzymeOpsDialect.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/EnzymeOps.td",
deps = [":EnzymeDialectTdFiles"],
)

td_library(
name = "EnzymePassesTdFiles",
srcs = [
],
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
]
)

gentbl_cc_library(
name = "EnzymePassesIncGen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=enzyme",
],
"Enzyme/MLIR/Passes/Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Passes/Passes.td",
deps = [":EnzymePassesTdFiles"],
)

gentbl_cc_library(
name = "EnzymeTypesIncGen",
tbl_outs = [
(
["-gen-typedef-decls"],
"Enzyme/MLIR/Dialect/EnzymeOpsTypes.h.inc",
),
(
["-gen-typedef-defs"],
"Enzyme/MLIR/Dialect/EnzymeOpsTypes.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/EnzymeOps.td",
deps = [":EnzymeDialectTdFiles"],
)

gentbl_cc_library(
name = "EnzymeEnumsIncGen",
tbl_outs = [
(
["-gen-enum-decls"],
"Enzyme/MLIR/Dialect/EnzymeEnums.h.inc",
),
(
["-gen-enum-defs"],
"Enzyme/MLIR/Dialect/EnzymeEnums.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/EnzymeOps.td",
deps = [":EnzymeDialectTdFiles"],
)

gentbl_cc_library(
name = "EnzymeAttributesIncGen",
tbl_outs = [
(
["-gen-attrdef-decls"],
"Enzyme/MLIR/Dialect/EnzymeAttributes.h.inc",
),
(
["-gen-attrdef-defs"],
"Enzyme/MLIR/Dialect/EnzymeAttributes.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/EnzymeOps.td",
deps = [":EnzymeDialectTdFiles"],
)


gentbl_cc_library(
name = "EnzymeTypeInterfacesIncGen",
tbl_outs = [
(
["--gen-type-interface-decls"],
"Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h.inc",
),
(
["--gen-type-interface-defs"],
"Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td",
deps = [":EnzymeDialectTdFiles"],
)

gentbl_cc_library(
name = "EnzymeOpInterfacesIncGen",
tbl_outs = [
(
["--gen-op-interface-decls"],
"Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h.inc",
),
(
["--gen-op-interface-defs"],
"Enzyme/MLIR/Interfaces/AutoDiffOpInterface.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td",
deps = [":EnzymeDialectTdFiles"],
)

cc_library(
name = "EnzymeMLIR",
srcs = glob([
"Enzyme/MLIR/Dialect/*.cpp",
"Enzyme/MLIR/Passes/*.cpp",
"Enzyme/MLIR/Interfaces/*.cpp",
"Enzyme/MLIR/Analysis/*.cpp",
"Enzyme/MLIR/Implementations/*.cpp",
]),
hdrs = glob([
"Enzyme/MLIR/Dialect/*.h",
"Enzyme/MLIR/Passes/*.h",
"Enzyme/MLIR/Interfaces/*.h",
"Enzyme/MLIR/Analysis/*.h",
"Enzyme/MLIR/Implementations/*.h",
"Enzyme/Utils.h",
"Enzyme/TypeAnalysis/*.h"
]),
includes = ["Enzyme/MLIR", "Enzyme"],
visibility = ["//visibility:public"],
deps = [
":EnzymeOpsIncGen",
":EnzymePassesIncGen",
":EnzymeTypesIncGen",
":EnzymeEnumsIncGen",
":EnzymeAttributesIncGen",
":EnzymeTypeInterfacesIncGen",
":EnzymeOpInterfacesIncGen",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:AsyncDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
],
)

cc_binary(
name = "enzymemlir-opt",
srcs = ["Enzyme/MLIR/enzymemlir-opt.cpp"],
visibility = ["//visibility:public"],
includes = ["Enzyme/MLIR"],
deps = [
":EnzymeMLIR",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:AllPassesAndDialects",
],
)

# Generates lit config input file by applying path placeholder substitutions
# similar to the configure_lit_site_cfg CMake macro.
expand_template(
name = "lit_site_cfg_py",
testonly = True,
out = "test/lit.site.cfg.py",
substitutions = {
"@LLVM_VERSION_MAJOR@": "18",
"@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.",
"@LLVM_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"),
"@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"),
"@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"),
"@ENZYME_SOURCE_DIR@": "",
"@ENZYME_BINARY_DIR@": "",
"@TARGET_TRIPLE@": "",
"@TARGETS_TO_BUILD@": "ALL",
"@LLVM_SHLIBEXT@": ".so",
},
template = "test/lit.site.cfg.py.in",
visibility = ["//visibility:private"],
)

[
lit_test(
name = "%s.test" % src,
srcs = [src],
data = [
":test/lit.cfg.py",
":test/lit.site.cfg.py",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:count",
"@llvm-project//llvm:not",
"@llvm-project//llvm:lli",
"@llvm-project//llvm:opt",
"@llvm-project//clang:builtin_headers_gen",
":enzyme-clang",
":enzyme-clang++",
":enzymemlir-opt"
] + glob(["test/**/*.h"])
)
for src in glob(["test/**/*.mlir"])
]
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct GenericOpInterfaceReverse
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
assert(linalgOp.hasPureBufferSemantics() &&
"Linalg op with tensor semantics not yet supported");

linalg::LinalgOp newOp =
Expand Down Expand Up @@ -278,4 +278,4 @@ void mlir::enzyme::registerLinalgDialectAutoDiffInterface(
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(context);
});
}
}
28 changes: 16 additions & 12 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ void MEnzymeLogic::handlePredecessors(
} else {
SmallVector<Block *> blocks;
SmallVector<APInt> indices;
SmallVector<ValueRange> arguments;
SmallVector<SmallVector<Value>> arguments;
SmallVector<Value> defaultArguments;
Block *defaultBlock;
int i = 1;
for (Block *predecessor : oBB->getPredecessors()) {
Block *defaultBlock = nullptr;
for (auto pair : llvm::enumerate(oBB->getPredecessors())) {
auto predecessor = pair.value();
auto idx = pair.index();
Block *predecessorRevMode =
gutils->mapReverseModeBlocks.lookupOrNull(predecessor);

Expand All @@ -250,10 +251,10 @@ void MEnzymeLogic::handlePredecessors(
}
}
}
if (predecessor != *(oBB->getPredecessors().begin())) {
if (idx != 0) {
blocks.push_back(predecessorRevMode);
indices.push_back(APInt(32, i++));
arguments.push_back(operands);
indices.push_back(APInt(32, idx - 1));
arguments.emplace_back(std::move(operands));
} else {
defaultBlock = predecessorRevMode;
defaultArguments = operands;
Expand All @@ -275,15 +276,19 @@ void MEnzymeLogic::handlePredecessors(
oBB->getPredecessors().end()) {
// If there is only one block we can directly create a branch for
// simplicity sake
revBuilder.create<cf::BranchOp>(loc, defaultBlock, defaultArguments);
auto bop =
revBuilder.create<cf::BranchOp>(loc, defaultBlock, defaultArguments);
} else {
Value cache = gutils->insertInit(gutils->getIndexCacheType());
Value flag =
revBuilder.create<enzyme::PopOp>(loc, gutils->getIndexType(), cache);

revBuilder.create<cf::SwitchOp>(
SmallVector<ValueRange> argumentRanges;
for (const auto &a : arguments)
argumentRanges.emplace_back(a);
auto bop = revBuilder.create<cf::SwitchOp>(
loc, flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices),
ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
ArrayRef<Block *>(blocks), argumentRanges);

Value origin = newBB->addArgument(gutils->getIndexType(), loc);

Expand Down Expand Up @@ -356,7 +361,6 @@ void MEnzymeLogic::differentiate(
Block *oBB = *it;
Block *newBB = gutils->getNewFromOriginal(oBB);
Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(oBB);

mapInvertArguments(oBB, reverseBB, gutils);
handleReturns(oBB, newBB, reverseBB, gutils, parentRegion);
visitChildren(oBB, reverseBB, gutils);
Expand Down Expand Up @@ -401,4 +405,4 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(

delete gutils;
return nf;
}
}
Loading

0 comments on commit 304e21b

Please sign in to comment.