Skip to content

Commit

Permalink
fix test (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
MARD1NO authored Aug 13, 2024
1 parent 1effdfe commit 88df9bb
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,22 @@ def test_accuracy_resolve_neg(shape, dtype):
@pytest.mark.parametrize("hiddensize", [128])
@pytest.mark.parametrize("topk", [5])
@pytest.mark.parametrize("largest", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_topk(
batch_size,
hiddensize,
topk,
largest,
dtype,
):
# Note(Zhengzekang): here I use arange is to generate unique array
# Due to fp16 and bf16 has lower precision, it maybe generate the same number, which cause topk index is not equal.
x = torch.arange(batch_size * hiddensize, dtype=dtype, device="cuda").reshape(
batch_size, hiddensize
)
indices = torch.randperm(x.size(1))
x = x[:, indices]
x = torch.arange(hiddensize, dtype=dtype, device="cuda")
x = x.repeat(batch_size).reshape(batch_size, hiddensize)

# Each row use different shuffled index.
for bsz in range(batch_size):
col_indices = torch.randperm(x.size(1))
x[bsz, :] = x[bsz, col_indices]

ref_value, ref_index = torch.topk(x, topk, largest=largest)

with flag_gems.use_gems():
Expand Down

0 comments on commit 88df9bb

Please sign in to comment.