Skip to content

Commit

Permalink
pipeline fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 19, 2024
1 parent f94c4b8 commit 78711bd
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 10 deletions.
32 changes: 32 additions & 0 deletions src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,20 @@
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"

#include "lhlo/IR/lhlo_ops.h"
#include "mhlo/transforms/passes.h"
#include "mlir/InitAllPasses.h"
#include "xla/mlir/backends/cpu/transforms/passes.h"
#include "xla/mlir/math/transforms/passes.h"
#include "xla/mlir/memref/transforms/passes.h"
#include "xla/mlir/runtime/transforms/passes.h"
#include "xla/mlir_hlo/lhlo/transforms/passes.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"

#include "deallocation/transforms/passes.h"
#include "lhlo/transforms/passes.h"
#include "lhlo_gpu/IR/lhlo_gpu_ops.h"
// #include "transforms/passes.h"

#include "compile_with_xla.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
Expand Down Expand Up @@ -992,6 +1005,25 @@ PYBIND11_MODULE(enzyme_call, m) {

mlir::registerAllPasses();

mlir::mhlo::registerAllMhloPasses();
xla::cpu::registerCpuTransformsPasses();
mlir::hlo::registerLMHLOTransformsPasses();
xla::runtime::registerRuntimeTransformsPasses();
xla::registerMathTransformsPasses();
xla::registerMemrefTransformsPasses();

mlir::registerShapePasses();
mlir::registerConvertShapeToStandardPass();
mlir::registerConvertShapeConstraintsPass();
mlir::memref::registerResolveShapedTypeResultDims();
mlir::registerLinalgPasses();
mlir::registerReconcileUnrealizedCastsPass();
mlir::registerConversionPasses();
mlir::bufferization::registerBufferizationPasses();
mlir::registerAsyncPasses();
mlir::arith::registerArithPasses();
mlir::memref::registerMemRefPasses();

pybind11::enum_<Language>(m, "Language")
.value("CPP", Language::CPP)
.value("LLVM", Language::LLVM)
Expand Down
186 changes: 176 additions & 10 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,158 @@
LANG_LLVM = enzyme_call.Language.LLVM
LANG_MHLO = enzyme_call.Language.MHLO


def xla_runtime(options):
return True


def pass_pipeline(options):
return "any(inline{default-pipeline=canonicalize max-iterations=4 },expand-hlo-tuples{entry-function=main},func.func(mhlo-flatten-tuple),xla-legalize-abi,func.func(mhlo-test-lower-general-dot),func.func(mhlo-broadcast-propagation),cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(xla-sparse-custom-call-to-pack),func.func(legalize-sparse-ops{legalize-to-custom-calls=false}),func.func(chlo-legalize-to-hlo{expand-compositions=true legalize-broadcasts=true}),func.func(mhlo-sparse-rewriting),func.func(mhlo-legalize-control-flow),func.func(mhlo-legalize-dot-general-to-dot),hlo-legalize-to-arithmetic,func.func(xla-legalize-library-ops),func.func(mhlo-expand-ops-simplifier),func.func(hlo-canonicalize-scatter),func.func(hlo-canonicalize-dot),func.func(group-reduction-dimensions{prefer-columns-reductions=true}),func.func(hlo-legalize-to-linalg{enable-primitive-ops=false}),func.func(lower-index-cast),convert-to-signless,func.func(shape-simplification),func.func(shape-to-shape-lowering),convert-shape-to-std,func.func(convert-shape-constraints),cse,resolve-shaped-type-result-dims,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(linalg-fuse-elementwise-ops),reconcile-unrealized-casts,convert-tensor-to-linalg,func.func(detensorize-scf-ops),func.func(linalg-detensorize{aggressive-mode=true}),eliminate-empty-tensors,func.func(empty-tensor-to-alloc-tensor),canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(linalg-generalize-named-ops),eliminate-empty-tensors,sparsification-and-bufferization,sparse-storage-specifier-to-llvm,func.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true}),func.func(finalizing-bufferize),func.func(xla-rewrite-realloc-to-alloc),func.func(vectorize-copy),func.func(naive-copy-removal),func.func(convert-linalg-to-loops),cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},buffer-results-to-out-params,func.func(promote-buffers-to-stack{max-alloc-size-in-bytes=1024 max-rank-of-allocated-memref=1}),func.func(buffer-deallocation),convert-bufferization-to-memref,func.func(xla-remove-copies-to-out-params),canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(convert-complex-to-standard),cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},func.func(convert-vector-to-scf{full-unroll=false lower-tensors=false target-rank=1}),func.func(xla-legalize-i1-vector-transfers),func.func(xla-convert-memref-element-cast-to-llvm),async-func-to-async-runtime,xla-rt-export-functions,xla-cpu-to-cpu-runtime,xla-rt-convert-custom-calls,xla-rt-convert-asserts,inline{default-pipeline=canonicalize max-iterations=4 },canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},cse,func.func(xla-math-approximation{oplist=all}),func.func(convert-linalg-to-parallel-loops),canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},async-to-async-runtime,xla-rt-move-allocas-to-entry-block,async-runtime-policy-based-ref-counting,func.func(arith-expand{include-bf16=false}),func.func(memref-expand),func.func(expand-strided-metadata),lower-affine,func.func(xla-memref-aligned-allocations{alignment=0}),xla-rt-to-llvm,convert-async-to-llvm,generic-host-to-llvm{enable-avx2=false},reconcile-unrealized-casts,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},cse)"
return """
inline{default-pipeline=canonicalize max-iterations=4},
expand-hlo-tuples{entry-function=main},
func.func(mhlo-flatten-tuple),
xla-legalize-abi,
func.func(mhlo-test-lower-general-dot),
func.func(mhlo-broadcast-propagation),
cse,
canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true},
func.func(xla-sparse-custom-call-to-pack),
func.func(legalize-sparse-ops{legalize-to-custom-calls=false}),
func.func(chlo-legalize-to-hlo{
expand-compositions=true legalize-broadcasts=true}),
func.func(mhlo-sparse-rewriting),
func.func(mhlo-legalize-control-flow),
func.func(mhlo-legalize-dot-general-to-dot),
hlo-legalize-to-arithmetic,
func.func(xla-legalize-library-ops),
func.func(mhlo-expand-ops-simplifier),
func.func(hlo-canonicalize-scatter),
func.func(hlo-canonicalize-dot),
func.func(group-reduction-dimensions{prefer-columns-reductions=true}),
func.func(hlo-legalize-to-linalg{enable-primitive-ops=false}),
func.func(lower-index-cast),
convert-to-signless,
func.func(shape-simplification),
func.func(shape-to-shape-lowering),
convert-shape-to-std,
func.func(convert-shape-constraints),
cse,
resolve-shaped-type-result-dims,
canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true},
func.func(linalg-fuse-elementwise-ops),
reconcile-unrealized-casts,
convert-tensor-to-linalg,
func.func(detensorize-scf-ops),
func.func(linalg-detensorize{aggressive-mode=true}),
eliminate-empty-tensors,
func.func(empty-tensor-to-alloc-tensor),
canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true},
func.func(linalg-generalize-named-ops),
eliminate-empty-tensors,
sparsification-and-bufferization,
sparse-storage-specifier-to-llvm,
func.func(canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true}),
func.func(finalizing-bufferize),
func.func(xla-rewrite-realloc-to-alloc),
func.func(vectorize-copy),
func.func(naive-copy-removal),
func.func(convert-linalg-to-loops),
cse,
canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true},
buffer-results-to-out-params,
func.func(promote-buffers-to-stack{
max-alloc-size-in-bytes=1024
max-rank-of-allocated-memref=1}),
func.func(buffer-deallocation),
convert-bufferization-to-memref,
func.func(xla-remove-copies-to-out-params),
cse,
canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true},
func.func(convert-complex-to-standard),
cse,
canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true},
func.func(convert-vector-to-scf{
full-unroll=false
lower-tensors=false
target-rank=1}),
func.func(xla-legalize-i1-vector-transfers),
func.func(xla-convert-memref-element-cast-to-llvm),
async-func-to-async-runtime,
xla-rt-export-functions,
xla-cpu-to-cpu-runtime,
xla-rt-convert-custom-calls,
xla-rt-convert-asserts,
inline{default-pipeline=canonicalize max-iterations=4},
canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true},
cse,
func.func(xla-math-approximation{oplist=all}),
func.func(convert-linalg-to-parallel-loops),
canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true},
async-to-async-runtime,
xla-rt-move-allocas-to-entry-block,
async-runtime-policy-based-ref-counting,
func.func(arith-expand{include-bf16=false}),
func.func(memref-expand),
func.func(expand-strided-metadata),
lower-affine,
func.func(xla-memref-aligned-allocations{alignment=0}),
xla-rt-to-llvm,
convert-async-to-llvm,
generic-host-to-llvm{enable-avx2=false},
reconcile-unrealized-casts,
canonicalize{
max-iterations=10
max-num-rewrites=-1
region-simplify=true
test-convergence=false
top-down=true},
cse"""


def resource_dir():
import os
Expand Down Expand Up @@ -192,7 +339,14 @@ def _enzyme_aug_abstract_eval(
argv = argv + ("-resource-dir", resource_dir()) + cflags()

tapeSize, tmpSize = enzyme_call.tape_and_tmp_size(
source, fn, out_shapes, in_shapes, argv, lang, xla_runtime(pipeline_options), pass_pipeline(pipeline_options)
source,
fn,
out_shapes,
in_shapes,
argv,
lang,
xla_runtime(pipeline_options),
pass_pipeline(pipeline_options),
)
res = tuple(prev_out_shapes) + (
jax.core.ShapedArray((tapeSize,), (jax.numpy.int8)),
Expand Down Expand Up @@ -277,7 +431,7 @@ def _enzyme_primal_lowering(
enzyme_call.ABI.Primal,
lang,
xla_runtime(pipeline_options),
pass_pipeline(pipeline_options)
pass_pipeline(pipeline_options),
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -340,7 +494,7 @@ def _enzyme_fwd_lowering(
enzyme_call.ABI.Forward,
lang,
xla_runtime(pipeline_options),
pass_pipeline(pipeline_options)
pass_pipeline(pipeline_options),
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -402,7 +556,7 @@ def _enzyme_aug_lowering(
enzyme_call.ABI.Augmented,
lang,
xla_runtime(pipeline_options),
pass_pipeline(pipeline_options)
pass_pipeline(pipeline_options),
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -471,7 +625,7 @@ def _enzyme_rev_lowering(
enzyme_call.ABI.Reverse,
lang,
xla_runtime(pipeline_options),
pass_pipeline(pipeline_options)
pass_pipeline(pipeline_options),
)
identifier_attr = jax_mlir.dense_int_elements([identifier])
identifier_op = stablehlo.ConstantOp(identifier_attr)
Expand Down Expand Up @@ -518,10 +672,16 @@ def ffi_call(
fn: str = "f",
argv: tuple[str] = (),
lang: int = LANG_CPP,
pipeline_options = None
pipeline_options=None
):
return _enzyme_primal_p.bind(
*args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=lang, pipeline_options=pipeline_options
*args,
source=source,
fn=fn,
argv=argv,
out_shapes=out_shapes,
lang=lang,
pipeline_options=pipeline_options
)


Expand All @@ -531,10 +691,16 @@ def cpp_call(
source: str,
fn: str = "f",
argv: tuple[str] = (),
pipeline_options = None
pipeline_options=None
):
return ffi_call(
*args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=LANG_CPP, pipeline_options=pipeline_options
*args,
source=source,
fn=fn,
argv=argv,
out_shapes=out_shapes,
lang=LANG_CPP,
pipeline_options=pipeline_options
)


Expand Down

0 comments on commit 78711bd

Please sign in to comment.