-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Add two clone methods about encoding to RankedTensorType. #127709
Conversation
There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType: - dropEncoding(): Return a clone of this type without the encoding. - cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type. Signed-off-by: hanhanW <[email protected]>
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-ods Author: Han-Chung Wang (hanhanW) ChangesThere are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType:
Full diff: https://github.com/llvm/llvm-project/pull/127709.diff 2 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e5a2ae81da0c9..af474b3e3ec47 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
RankedTensorType clone(::mlir::Type elementType) {
return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
}
+
+ /// Return a clone of this type without the encoding.
+ RankedTensorType dropEncoding() {
+ return RankedTensorType::get(getShape(), getElementType());
+ }
+
+ /// Return a clone of this type with the given new encoding and the same
+ /// shape and element type as this type.
+ RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
+ return RankedTensorType::get(getShape(), getElementType(), encoding);
+ }
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index c2900b5aaeeeb..bc4066ed210e8 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
view = mlir::cast<TensorWithString>(viewCreated);
EXPECT_EQ(view.getName(), "bob");
+
+ // Verify encoding clone methods.
+ EXPECT_EQ(unitEncodingRankedTensorType,
+ cast<RankedTensorType>(noEncodingRankedTensorType)
+ .cloneWithEncoding(unitAttr));
+ EXPECT_EQ(stringEncodingRankedTensorType,
+ cast<RankedTensorType>(noEncodingRankedTensorType)
+ .cloneWithEncoding(stringAttr));
+ EXPECT_EQ(
+ noEncodingRankedTensorType,
+ cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
+ EXPECT_EQ(
+ noEncodingRankedTensorType,
+ cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
}
} // namespace
|
@llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesThere are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType:
Full diff: https://github.com/llvm/llvm-project/pull/127709.diff 2 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e5a2ae81da0c9..af474b3e3ec47 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
RankedTensorType clone(::mlir::Type elementType) {
return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
}
+
+ /// Return a clone of this type without the encoding.
+ RankedTensorType dropEncoding() {
+ return RankedTensorType::get(getShape(), getElementType());
+ }
+
+ /// Return a clone of this type with the given new encoding and the same
+ /// shape and element type as this type.
+ RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
+ return RankedTensorType::get(getShape(), getElementType(), encoding);
+ }
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index c2900b5aaeeeb..bc4066ed210e8 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
view = mlir::cast<TensorWithString>(viewCreated);
EXPECT_EQ(view.getName(), "bob");
+
+ // Verify encoding clone methods.
+ EXPECT_EQ(unitEncodingRankedTensorType,
+ cast<RankedTensorType>(noEncodingRankedTensorType)
+ .cloneWithEncoding(unitAttr));
+ EXPECT_EQ(stringEncodingRankedTensorType,
+ cast<RankedTensorType>(noEncodingRankedTensorType)
+ .cloneWithEncoding(stringAttr));
+ EXPECT_EQ(
+ noEncodingRankedTensorType,
+ cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
+ EXPECT_EQ(
+ noEncodingRankedTensorType,
+ cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
}
} // namespace
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType: