Skip to content

Commit

Permalink
Fix potential out-of-bound access in int8_mm.py (#1751)
Browse files Browse the repository at this point in the history
* fix potential out-of-bound access

* remove unused EVEN_K

* refactor fix with triton.heuristics

* restore EVEN_K as an input

* fix typo

* fix another typo

* ruff reformatted
  • Loading branch information
mark14wu authored Feb 25, 2025
1 parent 38e36de commit 98c4e2e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchao/prototype/quantized_training/int8_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@


@triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"])
@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0})
@triton.jit
def _scaled_int8_mm_kernel(
A_ptr,
Expand Down Expand Up @@ -176,7 +177,6 @@ def scaled_int8_mm_cuda(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tens
*A.stride(),
*B.stride(),
*C.stride(),
EVEN_K=K % 2 == 0,
COL_SCALE_SCALAR=col_scale.numel() == 1,
)
return C

0 comments on commit 98c4e2e

Please sign in to comment.