You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
nn.MaxPool2D gets lowered to %result0, %result1 = torch.aten.max_pool2d_with_indices. %result0 is the actual output of the maxpool, and %result1 is some indices. If %result1 is not used (and it usually isn't), then compiling the full graph converts this op to:
Which we can lower to TTIR. Notice how it only has one result.
However, when compiling op-by-op we construct the single-op graph to return all results of whatever torch op we're looking at, this causes the maxpool to be lowered to:
def test_maxpool2d():
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.nn.functional.max_pool2d(x, kernel_size=2, stride=2)
cc = CompilerConfig()
cc.compile_depth = CompileDepth.COMPILE_OP_BY_OP # Comment this line out to compile full graph
verify_module(Basic(), inputs=[torch.randn(1, 1, 224, 224).to(torch.bfloat16)], compiler_config=cc)
nn.MaxPool2D
gets lowered to%result0, %result1 = torch.aten.max_pool2d_with_indices
.%result0
is the actual output of the maxpool, and%result1
is some indices. If%result1
is not used (and it usually isn't), then compiling the full graph converts this op to:Which we can lower to TTIR. Notice how it only has one result.
However, when compiling op-by-op we construct the single-op graph to return all results of whatever torch op we're looking at, this causes the maxpool to be lowered to:
This cannot be lowered to TTIR.
You can use the following test to try it out:
FYI: @AleksKnezevic
The text was updated successfully, but these errors were encountered: