Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched autodiff #2181

Merged
merged 22 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,20 @@ def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> {

}

def BroadcastOp : Enzyme_Op<"broadcast"> {
let description = [{
Broadcast the operand by adding extra dimensions with sizes provided by the `shape` attribute to the front.
For scalar operands, ranked tensor is created.

NOTE: Only works for scalar and *ranked* tensor operands for now.
}];

let arguments = (ins AnyType:$input, DenseI64ArrayAttr:$shape);
let results = (outs AnyRankedTensor:$output);

let builders = [
OpBuilder<(ins "Value":$input, "ArrayRef<int64_t>":$shape)>
];
}

#endif // ENZYME_OPS
15 changes: 15 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -191,3 +192,17 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

return success();
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
ArrayRef<int64_t> shape) {
auto shapeAttr = builder.getDenseI64ArrayAttr(shape);
auto resultTy = input.getType();
for (auto s : llvm::reverse(shape)) {
resultTy = resultTy.cast<AutoDiffTypeInterface>().getShadowType(s);
}
build(builder, result, resultTy, input, shapeAttr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -69,3 +70,10 @@ void mlir::enzyme::registerArithDialectAutoDiffInterface(
arith::ConstantOp::attachInterface<ArithConstantOpBatchInterface>(*context);
});
}

void mlir::enzyme::registerTensorDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) {
registerInterfaces(context);
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ class FloatTypeInterface
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
if (width > 1) {
return RankedTensorType::get({width}, self);
} else {
return self;
}
}

bool isMutable(Type self) const { return false; }
Expand Down Expand Up @@ -106,7 +109,14 @@ class TensorTypeInterface
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
if (width != 1) {
auto tenType = self.cast<TensorType>();
auto shape = tenType.getShape();
SmallVector<int64_t, 4> newShape;
newShape.push_back(width);
newShape.append(shape.begin(), shape.end());
return RankedTensorType::get(newShape, tenType.getElementType());
}
return self;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst,
newVals.push_back(gutils->invertPointerM(op, builder));
} else {
Type retTy =
arg.getType().cast<AutoDiffTypeInterface>().getShadowType();
arg.getType().cast<AutoDiffTypeInterface>().getShadowType(
gutils->width);
auto toret = retTy.cast<AutoDiffTypeInterface>().createNullValue(
builder, op.getLoc());
newVals.push_back(toret);
Expand Down Expand Up @@ -146,7 +147,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler(
if (auto iface =
dyn_cast<AutoDiffTypeInterface>(operand.get().getType())) {
if (!iface.isMutable()) {
Type retTy = iface.getShadowType();
Type retTy = iface.getShadowType(gutils->width);
auto toret = retTy.cast<AutoDiffTypeInterface>().createNullValue(
builder, operand.get().getLoc());
newOperands.push_back(toret);
Expand Down Expand Up @@ -346,7 +347,7 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
<< result.getType() << "\n";
return failure();
}
newOpResultTypes.push_back(typeIface.getShadowType());
newOpResultTypes.push_back(typeIface.getShadowType(gutils->width));
}

SmallVector<Value> newOperands;
Expand Down Expand Up @@ -432,4 +433,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces(
enzyme::registerCFDialectAutoDiffInterface(registry);
enzyme::registerLinalgDialectAutoDiffInterface(registry);
enzyme::registerFuncDialectAutoDiffInterface(registry);
enzyme::registerTensorDialectAutoDiffInterface(registry);
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ void registerCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
void registerMathDialectAutoDiffInterface(DialectRegistry &registry);
void registerFuncDialectAutoDiffInterface(DialectRegistry &registry);
void registerTensorDialectAutoDiffInterface(DialectRegistry &registry);

void registerCoreDialectAutodiffInterfaces(DialectRegistry &registry);

Expand Down
6 changes: 4 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,11 @@ FunctionOpInterface CloneFunctionWithReturns(
mlir::Value val = blk.getArgument(i);
mlir::Value dval;
if (i == ArgActivity.size() - 1)
dval = blk.addArgument(val.getType(), val.getLoc());
dval = blk.addArgument(getShadowType(val.getType(), width),
val.getLoc());
else
dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(),
dval = blk.insertArgument(blk.args_begin() + i + 1,
getShadowType(val.getType(), width),
val.getLoc());
ptrInputs.map(oval, dval);
}
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v,
return invertedPointers.lookupOrNull(v);

if (isConstantValue(v)) {
if (auto iface = v.getType().dyn_cast<AutoDiffTypeInterface>()) {
if (auto iface =
getShadowType(v.getType()).dyn_cast<AutoDiffTypeInterface>()) {
OpBuilder::InsertionGuard guard(Builder2);
if (auto op = v.getDefiningOp())
Builder2.setInsertionPoint(getNewFromOriginal(op));
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms
MLIRFuncDialect
MLIRFuncTransforms
MLIRGPUDialect
MLIRTensorDialect
MLIRIR
MLIRLLVMDialect
MLIRMathDialect
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

#include "Dialect/Dialect.h"

Expand Down Expand Up @@ -80,6 +81,10 @@ namespace affine {
class AffineDialect;
} // end namespace affine

namespace tensor {
class TensorDialect;
} // end namespace tensor

namespace LLVM {
class LLVMDialect;
} // end namespace LLVM
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def DifferentiatePass : Pass<"enzyme"> {
let dependentDialects = [
"arith::ArithDialect",
"complex::ComplexDialect",
"cf::ControlFlowDialect"
"cf::ControlFlowDialect",
"tensor::TensorDialect",
];
let constructor = "mlir::enzyme::createDifferentiatePass()";
}
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/enzymemlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ int main(int argc, char **argv) {
registry.insert<mlir::omp::OpenMPDialect>();
registry.insert<mlir::math::MathDialect>();
registry.insert<mlir::linalg::LinalgDialect>();
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<DLTIDialect>();

registry.insert<mlir::enzyme::EnzymeDialect>();
Expand Down
26 changes: 26 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_branch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64, %y : f64) -> f64 {
%c = arith.cmpf ult, %x, %y : f64
cf.cond_br %c, ^blk2(%x : f64), ^blk2(%y : f64)

^blk2(%r : f64):
return %r : f64
}
func.func @dsq(%x : f64, %dx : tensor<2xf64>, %y : f64, %dy : tensor<2xf64>) -> tensor<2xf64> {
%r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> (tensor<2xf64>)
return %r : tensor<2xf64>
}
}

// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64>
// CHECK-NEXT: return %[[i0]] : tensor<2xf64>
// CHECK-NEXT: }
// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64
// CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>)
// CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0
// CHECK-NEXT: return %[[i2]] : tensor<2xf64>
// CHECK-NEXT: }
33 changes: 33 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_for.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64) -> f64 {
%cst = arith.constant 10.000000e+00 : f64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (f64) {
%n = arith.addf %arg2, %x : f64
scf.yield %n : f64
}
return %r : f64
}
func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>)
return %r : tensor<2xf64>
}
}

// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64>
// CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, tensor<2xf64>) {
// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : tensor<2xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64
// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, tensor<2xf64>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[i0]]#1 : tensor<2xf64>
// CHECK-NEXT: }
43 changes: 43 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_if.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64, %c : i1) -> f64 {
%c2 = arith.constant 2.000000e+00 : f64
%c10 = arith.constant 10.000000e+00 : f64
%r:2 = scf.if %c -> (f64, f64) {
%mul = arith.mulf %x, %x : f64
scf.yield %mul, %c2 : f64, f64
} else {
%add = arith.addf %x, %x : f64
scf.yield %add, %c10 : f64, f64
}
%res = arith.mulf %r#0, %r#1 : f64
return %res : f64
}
func.func @dsq(%x : f64, %dx : tensor<2xf64>, %c : i1) -> tensor<2xf64> {
%r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_const>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>, i1) -> (tensor<2xf64>)
return %r : tensor<2xf64>
}
}

// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: i1) -> tensor<2xf64> {
// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64
// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64
// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, tensor<2xf64>, f64) {
// CHECK-NEXT: %[[t4:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %[[t5:.+]] = arith.mulf %[[arg1]], %[[t4]] : tensor<2xf64>
// CHECK-NEXT: %[[t6:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %[[t7:.+]] = arith.mulf %[[arg1]], %[[t6]] : tensor<2xf64>
// CHECK-NEXT: %[[t8:.+]] = arith.addf %[[t5]], %[[t7]] : tensor<2xf64>
// CHECK-NEXT: %[[t9:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64
// CHECK-NEXT: scf.yield %[[t9]], %[[t8]], %[[cst2]] : f64, tensor<2xf64>, f64
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[e4:.+]] = arith.addf %[[arg1]], %[[arg1]] : tensor<2xf64>
// CHECK-NEXT: %[[e5:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64
// CHECK-NEXT: scf.yield %[[e5]], %[[e4]], %[[cst10]] : f64, tensor<2xf64>, f64
// CHECK-NEXT: }
// CHECK-NEXT: %[[r1:.+]] = "enzyme.broadcast"(%[[r0]]#2) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#1, %[[r1]] : tensor<2xf64>
// CHECK-NEXT: %[[r3:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64
// CHECK-NEXT: return %[[r2]] : tensor<2xf64>
// CHECK-NEXT: }
26 changes: 26 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_scalar.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : f64) -> f64{
%y = arith.mulf %x, %x : f64
return %y : f64
}
func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>)
return %r : tensor<2xf64>
}
}

// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64>
// CHECK-NEXT: return %[[i0]] : tensor<2xf64>
// CHECK-NEXT: }
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : f64 -> tensor<2xf64>
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64>
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : f64 -> tensor<2xf64>
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64>
// CHECK-NEXT: return %[[i2]] : tensor<2xf64>
// CHECK-NEXT: }
26 changes: 26 additions & 0 deletions enzyme/test/MLIR/ForwardMode/batched_tensor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{
%y = arith.mulf %x, %x : tensor<10xf64>
return %y : tensor<10xf64>
}
func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>)
return %r : tensor<2x10xf64>
}
}

// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> {
// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64>
// CHECK-NEXT: }
// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> {
// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array<i64: 2>}> : (tensor<10xf64>) -> tensor<2x10xf64>
// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64>
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array<i64: 2>}> : (tensor<10xf64>) -> tensor<2x10xf64>
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64>
// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64>
// CHECK-NEXT: }
13 changes: 12 additions & 1 deletion enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,19 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
os << ord;
}
if (!vecValue && !startsWith(ord, "local")) {
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives))
if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) {
os << ")";
if (intrinsic == MLIRDerivatives) {
os << ";\n";
os << "if (gutils->width != 1) {\n"
<< " " << argName << "_" << (idx - 1)
<< " = builder.create<enzyme::BroadcastOp>(\n"
<< " op.getLoc(),\n"
<< " " << argName << "_" << (idx - 1) << ",\n"
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
<< "}";
}
}

if (lookup && intrinsic != MLIRDerivatives)
os << ", " << builder << ")";
Expand Down
Loading