Skip to content

Commit

Permalink
use getShadowType in BroadcastOp builder
Browse files Browse the repository at this point in the history
Co-authored-by: Billy Moses <[email protected]>
  • Loading branch information
jumerckx and wsmoses committed Dec 25, 2024
1 parent fa190a4 commit 1802949
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -199,17 +200,9 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
ArrayRef<int64_t> shape) {
auto shapeAttr = builder.getDenseI64ArrayAttr(shape);
RankedTensorType output;
// TODO: support things other than scalars and ranked tensors, maybe reuse
// getShadowType here?
if (auto tensorType = input.getType().dyn_cast<TensorType>()) {
auto originalShape = tensorType.getShape();
SmallVector<int64_t, 4> newShape;
newShape.append(shape.begin(), shape.end());
newShape.append(originalShape.begin(), originalShape.end());
output = RankedTensorType::get(newShape, tensorType.getElementType());
} else {
output = RankedTensorType::get(shape, input.getType());
auto resultTy = input.getType();
for (auto s : llvm::reverse(shape)) {
resultTy = resultTy.cast<AutoDiffTypeInterface>().getShadowType(s);
}
build(builder, result, output, input, shapeAttr);
build(builder, result, resultTy, input, shapeAttr);
}

0 comments on commit 1802949

Please sign in to comment.