Skip to content

Commit

Permalink
Adding ttir.repeat op in MLIR
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Jan 23, 2025
1 parent 0c49bea commit 3cfbbcc
Show file tree
Hide file tree
Showing 41 changed files with 240 additions and 24 deletions.
34 changes: 34 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,40 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
let hasVerifier = 1;
}

def TTIR_RepeatOp : TTIR_DPSOp<"repeat"> {
let summary = "Repeat operation.";
let description = [{
The `repeat` operation creates a new tensor by replicating the input tensor's elements
along specified dimensions. The number of repetitions for each dimension is defined by
the `repeats` attribute, which must have the same rank as the input tensor.

Parameters:
- `input`: The input tensor.
- `repeats`: Specifies the number of times to repeat this tensor along each dimension.

### Example IR Usage:
```mlir
// Input tensor of shape (2, 3)
%input = ... : tensor<2x3xf32>

// Repeat each dimension twice
%repeated = "repeat"(%input) {repeat_dimensions = dense<[2, 2]> : tensor<2xi64>} : tensor<2x3xf32> -> tensor<4x6xf32>
```
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
DenseI32ArrayAttr:$repeat_dimensions);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_RepeatInterleaveOp : TTIR_DPSOp<"repeat_interleave"> {
let summary = "Repeat interleave op.";
let description = [{
Expand Down
8 changes: 6 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -774,11 +774,15 @@ def TTNN_ReshapeOp : TTNN_Op<"reshape"> {
def TTNN_RepeatOp : TTNN_Op<"repeat"> {
let summary = "Repeat op.";
let description = [{
Repeat the input tensor according to number of times specified in repeat_dimensions.
Returns a new tensor filled with repetition of input tensor according to number of times specified in repeat_dims.

Parameters:
- `input_tensor` (ttnn.Tensor): the input tensor.
- `repeat_dims` (number): The number of repetitions for each element.
}];

let arguments = (ins AnyRankedTensor:$input,
I32ArrayAttr:$shape);
I32ArrayAttr:$repeat_dims);

let results = (outs AnyRankedTensor:$result);

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ table ReshapeOp {
table RepeatOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
shape: [uint32];
repeat_dims: [uint32];
}

table SliceOp {
Expand Down
18 changes: 18 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,23 @@ class BroadcastOpConversionPattern
}
};

class RepeatOpConversionPattern : public OpConversionPattern<ttir::RepeatOp> {
using OpConversionPattern<ttir::RepeatOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(ttir::RepeatOp op, ttir::RepeatOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto repeatDimensionsAttr = adaptor.getRepeatDimensionsAttr();

rewriter.replaceOpWithNewOp<ttnn::RepeatOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), rewriter.getI32ArrayAttr(repeatDimensionsAttr));

return success();
}
};

class UnsqueezeOpConversionPattern
: public OpConversionPattern<ttir::UnsqueezeOp> {
public:
Expand Down Expand Up @@ -1335,6 +1352,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
EmbeddingBackwardOpConversionPattern,
RepeatOpConversionPattern,
RepeatInterleaveOpConversionPattern,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
Expand Down
44 changes: 44 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,50 @@ ::mlir::LogicalResult mlir::tt::ttir::AllocOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// RepeatOp
//===----------------------------------------------------------------------===//

// BroadcastOp verification
::mlir::LogicalResult mlir::tt::ttir::RepeatOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
llvm::ArrayRef<int32_t> repeatDimensions = getRepeatDimensions();

// Input tensor and repeate dimension argument must have same rank
if (inputType.getRank() != static_cast<int64_t>(repeatDimensions.size())) {
return emitOpError() << "Input tensor rank " << inputType.getRank()
<< " doesn't match the number of repeat dimensions "
<< repeatDimensions.size() << ".";
}

// Input and output tensors must have the same rank
if (inputType.getRank() != outputType.getRank()) {
return emitOpError() << "Input tensor rank " << inputType.getRank()
<< " doesn't match the output tensor rank "
<< outputType.getRank() << ".";
}

// Verify output shape based on input shape and repeat dimension argument
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::ArrayRef<int64_t> outputShape = outputType.getShape();

for (size_t i = 0; i < inputShape.size(); i++) {
int64_t expectedDimValue = inputShape[i] * repeatDimensions[i];
if (expectedDimValue != outputShape[i]) {
return emitOpError() << "Input tensor shape ("
<< ttmlir::utils::join(inputShape, ",")
<< ") at index " << i
<< " does not repeat to output ("
<< ttmlir::utils::join(outputShape, ",")
<< ") using repeat value " << repeatDimensions[i]
<< ".";
}
}

return success();
}

//===----------------------------------------------------------------------===//
// RepeatInterleaveOp
//===----------------------------------------------------------------------===//
Expand Down
28 changes: 22 additions & 6 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,19 +412,35 @@ ::mlir::LogicalResult mlir::tt::ttnn::ConcatOp::verify() {
::mlir::LogicalResult mlir::tt::ttnn::RepeatOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getResult().getType();
auto repeatDims = getRepeatDims();

auto shape = getShape();
// Verify that the input tensor and repeate_dims argument have same rank
if (inputType.getRank() != static_cast<int64_t>(repeatDims.size())) {
return emitOpError() << "Input tensor rank " << inputType.getRank()
<< " doesn't match the number of repeat dimensions "
<< repeatDims.size() << ".";
}

// Verify that the input and output tensor have same rank
if (inputType.getRank() != outputType.getRank()) {
return emitOpError() << "Input tensor rank " << inputType.getRank()
<< " doesn't match the output tensor rank "
<< outputType.getRank() << ".";
}

// Verify expected output shape
auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();

for (size_t i = 0; i < shape.size(); i++) {
uint32_t dimValue = mlir::cast<IntegerAttr>(shape[i]).getInt();
for (size_t i = 0; i < getRepeatDims().size(); i++) {
uint32_t dimValue = mlir::cast<IntegerAttr>(repeatDims[i]).getInt();
if (inputShape[i] * dimValue != outputShape[i]) {
return emitOpError() << "Input tensor shape ("
<< ttmlir::utils::join(inputShape, ",") << ") index "
<< i << " does not repeat to output ("
<< ttmlir::utils::join(inputShape, ",")
<< ") at index " << i
<< " does not repeat to output ("
<< ttmlir::utils::join(outputShape, ",")
<< ") using repeat value " << dimValue;
<< ") using repeat value " << dimValue << ".";
}
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,8 +874,8 @@ ::flatbuffers::Offset<::tt::target::ttnn::RepeatOp>
createRepeatOp(FlatbufferObjectCache &cache, RepeatOp op) {
auto in =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto shape =
arrayAttrToFlatbuffer<mlir::IntegerAttr, uint32_t>(cache, op.getShape());
auto shape = arrayAttrToFlatbuffer<mlir::IntegerAttr, uint32_t>(
cache, op.getRepeatDims());
auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);

Expand Down
8 changes: 4 additions & 4 deletions runtime/lib/ttnn/operations/data_movement/repeat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ void run(const ::tt::target::ttnn::RepeatOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
DEBUG_ASSERT(in.is_allocated());
const auto *fbShape = op->shape();
const std::vector<uint32_t> dims(fbShape->begin(), fbShape->end());
::ttnn::Shape shape(dims);
::ttnn::Tensor out = ::ttnn::repeat(in, shape);
const auto *fbShape = op->repeat_dims();
const std::vector<uint32_t> repeatDims(fbShape->begin(), fbShape->end());
::ttnn::Shape repeatDimsShape(repeatDims);
::ttnn::Tensor out = ::ttnn::repeat(in, repeatDimsShape);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::data_movement
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
// Negative tests for repeat operation

// Verify that the parsing fails if the input tensor and repeat_dimensions attribute doesn't have the same rank
module {
func.func @repeat_not_valid_repeat_dimension_attribute(%arg0: tensor<32x32xf32>) -> tensor<32x64xf32> {
// CHECK: 'ttir.repeat' op Input tensor rank 2 doesn't match the number of repeat dimensions 1.
%0 = tensor.empty() : tensor<32x64xf32>
%1 = "ttir.repeat"(%arg0, %0) {repeat_dimensions = array<i32 : 2>} : (tensor<32x32xf32>, tensor<32x64xf32>) -> tensor<32x64xf32>
return %1 : tensor<32x64xf32>
}
}

// -----

// Verify that the parsing fails if the input tensor and repeat_dimensions attribute doesn't have the same rank
module {
func.func @repeat_not_valid_input_output(%arg0: tensor<32x32xf32>) -> tensor<1x32x64xf32> {
// CHECK: 'ttir.repeat' op Input tensor rank 2 doesn't match the output tensor rank 3.
%0 = tensor.empty() : tensor<1x32x64xf32>
%1 = "ttir.repeat"(%arg0, %0) {repeat_dimensions = array<i32 : 1, 2>} : (tensor<32x32xf32>, tensor<1x32x64xf32>) -> tensor<1x32x64xf32>
return %1 : tensor<1x32x64xf32>
}
}

// -----

// Verify that the parsing fails if the output tensor dimensions are not expected
module {
func.func @repeat_not_valid_input_output(%arg0: tensor<32x32xf32>) -> tensor<32x128xf32> {
// CHECK: 'ttir.repeat' op Input tensor shape (32,32) at index 1 does not repeat to output (32,128) using repeat value 2.
%0 = tensor.empty() : tensor<32x128xf32>
%1 = "ttir.repeat"(%arg0, %0) {repeat_dimensions = array<i32 : 1, 2>} : (tensor<32x32xf32>, tensor<32x128xf32>) -> tensor<32x128xf32>
return %1 : tensor<32x128xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
// Negative tests for repeat operation

// Verify that the parsing fails if the input tensor and repeat_dimensions attribute doesn't have the same rank
module {
func.func @repeat_not_valid_repeat_dimension_attribute(%arg0: tensor<32x32xf32>) -> tensor<32x64xf32> {
// CHECK: 'ttnn.repeat' op Input tensor rank 2 doesn't match the number of repeat dimensions 1.
%0 = "ttnn.repeat"(%arg0) {repeat_dims = [2 : i32]} : (tensor<32x32xf32>) -> tensor<32x64xf32>
return %0 : tensor<32x64xf32>
}
}

// -----

// Verify that the parsing fails if the input tensor and repeat_dimensions attribute doesn't have the same rank
module {
func.func @repeat_not_valid_input_output(%arg0: tensor<32x32xf32>) -> tensor<1x32x64xf32> {
// CHECK: 'ttnn.repeat' op Input tensor rank 2 doesn't match the output tensor rank 3.
%0 = "ttnn.repeat"(%arg0) {repeat_dims = [1 : i32, 2 : i32]} : (tensor<32x32xf32>) -> tensor<1x32x64xf32>
return %0 : tensor<1x32x64xf32>
}
}

// -----

// Verify that the parsing fails if the output tensor dimensions are not expected
module {
func.func @repeat_not_valid_input_output(%arg0: tensor<32x32xf32>) -> tensor<32x128xf32> {
// CHECK: 'ttnn.repeat' op Input tensor shape (32,32) at index 1 does not repeat to output (32,128) using repeat value 2.
%0 = "ttnn.repeat"(%arg0) {repeat_dims = [1 : i32, 2 : i32]} : (tensor<32x32xf32>) -> tensor<32x128xf32>
return %0 : tensor<32x128xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s

module {
func.func @repeat_on_one_dim(%arg0: tensor<1x32x32xf32>) -> tensor<32x32x32xf32> {
// CHECK: "ttnn.repeat"
// CHECK-SAME: repeat_dims = [32 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<32x32x32xf32>
%1 = "ttir.repeat"(%arg0, %0) {repeat_dimensions = array<i32: 32, 1, 1>} : (tensor<1x32x32xf32>, tensor<32x32x32xf32>) -> tensor<32x32x32xf32>
return %1 : tensor<32x32x32xf32>
}

func.func @repeat_on_all_dims(%arg0: tensor<1x1x32xf32>) -> tensor<32x32x64xf32> {
// CHECK: "ttnn.repeat"
// CHECK-SAME: repeat_dims = [32 : i32, 32 : i32, 2 : i32]
%0 = tensor.empty() : tensor<32x32x64xf32>
%1 = "ttir.repeat"(%arg0, %0) {repeat_dimensions = array<i32: 32, 32, 2>} : (tensor<1x1x32xf32>, tensor<32x32x64xf32>) -> tensor<32x32x64xf32>
return %1 : tensor<32x32x64xf32>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
module {
func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 16 : i32, 1 : i32]
// CHECK-SAME: repeat_dims = [1 : i32, 16 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array<i32 : 1, 16, 1>}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%2 = tensor.empty() : tensor<1x16x32xf32>
Expand All @@ -15,7 +15,7 @@ module {
func.func public @main(%arg0: tensor<1xf32>, %arg1: tensor<512x512xf32>) -> (tensor<512x512xf32>) {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [512 : i32, 512 : i32]
// CHECK-SAME: repeat_dims = [512 : i32, 512 : i32]
%0 = tensor.empty() : tensor<1x1xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 1 : i32]}> : (tensor<1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>
%2 = tensor.empty() : tensor<512x512xf32>
Expand All @@ -30,7 +30,7 @@ module {
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
// CHECK-SAME: repeat_dims = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
Expand All @@ -46,7 +46,7 @@ module {
module {
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
// CHECK-SAME: repeat_dims = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Dialect/TTNN/implicit_broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func.func @main(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<784x12
// CHECK-NOT: "ttnn.repeat"
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [784 : i32, 1 : i32]
// CHECK-SAME: repeat_dims = [784 : i32, 1 : i32]
// CHECK: %{{[0-9]+}} = "ttnn.add"
%0 = tensor.empty() : tensor<1x128xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32>
Expand All @@ -35,7 +35,7 @@ func.func @main(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<784x12

module { func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> {
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 16 : i32, 1 : i32]
// CHECK-SAME: repeat_dims = [1 : i32, 16 : i32, 1 : i32]
// CHECK: %{{[0-9]+}} = "ttnn.multiply"(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
// CHECK: %{{[0-9]+}} = "ttnn.bitwise_and"([[VAL0]], %{{[0-9]+}}, %{{[0-9]+}})
%0 = tensor.empty() : tensor<1x16x32xf32>
Expand Down
12 changes: 12 additions & 0 deletions test/ttmlir/Silicon/TTNN/data_movement/repeat/simple_repeat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module {
func.func @repeat(%arg0: tensor<1x32x32xf32>) -> tensor<32x32x32xf32> {
// CHECK: "ttnn.repeat"
// CHECK-SAME: repeat_dims = [32 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<32x32x32xf32>
%1 = "ttir.repeat"(%arg0, %0) {repeat_dimensions = array<i32: 32, 1, 1>} : (tensor<1x32x32xf32>, tensor<32x32x32xf32>) -> tensor<32x32x32xf32>
return %1 : tensor<32x32x32xf32>
}
}
2 changes: 1 addition & 1 deletion test/ttmlir/Silicon/TTNN/implicit_broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func.func @main(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<784x12
// CHECK-NOT: "ttnn.repeat"
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [784 : i32, 1 : i32]
// CHECK-SAME: repeat_dims = [784 : i32, 1 : i32]
// CHECK: %{{[0-9]+}} = "ttnn.add"
%0 = tensor.empty() : tensor<1x128xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32>
Expand Down
Loading

0 comments on commit 3cfbbcc

Please sign in to comment.