Skip to content

Commit

Permalink
more fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 8, 2024
1 parent 38e6039 commit c2d3c1f
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,23 @@ struct SliceConcat final : OpRewritePattern<mlir::stablehlo::SliceOp> {
for (auto v : concat.getInputs()) {
auto ty = v.getType().cast<RankedTensorType>();
auto nextdim = ty.getShape()[dim];
if (op.getStartIndices()[dim] < curdim) {
if (op.getStartIndices()[dim] >= curdim + nextdim) {
curdim += nextdim;
continue;
}
if (op.getLimitIndices()[dim] >= curdim) {
if (op.getLimitIndices()[dim] <= curdim) {
curdim += nextdim;
continue;
}
SmallVector<int64_t> nstart(op.getStartIndices().begin(), op.getStartIndices().end());
SmallVector<int64_t> nend(op.getStartIndices().begin(), op.getStartIndices().end());
SmallVector<int64_t> nend(op.getLimitIndices().begin(), op.getLimitIndices().end());
nstart[dim] -= curdim;
if (nstart[dim] < 0) nstart[dim] = 0;
nend[dim] -= curdim;
if (nend[dim] > nextdim) nend[dim] = nextdim;
auto subslice = rewriter.create<stablehlo::SliceOp>(op.getLoc(), v, nstart, nend, op.getStrides());
postConcat.push_back(subslice);
curdim += nextdim;
}
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(op, postConcat, dim);
return success();
Expand Down Expand Up @@ -271,6 +272,29 @@ struct ConcatConstProp final : OpRewritePattern<mlir::stablehlo::ConcatenateOp>
if (!type)
return failure();

if (op->getNumOperands() == 1) {
rewriter.replaceOp(op, op->getOperand(0));
return success();
}

{
SmallVector<Value> subconcat;
bool changed = false;
for (auto v : op->getOperands()) {
if (auto c2 = v.getDefiningOp<stablehlo::ConcatenateOp>())
if (c2.getDimension() == op.getDimension()) {
for (auto v2 : c2->getOperands())
subconcat.push_back(v2);
changed = true;
continue;
}
subconcat.push_back(v);
}
if (changed) {
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(op, subconcat, op.getDimension());
return success();
}
}

SmallVector<DenseElementsAttr> constants;
constants.assign(op->getNumOperands(), DenseElementsAttr());
Expand Down

0 comments on commit c2d3c1f

Please sign in to comment.