Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Dec 23, 2024
1 parent aa4e598 commit fa190a4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
16 changes: 9 additions & 7 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,18 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// BroadcastOp
//===----------------------------------------------------------------------===//

void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input, ArrayRef<int64_t> shape) {
void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
ArrayRef<int64_t> shape) {
auto shapeAttr = builder.getDenseI64ArrayAttr(shape);
RankedTensorType output;
// TODO: support things other than scalars and ranked tensors, maybe reuse getShadowType here?
// TODO: support things other than scalars and ranked tensors, maybe reuse
// getShadowType here?
if (auto tensorType = input.getType().dyn_cast<TensorType>()) {
auto originalShape = tensorType.getShape();
SmallVector<int64_t, 4> newShape;
newShape.append(shape.begin(), shape.end());
newShape.append(originalShape.begin(), originalShape.end());
output = RankedTensorType::get(newShape, tensorType.getElementType());
auto originalShape = tensorType.getShape();
SmallVector<int64_t, 4> newShape;
newShape.append(shape.begin(), shape.end());
newShape.append(originalShape.begin(), originalShape.end());
output = RankedTensorType::get(newShape, tensorType.getElementType());
} else {
output = RankedTensorType::get(shape, input.getType());
}
Expand Down
6 changes: 4 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,11 @@ FunctionOpInterface CloneFunctionWithReturns(
mlir::Value val = blk.getArgument(i);
mlir::Value dval;
if (i == ArgActivity.size() - 1)
dval = blk.addArgument(getShadowType(val.getType(), width), val.getLoc());
dval = blk.addArgument(getShadowType(val.getType(), width),
val.getLoc());
else
dval = blk.insertArgument(blk.args_begin() + i + 1, getShadowType(val.getType(), width),
dval = blk.insertArgument(blk.args_begin() + i + 1,
getShadowType(val.getType(), width),
val.getLoc());
ptrInputs.map(oval, dval);
}
Expand Down
11 changes: 6 additions & 5 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,12 @@ SmallVector<bool, 1> prepareArgs(const Twine &curIndent, raw_ostream &os,
if (intrinsic == MLIRDerivatives) {
os << ";\n";
os << "if (gutils->width != 1) {\n"
<< " " << argName << "_" << (idx - 1) << " = builder.create<enzyme::BroadcastOp>(\n"
<< " op.getLoc(),\n"
<< " " << argName << "_" << (idx - 1) << ",\n"
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
<< "}";
<< " " << argName << "_" << (idx - 1)
<< " = builder.create<enzyme::BroadcastOp>(\n"
<< " op.getLoc(),\n"
<< " " << argName << "_" << (idx - 1) << ",\n"
<< " llvm::SmallVector<int64_t>({gutils->width}));\n"
<< "}";
}
}

Expand Down

0 comments on commit fa190a4

Please sign in to comment.