Skip to content

Commit

Permalink
Remove unused getitems in op by op flow [#105]
Browse files Browse the repository at this point in the history
  • Loading branch information
brataTT committed Jan 7, 2025
1 parent 3bdfb8b commit 78ceecd
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/torch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,29 @@ def forward(self, x):
)


def test_unused_output():
class Basic_var_only(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
var, mean = torch.var_mean(x)
return var

class Basic_mean_only(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
var, mean = torch.var_mean(x)
return mean

for module in [Basic_var_only, Basic_mean_only]:
cc = CompilerConfig()
cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP
verify_module(module(), input_shapes=[(256, 256)], compiler_config=cc)


@pytest.mark.parametrize(
("input_range", "input_shapes", "input_type"),
[
Expand Down
27 changes: 27 additions & 0 deletions tests/torch/test_maxpool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import nn
import pytest

import tt_torch
from tt_torch.tools.verify import verify_module
from tt_torch.tools.utils import CompilerConfig, CompileDepth


def test_maxpool2d():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.nn.functional.max_pool2d(x, kernel_size=2, stride=2)

cc = CompilerConfig()
cc.compile_depth = CompileDepth.EXECUTE_OP_BY_OP
verify_module(
Basic(),
inputs=[torch.randn(1, 1, 224, 224).to(torch.bfloat16)],
compiler_config=cc,
)
12 changes: 12 additions & 0 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,15 @@ def compile_op(self, node, *inputs, **kwargs):
):
getitem_nodes = []
graph_node.meta["val"] = node.meta["val"]

for idx, tensor_meta in enumerate(node.meta["tensor_meta"]):
# filter out unused outputs that do not exist in the reduced graph
users = self.gm.graph.find_nodes(
op="call_function", target=operator.getitem
)
if not any(user_node.args == (node, idx) for user_node in users):
continue

getitem_node = graph.call_function(
operator.getitem, args=(graph_node, idx)
)
Expand All @@ -199,6 +207,10 @@ def compile_op(self, node, *inputs, **kwargs):
out = graph.output((graph_node,))
if "tensor_meta" not in node.meta:
raise ValueError(f"Node {node} does not have tensor_meta")
if len(node.users) != len(graph_node.users):
raise ValueError(
f"Op Node {node} has different number of users({len(graph_node.users)}) from global graph({len(node.users)})"
)

op.compilation_status = OpCompilationStatus.CREATED_GRAPH
out.meta["tensor_meta"] = node.meta["tensor_meta"]
Expand Down

0 comments on commit 78ceecd

Please sign in to comment.