Skip to content

Commit

Permalink
Add MLIR lowering for cumsum op (#1186)
Browse files Browse the repository at this point in the history
### Ticket
Fix #891, Fix
#503

### Problem description
Add MLIR lowering for cumsum op

### What's changed
- Added mapping for cumsum in lower_to_mlir
- Add TargetType handling for Int64 type in lower_to_mlir
- Add E2E tests for cumsum op
  • Loading branch information
ashokkumarkannan1 authored Feb 8, 2025
1 parent fb10a81 commit 7286f5b
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 20 deletions.
4 changes: 4 additions & 0 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ enum class TargetType
{
SourceType,
UInt32,
Int64,
};

struct AttributeRemap
Expand Down Expand Up @@ -104,6 +105,7 @@ class AttributeMapper
{
add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UInt32));
add_op_mapping("reduce_avg", "dim", AttributeRemap("dim_arg"));
add_op_mapping("cumsum", "dim", AttributeRemap(std::nullopt, TargetType::Int64));

// Add more default mappings here
}
Expand Down Expand Up @@ -234,6 +236,7 @@ class MLIRGenerator
case TargetType::UInt32:
TT_ASSERT(std::get<int>(value) >= 0, "Value must be an >= 0 for conversion to uint32");
return builder_.getUI32IntegerAttr(static_cast<uint32_t>(std::get<int>(value)));
case TargetType::Int64: return builder_.getI64IntegerAttr(static_cast<int64_t>(std::get<int>(value)));
default:
// If type not handled, throw an exception
throw std::runtime_error("Unhandled target type conversion");
Expand Down Expand Up @@ -608,6 +611,7 @@ class MLIRGenerator
lowering_handler_map["concatenate"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ConcatOp>;
lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Conv2dOp>;
lowering_handler_map["cosine"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::CosOp>;
lowering_handler_map["cumsum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::CumSumOp>;
lowering_handler_map["embedding_bw"] =
&MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingBackwardOp>;
lowering_handler_map["embedding"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingOp>;
Expand Down
6 changes: 2 additions & 4 deletions forge/forge/op/eltwise_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def Tanh(name: str, operandA: Tensor) -> Tensor:
return op("tanh", name, operandA).get_tensor()


def CumSum(name: str, operandA: Tensor, axis: int, exclusive: bool = False) -> Tensor:
def CumSum(name: str, operandA: Tensor, dim: int) -> Tensor:

"""
Cumulative sum operation.
Expand Down Expand Up @@ -483,9 +483,7 @@ def CumSum(name: str, operandA: Tensor, axis: int, exclusive: bool = False) -> T
Forge tensor
"""

assert not exclusive, "Currently not supported"

return op("cumsum", name, operandA, axis=axis, exclusive=exclusive).get_tensor()
return op("cumsum", name, operandA, dim=dim).get_tensor()


def LogicalNot(name: str, operandA: Tensor) -> Tensor:
Expand Down
9 changes: 4 additions & 5 deletions forge/forge/op/eval/forge/cumulativesum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@

class CumulativeSum(PyEltwiseUnaryOp):
@classmethod
def create(cls, axis, exclusive=False):
def create(cls, dim):
self = cls("cumsum")
self.axis = axis
self.exclusive = exclusive
self.dim = dim[0]
return self

def eval(self, tensors):
assert len(tensors) == 1, "Cumulative Sum should have one input"
shape = tensors[0].shape
original_types = [o.dtype for o in tensors]
ret = torch.cumsum(tensors[0], dim=self.axis)
ret = torch.cumsum(tensors[0], dim=self.dim)

if ret.dtype != original_types[0]:
ret = ret.type(original_types[0])
Expand All @@ -44,7 +43,7 @@ def shape(self, tensor_shapes):
def backward(self, ac, operand, inputs, output, grad):
assert len(inputs) == 1, "Cumulative Sum should have one input"
assert operand == 0, "Invalid operand index"
dim = self.axis
dim = self.dim
assert dim == 0, "Unsupported dim different then 0 for cumulative sum backward pass"
if dim == 0:
return ac.op(Nop.create(), (grad,))
Expand Down
11 changes: 1 addition & 10 deletions forge/forge/tvm_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,19 +661,10 @@ def populate_cumsum_args(graph, nid, compiler_cfg):
axis = node["attrs"]["axis"][0][0]
args.append(
(
"axis",
"dim",
f"{axis}",
)
)

exclusive = node["attrs"]["exclusive"][0][0]
args.append(
(
"exclusive",
f"{exclusive}",
)
)

return args


Expand Down
19 changes: 18 additions & 1 deletion forge/test/mlir/operators/eltwise_unary/test_eltwise_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,27 @@ def forward(self, x):
@pytest.mark.parametrize(
"shape, dim",
[
((56), 0),
((1, 128), 1),
pytest.param(
(1, 64, 76),
2,
marks=pytest.mark.xfail(reason="ValueError: Data mismatch -> AutomaticValueChecker (compare_with_golden)"),
),
pytest.param(
(1, 64, 76, 96),
3,
marks=pytest.mark.xfail(reason="ValueError: Data mismatch -> AutomaticValueChecker (compare_with_golden)"),
),
pytest.param(
(1, 64, 86, 100, 120),
4,
marks=pytest.mark.xfail(
reason=" RuntimeError: (dim >= 0 && dim <= 3),info: dim should be 0 - 3, but got: 4"
),
),
],
)
@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
@pytest.mark.push
def test_cumsum(shape, dim):
class CumSum(nn.Module):
Expand Down

0 comments on commit 7286f5b

Please sign in to comment.