Skip to content

Commit

Permalink
[Encoding][NFC] Moving Encoding attr/enum to Encoding[Types|Attrs].* (i…
Browse files Browse the repository at this point in the history
…ree-org#18711)

The revision keeps `EncodingBase.td` simple. It follows the IREE core
dialect style, which moves the declarations to `EncodingTypes.h` and
implementation to `EncodingAttrs.cpp`.

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Oct 8, 2024
1 parent e8ff07e commit 4636257
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 320 deletions.
7 changes: 5 additions & 2 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ iree_td_library(
name = "td_files",
srcs = enforce_glob(
[
"EncodingAttrs.td",
"EncodingBase.td",
"EncodingOps.td",
],
Expand All @@ -39,6 +40,7 @@ iree_td_library(
iree_compiler_cc_library(
name = "IR",
srcs = [
"EncodingAttrs.cpp",
"EncodingAttrs.cpp.inc",
"EncodingDialect.cpp",
"EncodingDialect.cpp.inc",
Expand All @@ -54,6 +56,7 @@ iree_compiler_cc_library(
"EncodingEnums.h.inc",
"EncodingOps.h",
"EncodingOps.h.inc",
"EncodingTypes.h",
"EncodingTypes.h.inc",
],
deps = [
Expand Down Expand Up @@ -101,7 +104,7 @@ iree_gentbl_cc_library(
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "EncodingBase.td",
td_file = "EncodingAttrs.td",
deps = [":td_files"],
)

Expand Down Expand Up @@ -169,7 +172,7 @@ iree_gentbl_cc_library(
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "EncodingBase.td",
td_file = "EncodingAttrs.td",
deps = [":td_files"],
)

Expand Down
6 changes: 4 additions & 2 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ iree_cc_library(
"EncodingEnums.h.inc"
"EncodingOps.h"
"EncodingOps.h.inc"
"EncodingTypes.h"
"EncodingTypes.h.inc"
SRCS
"EncodingAttrs.cpp"
"EncodingAttrs.cpp.inc"
"EncodingDialect.cpp"
"EncodingDialect.cpp.inc"
Expand Down Expand Up @@ -63,7 +65,7 @@ iree_tablegen_library(
NAME
EncodingEnumsGen
TD_FILE
"EncodingBase.td"
"EncodingAttrs.td"
OUTS
--gen-enum-decls EncodingEnums.h.inc
--gen-enum-defs EncodingEnums.cpp.inc
Expand All @@ -85,7 +87,7 @@ iree_tablegen_library(
NAME
EncodingTypesGen
TD_FILE
"EncodingBase.td"
"EncodingAttrs.td"
OUTS
--gen-attrdef-decls --attrdefs-dialect=iree_encoding EncodingAttrs.h.inc
--gen-attrdef-defs --attrdefs-dialect=iree_encoding EncodingAttrs.cpp.inc
Expand Down
160 changes: 160 additions & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"

#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"

namespace mlir::iree_compiler::IREE::Encoding {

EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex,
EncodingOpType opType, ArrayRef<Type> elemTypes,
ArrayRef<AffineMap> maps,
std::optional<AffineMap> bcastMap,
ArrayRef<int64_t> roundDimsTo) {
Builder b(ctx);
auto opTypeAttr = EncodingOpTypeAttr::get(ctx, opType);
auto roundDimsToAttr = roundDimsTo.empty()
? DenseI64ArrayAttr()
: b.getDenseI64ArrayAttr(roundDimsTo);
auto bcastMapAttr = bcastMap.has_value()
? AffineMapAttr::get(bcastMap.value())
: AffineMapAttr();
return get(ctx, b.getIndexAttr(operandIndex), opTypeAttr,
b.getTypeArrayAttr(elemTypes), b.getAffineMapArrayAttr(maps),
bcastMapAttr, roundDimsToAttr);
}

AffineMap EncodingAttr::getMapForOperandIndex() {
auto index = getOperandIndex().getValue().getZExtValue();
switch (index) {
case MATMUL_LHS:
case MATMUL_RHS:
case MATMUL_RESULT: {
auto indexingMap =
llvm::cast<AffineMapAttr>(getUserIndexingMaps()[index]).getAffineMap();
if (auto bcastMap = getBcastMap()) {
indexingMap = bcastMap.getAffineMap().compose(indexingMap);
}
return indexingMap;
}
default:
return AffineMap();
}
}

std::optional<unsigned> EncodingAttr::mapDimToOperandIndex(int64_t dimPos) {
return getMapForOperandIndex().getResultPosition(
getAffineDimExpr(dimPos, getContext()));
}

MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp,
int narrowThreshold) {
linalg::ContractionDimensions cDims =
linalg::inferContractionDims(linalgOp).value();
auto map = linalgOp.getIndexingMapsArray().back();
auto outType = llvm::cast<ShapedType>(linalgOp.getDpsInits()[0].getType());
auto getOutputSizeAtDimPos = [=](unsigned dimPos) -> int64_t {
return outType.getDimSize(
map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext()))
.value());
};
// M or N can be empty instead of having an explicit dim size of 1 for matvec
// and vecmat, so set to 1 if empty.
int64_t mSize = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]);
int64_t nSize = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]);

MatmulNarrowDim narrowM, narrowN;
if (!ShapedType::isDynamic(mSize) && mSize < narrowThreshold) {
narrowM = {/*dim=*/MatmulNarrowDim::Dim::M, /*size=*/mSize};
}
if (!ShapedType::isDynamic(nSize) && nSize < narrowThreshold) {
narrowN = {/*dim=*/MatmulNarrowDim::Dim::N, /*size=*/nSize};
}

return (narrowM && (!narrowN || mSize <= nSize)) ? narrowM : narrowN;
}

ArrayRef<int64_t> EncodingAttr::getRoundDimsToArray() {
auto roundDimsTo = getRoundDimsTo();
if (!roundDimsTo) {
return {};
}
return llvm::cast<DenseI64ArrayAttr>(roundDimsTo).asArrayRef();
}

SmallVector<Type> EncodingAttr::getElementTypesArray() {
return llvm::map_to_vector(getElementTypes().getValue(), [](Attribute a) {
return llvm::cast<TypeAttr>(a).getValue();
});
}

EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
return get(bcastMap.getContext(), getOperandIndex(), getOpType(),
getElementTypes(), getUserIndexingMaps(),
AffineMapAttr::get(bcastMap), getRoundDimsTo());
}

MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) {
if (encoding.getOpType().getValue() != EncodingOpType::matmul) {
return {};
}
ArrayRef<int64_t> roundDimsTo = encoding.getRoundDimsToArray();
if (roundDimsTo.empty()) {
return {};
}
int m = roundDimsTo[0];
int n = roundDimsTo[1];
if (m < n) {
return {MatmulNarrowDim::Dim::M, m};
}
if (n < m) {
return {MatmulNarrowDim::Dim::N, n};
}
return {};
}

EncodingAttr getEncodingAttr(RankedTensorType type) {
return dyn_cast_or_null<EncodingAttr>(type.getEncoding());
}

FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding) {
auto indexingMapsAttr = encoding.getUserIndexingMaps();
SmallVector<AffineMap> indexingMaps = llvm::map_to_vector(
indexingMapsAttr.getValue(), [](Attribute m) -> AffineMap {
return cast<AffineMapAttr>(m).getAffineMap();
});
return linalg::inferContractionDims(indexingMaps);
}

std::string stringifyOperandIndex(IntegerAttr valueAttr) {
auto value = valueAttr.getValue().getZExtValue();
switch (value) {
case MATMUL_LHS:
return "LHS";
case MATMUL_RHS:
return "RHS";
case MATMUL_RESULT:
return "RESULT";
default:
assert(false && "invalid index");
return "";
}
}

} // namespace mlir::iree_compiler::IREE::Encoding
104 changes: 104 additions & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_DIALECT_ENCODING_ATTRS
#define IREE_DIALECT_ENCODING_ATTRS

include "iree/compiler/Dialect/Encoding/IR/EncodingBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

//===---------------------------------------------------------------------===//
// Data layout encoding attributes
//===---------------------------------------------------------------------===//

class IREEEncoding_Attr<string name, list<Trait> traits = []>
: AttrDef<IREEEncoding_Dialect, name, traits>;

class IREEEncoding_I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, summary, cases> {
let cppNamespace = "::mlir::iree_compiler::IREE::Encoding";
let genSpecializedAttr = 0;
}

class IREEEncoding_EnumAttr<EnumAttrInfo enumInfo, string name = "">
: EnumAttr<IREEEncoding_Dialect, enumInfo, name>;

// Enums for tagging operand operation in an EncodingAttr
def MATMUL : I32EnumAttrCase<"matmul", 0>;
def CONV : I32EnumAttrCase<"conv", 1>;

def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType",
"Tracks the type of operation of the operand.", [
MATMUL,
CONV,
]>;

def EncodingOpTypeAttr:
IREEEncoding_EnumAttr<EncodingOpType, "optype">;

def EncodingAttr :
IREEEncoding_Attr<"Encoding"> {
let mnemonic = "encoding";
let summary = [{information to decide how to data-tile a tensor}];
let description = [{
This attribute describes the change in the layout for
a given tensor to execute subsequent operations on
the tiled layout. The encoding serves as a way to
represent the change in the way the data is laid out in
memory without changing the logical rank/extent of
the tensor itself. When required, the encoding
can be used to explicitly manifest the layout change
through operations like pack/unpack.
}];

let assemblyFormat = "`<` struct(params) `>`";

let parameters = (ins
AttrParameter<"IntegerAttr", "this tensor operand's index in the parameter list">:$operand_index,
AttrParameter<"EncodingOpTypeAttr", "operand type">:$op_type,
AttrParameter<"ArrayAttr", "element types of the user's operands">:$element_types,
OptionalParameter<"ArrayAttr", "Indexing maps of the operation using this tensor">:$user_indexing_maps,
OptionalParameter<"AffineMapAttr", "Indexing map that represents the broadcasting dims in the producer">:$bcast_map,
// TODO(hanchung): The round_dims_to parameter can be revisited. We explicitly map them to M,N,K dimension for now.
OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to
);

let builders = [
AttrBuilder<(ins "int64_t":$operandIndex,
"EncodingOpType":$opType,
"ArrayRef<Type>":$elemTypes,
CArg<"ArrayRef<AffineMap>", "{}">:$maps,
CArg<"std::optional<AffineMap>", "{}">:$bcastMap,
CArg<"ArrayRef<int64_t>", "{}">:$roundDimsTo)>
];

let extraClassDeclaration = [{
/// Returns the bcast_map composed with the user_indexing_map for the
/// operand_index. The dimensions of the returned map are those of the
/// data-tiled op's iteration space, and the results of the map are in
/// the domain of the encoded tensor type.
AffineMap getMapForOperandIndex();

/// Given the dim position of the encoding `user_indexing_maps`, returns the
/// matching index of the given encoding's tensor, using getMapForOperandIndex
/// bcast_map and user_indexing_map.
std::optional<unsigned> mapDimToOperandIndex(int64_t dimPos);

/// Returns an integer array with values in `round_dims_to`.
ArrayRef<int64_t> getRoundDimsToArray();

/// Returns a vector with values in `element_types`.
SmallVector<Type> getElementTypesArray();

/// Clones an encoding with a new bcast_map
EncodingAttr clone(AffineMap bcastMap);
}];

let genVerifyDecl = 0;
}

#endif // IREE_DIALECT_ENCODING_ATTRS
Loading

0 comments on commit 4636257

Please sign in to comment.