diff --git a/mlir/lib/Transform/AIRMiscPasses.cpp b/mlir/lib/Transform/AIRMiscPasses.cpp index 0fec5ddc1..9fee92861 100644 --- a/mlir/lib/Transform/AIRMiscPasses.cpp +++ b/mlir/lib/Transform/AIRMiscPasses.cpp @@ -1047,56 +1047,16 @@ FailureOr tileChannelOpByFactor( auto newGetOp = rewriter.create( loc, tys, deps, newChanOp.getSymName(), newIndices, originalChanOp.getMemref(), newOffsets, newWraps, newStrides); - newGetOp->setAttr("id", - mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), - originalChanOp.getId())); + newGetOp->setAttrs(originalChanOp->getDiscardableAttrDictionary()); tokens.push_back(newGetOp.getAsyncToken()); opToSplitInfoMap[newGetOp] = splitInfoVec[i]; - newGetOp->setAttr( - "split_dim", - mlir::IntegerAttr::get(IntegerType::get(ctx, 32), splitDimOnOffsets)); - if (splitInfoAffineMap) - newGetOp->setAttr("affine_map", - mlir::AffineMapAttr::get(splitInfoAffineMap)); - if (splitInfoSplitOffset) - newGetOp->setAttr("split_offset", - mlir::IntegerAttr::get(IntegerType::get(ctx, 32), - *splitInfoSplitOffset)); - if (splitInfoSplitSize) - newGetOp->setAttr("split_size", - mlir::IntegerAttr::get(IntegerType::get(ctx, 32), - *splitInfoSplitSize)); - if (splitInfoSplitStrideFactor) - newGetOp->setAttr("split_stride_factor", - mlir::IntegerAttr::get(IntegerType::get(ctx, 32), - *splitInfoSplitStrideFactor)); } else { auto newPutOp = rewriter.create( loc, tys, deps, newChanOp.getSymName(), newIndices, originalChanOp.getMemref(), newOffsets, newWraps, newStrides); - newPutOp->setAttr("id", - mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), - originalChanOp.getId())); + newPutOp->setAttrs(originalChanOp->getDiscardableAttrDictionary()); tokens.push_back(newPutOp.getAsyncToken()); opToSplitInfoMap[newPutOp] = splitInfoVec[i]; - newPutOp->setAttr( - "split_dim", - mlir::IntegerAttr::get(IntegerType::get(ctx, 32), splitDimOnOffsets)); - if (splitInfoAffineMap) - newPutOp->setAttr("affine_map", - mlir::AffineMapAttr::get(splitInfoAffineMap)); - if (splitInfoSplitOffset) - newPutOp->setAttr("split_offset", - mlir::IntegerAttr::get(IntegerType::get(ctx, 32), - *splitInfoSplitOffset)); - if (splitInfoSplitSize) - newPutOp->setAttr("split_size", - mlir::IntegerAttr::get(IntegerType::get(ctx, 32), - *splitInfoSplitSize)); - if (splitInfoSplitStrideFactor) - newPutOp->setAttr("split_stride_factor", - mlir::IntegerAttr::get(IntegerType::get(ctx, 32), - *splitInfoSplitStrideFactor)); } } auto newWaitAll = rewriter.create( @@ -1207,26 +1167,18 @@ void AIRSplitL2MemrefForBufferConstraintPass::partitionMemref( return; push_back_if_unique(keys, offset_key); chanOpPartitions[offset_key].push_back(op); - op->setAttr("partition_key", - IntegerAttr::get(IntegerType::get(ctx, 32), offset_key)); }; for (auto op : puts) { - op->setAttr("partitioning", BoolAttr::get(ctx, true)); - if (!opToSplitInfoMap.count(op)) { - op->setAttr("opNotInOpToSplitInfoMap", BoolAttr::get(ctx, true)); + if (!opToSplitInfoMap.count(op)) continue; - } auto &[splitInfoDimOnOffsets, splitAffineMap, splitOffset, splitSize, splitStride] = opToSplitInfoMap[op]; getChanOpPartitionsMap(chanOpPartitions, keys, splitInfoDimOnOffsets, op); } for (auto op : gets) { - op->setAttr("partitioning", BoolAttr::get(ctx, true)); - if (!opToSplitInfoMap.count(op)) { - op->setAttr("opNotInOpToSplitInfoMap", BoolAttr::get(ctx, true)); + if (!opToSplitInfoMap.count(op)) continue; - } auto &[splitInfoDimOnOffsets, splitAffineMap, splitOffset, splitSize, splitStride] = opToSplitInfoMap[op]; getChanOpPartitionsMap(chanOpPartitions, keys, splitInfoDimOnOffsets, op); @@ -1495,15 +1447,6 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs( int tilingFactor = std::max(getChanCount(MM2SChannels), getChanCount(S2MMChannels)); - // Keep debug log in alloc op's attributes. TODO: clean up. - allocOp->setAttr("split", BoolAttr::get(func.getContext(), true)); - allocOp->setAttr("tilingFactor", - IntegerAttr::get(IntegerType::get(ctx, 32), tilingFactor)); - if (getChanCount(MM2SChannels) > 1) { - allocOp->setAttr("split_type", StringAttr::get(ctx, "MM2SChannels")); - } else { - allocOp->setAttr("split_type", StringAttr::get(ctx, "S2MMChannels")); - } llvm::MapVector> infoEntryMap; std::optional splitDimOffset = std::nullopt; std::optional splitDimSize = std::nullopt; @@ -1529,18 +1472,6 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs( "memref splitting analysis failed to get the split dimension."); return failure(); } - if (allocOp->hasAttr("split_dim")) - assert(allocOp->getAttrOfType("split_dim").getInt() == - *splitDim && - "L2 memref tiled inconsistently across multiple data access " - "patterns. Cannot infer L2 memref tiling strategy."); - else { - if (*splitDim >= 0) - allocOp->setAttr( - "split_dim", - IntegerAttr::get(IntegerType::get(ctx, 32), *splitDim)); - assert(*splitDim >= 0 && "failed to get split dimension"); - } // Methods to get root offset/size/stride from air.channel's operands, where // root is either a constant, or a loop's induction variable. @@ -1554,13 +1485,6 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs( }; auto getRootSize = [&](Value offsetVal, Value sizeVal) { std::optional rootSize = std::nullopt; - // if (auto constSize = getConstantIntValue(sizeVal)){ - // // splitDimSize = *constSize; - // // putgets[i]->setAttr( - // // "split_dim_size", - // // IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimSize)); - // } - // else if (auto forOp = getScfForFromVal(offsetVal)) { if (auto trip_count = air::getStaticScfForTripCountAsInt(forOp)) rootSize = *getConstantIntValue(sizeVal) * (*trip_count); @@ -1582,29 +1506,14 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs( auto offsetDimOpt = air::getOffsetDimFromMemrefDim(*splitDim, putgets[i].getStrides(), air::getTensorShape(memref.getType())); - if (offsetDimOpt) - putgets[i]->setAttr( - "split_dim", - IntegerAttr::get(IntegerType::get(ctx, 32), *offsetDimOpt)); // Infer offset at splitDim. if (auto rootOffset = - getRootOffset(putgets[i].getOffsets()[*offsetDimOpt])) { + getRootOffset(putgets[i].getOffsets()[*offsetDimOpt])) splitDimOffset = *rootOffset; - putgets[i]->setAttr( - "split_dim_offset", - IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimOffset)); - } // Infer size at splitDim. if (auto rootSize = getRootSize(putgets[i].getOffsets()[*offsetDimOpt], - putgets[i].getSizes()[*offsetDimOpt])) { + putgets[i].getSizes()[*offsetDimOpt])) splitDimSize = *rootSize; - allocOp->setAttr( - "split_dim_size", - IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimSize)); - putgets[i]->setAttr( - "split_dim_size", - IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimSize)); - } // Infer stride (factor) at splitDim. If the root comes from an scf.for // loop, and if the loop has non-unit step size, then that multiplier // should be applied to other split channe put/get ops. @@ -1612,9 +1521,6 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs( getRootStrideFactor(putgets[i].getOffsets()[*offsetDimOpt], putgets[i].getStrides()[*offsetDimOpt])) { splitDimStrideFactor = *rootStrideFactor; - putgets[i]->setAttr( - "split_dim_stride_factor", - IntegerAttr::get(IntegerType::get(ctx, 32), *splitDimStrideFactor)); // Cancel out the non-unit step size on the for loop, to get contiguous // access pattern on memrefs after split. if (auto forOp = @@ -1625,11 +1531,8 @@ AIRSplitL2MemrefForBufferConstraintPass::getTargetMemrefAllocs( } AffineMap applyMap; auto apply = getAffineMapOnMemrefSplitDim(putgets[i], *offsetDimOpt); - if (apply) { + if (apply) applyMap = apply.getAffineMap(); - allocOp->setAttr("affine_map", AffineMapAttr::get(applyMap)); - putgets[i]->setAttr("affine_map", AffineMapAttr::get(applyMap)); - } infoEntryTy newEntry = {*offsetDimOpt, applyMap, splitDimOffset, splitDimSize, splitDimStrideFactor}; @@ -1942,10 +1845,6 @@ void AIRSplitL2MemrefForBufferConstraintPass::runOnOperation() { signalPassFailure(); erased.insert(par); } - for (auto &[old, news] : parUnrollMap) { - for (auto newOp : news) - newOp->setAttr("unrolled", BoolAttr::get(ctx, true)); - } // Update map after loop unrolling. for (auto &[oldOp, splitInfo] : opToSplitInfoMap) { Operation *o = oldOp;