Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 8, 2024
1 parent f308e4e commit 10e312d
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,33 +173,40 @@ struct AddPad final : OpRewritePattern<mlir::stablehlo::AddOp> {
SmallVector<size_t> idxs;
for (auto &&[low, high, dim] : llvm::zip(lhs.getEdgePaddingLow(), lhs.getEdgePaddingHigh(), type.getShape())) {
padidx++;
if (low == 0 && high == dim) continue;
idxs.push_back(padidx-1);
if (low == 0 && high == 0) continue;
idxs.push_back(padidx);
}

if (idxs.size() == 0) {
if (idxs.size() == 1) {
auto idx = idxs[0];

SmallVector<int64_t> strides(type.getShape().size(), 1);
SmallVector<int64_t> starts(type.getShape().size(), 0);
SmallVector<int64_t> limits(type.getShape().begin(), type.getShape().end());

starts[idx] = lhs.getEdgePaddingLow()[idx];
limits[idx] = type.getShape()[idx] - lhs.getEdgePaddingLow()[idx];

auto midSlice = rewriter.create<stablehlo::SliceOp>(op.getLoc(), rhs, starts, limits, strides);
SmallVector<Value, 1> vals;

if (lhs.getEdgePaddingLow()[idx] != 0) {
starts[idx] = 0;
limits[idx] = lhs.getEdgePaddingLow()[idx];
auto prevSlice = rewriter.create<stablehlo::SliceOp>(op.getLoc(), rhs, starts, limits, strides);
vals.push_back(prevSlice);
}

starts[idx] = type.getShape()[idx] - lhs.getEdgePaddingLow()[idx];
limits[idx] = 0;
auto postSlice = rewriter.create<stablehlo::SliceOp>(op.getLoc(), rhs, starts, limits, strides);
starts[idx] = lhs.getEdgePaddingLow()[idx];
limits[idx] = type.getShape()[idx] - lhs.getEdgePaddingHigh()[idx];

auto midSlice = rewriter.create<stablehlo::SliceOp>(op.getLoc(), rhs, starts, limits, strides);
auto mid = rewriter.create<stablehlo::AddOp>(op.getLoc(), midSlice, lhs.getOperand());
vals.push_back(mid);

if (lhs.getEdgePaddingHigh()[idx] != 0) {
starts[idx] = type.getShape()[idx] - lhs.getEdgePaddingHigh()[idx];
limits[idx] = 0;
auto postSlice = rewriter.create<stablehlo::SliceOp>(op.getLoc(), rhs, starts, limits, strides);
vals.push_back(postSlice);
}

Value vals[3] = {prevSlice, mid, postSlice};
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(op, vals, idx);
return success();
}
Expand Down

0 comments on commit 10e312d

Please sign in to comment.