Skip to content

Commit

Permalink
Ensure the padding exists solely in place of the conv padding attr
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Sep 11, 2024
1 parent 32b4ac3 commit 58b7793
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions python/tvm/relay/op/contrib/forge/forge_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,27 @@ def callback(self, pre, post, node_map):

class FuseConvAndPoolPadding(DFPatternCallback):
def __init__(self):
super().__init__(require_type=True)
super().__init__(require_type=True, rewrite_once=True)
self.act = wildcard()
self.weight = wildcard()
self.pad = is_op("nn.pad")(self.act, wildcard())
self.pattern = is_op("nn.conv2d")(self.pad, self.weight) | is_op("nn.max_pool2d")(self.pad)

def callback(self, pre, post, node_map):
act = node_map[self.act][0]
op = node_map[self.pattern][0].op
pad_width = node_map[self.pad][0].attrs.pad_width
conv_pool = node_map[self.pattern][0]
pad = node_map[self.pad][0]
if not all(conv_pool.attrs.padding[i] == 0 for i in range(len(conv_pool.attrs.padding))) \
or not isinstance(pad.args[1], tvm.relay.Constant) or not pad.args[1].data.shape == () \
or not int(pad.args[1].data.numpy()) == 0:
return

pad_width = pad.attrs.pad_width
padding = list(pad_width[-2]) + list(pad_width[-3]) # left, right, top, bottom

op_attrs = {**node_map[self.pattern][0].attrs}
op_attrs = {**conv_pool.attrs}
op_attrs["padding"] = padding

if op.name == "nn.conv2d":
if conv_pool.op.name == "nn.conv2d":
weight = node_map[self.weight][0]
return tvm.relay.op.nn.conv2d(
act,
Expand Down

0 comments on commit 58b7793

Please sign in to comment.