Skip to content

Commit

Permalink
check stride padding of conv1d
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiang Bin committed Nov 11, 2024
1 parent d41b935 commit 4f9b9e9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/flag_gems/ops/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,12 @@ def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
bias is None or weight.shape[0] == bias.shape[0]
), "Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape"

if isinstance(stride, list):
if isinstance(stride, (list, tuple)):
stride_height, stride_width = stride
else:
stride_height = stride_width = stride

if isinstance(padding, list):
if isinstance(padding, (list, tuple)):
padding_height, padding_width = padding
else:
padding_height = padding_width = padding
Expand Down

0 comments on commit 4f9b9e9

Please sign in to comment.