Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Add two clone methods about encoding to RankedTensorType. #127709

Merged
merged 1 commit into from
Feb 28, 2025

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Feb 18, 2025

There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType:

  • dropEncoding(): Return a clone of this type without the encoding.
  • cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type.

There are clone methods for shape and element type, but not for
encodings. The revision adds two clone method to RankedTensorType:
- dropEncoding(): Return a clone of this type without the encoding.
- cloneWithEncoding(Attribute encoding): Return a clone of this type
  with the given new encoding and the same shape and element type as
  this type.

Signed-off-by: hanhanW <[email protected]>
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Feb 18, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-ods

Author: Han-Chung Wang (hanhanW)

Changes

There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType:

  • dropEncoding(): Return a clone of this type without the encoding.
  • cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type.

Full diff: https://github.com/llvm/llvm-project/pull/127709.diff

2 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+11)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+14)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e5a2ae81da0c9..af474b3e3ec47 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     RankedTensorType clone(::mlir::Type elementType) {
       return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
     }
+
+    /// Return a clone of this type without the encoding.
+    RankedTensorType dropEncoding() {
+      return RankedTensorType::get(getShape(), getElementType());
+    }
+
+    /// Return a clone of this type with the given new encoding and the same
+    /// shape and element type as this type.
+    RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
+      return RankedTensorType::get(getShape(), getElementType(), encoding);
+    }
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index c2900b5aaeeeb..bc4066ed210e8 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
   ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
   view = mlir::cast<TensorWithString>(viewCreated);
   EXPECT_EQ(view.getName(), "bob");
+
+  // Verify encoding clone methods.
+  EXPECT_EQ(unitEncodingRankedTensorType,
+            cast<RankedTensorType>(noEncodingRankedTensorType)
+                .cloneWithEncoding(unitAttr));
+  EXPECT_EQ(stringEncodingRankedTensorType,
+            cast<RankedTensorType>(noEncodingRankedTensorType)
+                .cloneWithEncoding(stringAttr));
+  EXPECT_EQ(
+      noEncodingRankedTensorType,
+      cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
+  EXPECT_EQ(
+      noEncodingRankedTensorType,
+      cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
 }
 
 } // namespace

@llvmbot
Copy link
Member

llvmbot commented Feb 18, 2025

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType:

  • dropEncoding(): Return a clone of this type without the encoding.
  • cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type.

Full diff: https://github.com/llvm/llvm-project/pull/127709.diff

2 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+11)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+14)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e5a2ae81da0c9..af474b3e3ec47 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
     RankedTensorType clone(::mlir::Type elementType) {
       return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
     }
+
+    /// Return a clone of this type without the encoding.
+    RankedTensorType dropEncoding() {
+      return RankedTensorType::get(getShape(), getElementType());
+    }
+
+    /// Return a clone of this type with the given new encoding and the same
+    /// shape and element type as this type.
+    RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
+      return RankedTensorType::get(getShape(), getElementType(), encoding);
+    }
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index c2900b5aaeeeb..bc4066ed210e8 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
   ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
   view = mlir::cast<TensorWithString>(viewCreated);
   EXPECT_EQ(view.getName(), "bob");
+
+  // Verify encoding clone methods.
+  EXPECT_EQ(unitEncodingRankedTensorType,
+            cast<RankedTensorType>(noEncodingRankedTensorType)
+                .cloneWithEncoding(unitAttr));
+  EXPECT_EQ(stringEncodingRankedTensorType,
+            cast<RankedTensorType>(noEncodingRankedTensorType)
+                .cloneWithEncoding(stringAttr));
+  EXPECT_EQ(
+      noEncodingRankedTensorType,
+      cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
+  EXPECT_EQ(
+      noEncodingRankedTensorType,
+      cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
 }
 
 } // namespace

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hanhanW hanhanW merged commit 28d7671 into llvm:main Feb 28, 2025
12 checks passed
@hanhanW hanhanW deleted the add-methods-to-ranked-tensor-type branch February 28, 2025 01:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants