From 6e28e4b809eef8209a935f239243b5d3985f3329 Mon Sep 17 00:00:00 2001 From: brataTT Date: Tue, 7 Jan 2025 15:18:01 +0000 Subject: [PATCH] Remove unused getitems in op by op flow [#105] --- tests/torch/test_basic.py | 23 +++++++++++++++++++++++ tests/torch/test_maxpool2d.py | 27 +++++++++++++++++++++++++++ tt_torch/dynamo/backend.py | 14 ++++++++++++++ 3 files changed, 64 insertions(+) create mode 100644 tests/torch/test_maxpool2d.py diff --git a/tests/torch/test_basic.py b/tests/torch/test_basic.py index 835310f4..527bfde1 100644 --- a/tests/torch/test_basic.py +++ b/tests/torch/test_basic.py @@ -412,6 +412,29 @@ def forward(self, x): ) +def test_unused_output(): + class Basic_var_only(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + var, mean = torch.var_mean(x) + return var + + class Basic_mean_only(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + var, mean = torch.var_mean(x) + return mean + + for module in [Basic_var_only, Basic_mean_only]: + cc = CompilerConfig() + cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP + verify_module(module(), input_shapes=[(256, 256)], compiler_config=cc) + + @pytest.mark.parametrize( ("input_range", "input_shapes", "input_type"), [ diff --git a/tests/torch/test_maxpool2d.py b/tests/torch/test_maxpool2d.py new file mode 100644 index 00000000..636fb744 --- /dev/null +++ b/tests/torch/test_maxpool2d.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import torch +from torch import nn +import pytest + +import tt_torch +from tt_torch.tools.verify import verify_module +from tt_torch.tools.utils import CompilerConfig, CompileDepth + + +def test_maxpool2d(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool2d(x, kernel_size=2, stride=2) + + cc = CompilerConfig() + cc.compile_depth = CompileDepth.EXECUTE_OP_BY_OP + verify_module( + Basic(), + inputs=[torch.randn(1, 1, 224, 224).to(torch.bfloat16)], + compiler_config=cc, + ) diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 49cf5b95..0261ecd1 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -188,7 +188,17 @@ def compile_op(self, node, *inputs, **kwargs): ): getitem_nodes = [] graph_node.meta["val"] = node.meta["val"] + for idx, tensor_meta in enumerate(node.meta["tensor_meta"]): + # filter out unused outputs that do not exist in the reduced graph + users = self.gm.graph.find_nodes( + op="call_function", target=operator.getitem + ) + if not any( + user_node.args == (node, idx) for user_node in users + ): + continue + getitem_node = graph.call_function( operator.getitem, args=(graph_node, idx) ) @@ -199,6 +209,10 @@ def compile_op(self, node, *inputs, **kwargs): out = graph.output((graph_node,)) if "tensor_meta" not in node.meta: raise ValueError(f"Node {node} does not have tensor_meta") + if len(node.users) != len(graph_node.users): + raise ValueError( + f"Op Node {node} has different number of users({len(graph_node.users)}) from global graph({len(node.users)})" + ) op.compilation_status = OpCompilationStatus.CREATED_GRAPH out.meta["tensor_meta"] = node.meta["tensor_meta"]