diff --git a/BUILD b/BUILD index 53b49b307..d8aae1d56 100644 --- a/BUILD +++ b/BUILD @@ -62,6 +62,7 @@ cc_binary( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:Transforms", + "//src/enzyme_ad/jax:RaisingTransformOps", "//src/enzyme_ad/jax:TransformOps", "//src/enzyme_ad/jax:XLADerivatives", "@stablehlo//:chlo_ops", diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 42d0a2470..f9fd9d206 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -179,9 +179,9 @@ gentbl_cc_library( ) cc_library( - name = "TransformOps", - srcs = glob(["TransformOps/*.cpp"]), - hdrs = glob(["TransformOps/*.h"]), + name = "RaisingTransformOps", + srcs = ["TransformOps/RaisingTransformOps.cpp"], + hdrs = ["TransformOps/RaisingTransformOps.h"], deps = [ "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -192,6 +192,23 @@ cc_library( ":RaisingTransformOpsIncGen", ":RaisingTransformOpsImplIncGen", ":RaisingTransformPatternsIncGen", + ], +) + +cc_library( + name = "TransformOps", + srcs = [ + "TransformOps/TransformOps.cpp", + "TransformOps/GenerateApplyPatterns.cpp", + ], + hdrs = ["TransformOps/TransformOps.h"], + deps = [ + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgTransformOps", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformDialect", + "@llvm-project//mlir:TransformDialectInterfaces", ":TransformOpsIncGen", ":TransformOpsImplIncGen", ":XLADerivatives", @@ -373,6 +390,7 @@ cc_library( ":EnzymeXLAOpsIncGen", ":EnzymeXLAPassesIncGen", ":EnzymeHLOPatternsIncGen", + ":RaisingTransformOps", "@llvm-project//mlir:DLTIDialect", "@llvm-project//mlir:GPUPipelines", "@llvm-project//llvm:Core", diff --git a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp index b7d4393f3..1505aa04d 100644 --- a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp +++ b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp @@ -17,7 +17,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" -#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" #define GET_OP_CLASSES #include "src/enzyme_ad/jax/TransformOps/RaisingTransformOps.cpp.inc" diff --git a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h index 48509be20..2027640c7 100644 --- a/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h +++ b/src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h @@ -9,7 +9,6 @@ #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" -#include "src/enzyme_ad/jax/TransformOps/OpInterfaces.h.inc" #define GET_OP_CLASSES #include "src/enzyme_ad/jax/TransformOps/RaisingTransformOps.h.inc"