Skip to content

Commit

Permalink
Implement out kwarg overloads for custom ops
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Mar 7, 2025
1 parent 23eba7a commit 2d5b2cc
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 126 deletions.
129 changes: 107 additions & 22 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,42 @@ def _(
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul(A, B)
out = torch.ops.bitsandbytes.int8_mm_dequant(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
return out


# Define op
# TODO: mutable output arg as alias of return can be challenging;
# consider a separate op without aliased return:
# int8_linear_matmul_out(
# Tensor A, Tensor B, Tensor out, ScalarType dtype=int32
# ) -> ()
# return () instead of `None` for compatibility, see here: https://github.com/pytorch/pytorch/issues/125044
torch.library.define(
"bitsandbytes::int8_linear_matmul",
"(Tensor A, Tensor B, Tensor? out=None, ScalarType dtype=int32) -> Tensor",
"(Tensor A, Tensor B) -> Tensor",
)


@register_fake("bitsandbytes::int8_linear_matmul")
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
def _(A: torch.Tensor, B: torch.Tensor):
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
shapeC = (*A.shape[:-1], B.shape[0])
if out is None:
return torch.empty(shapeC, device=A.device, dtype=dtype)
return out
return torch.empty(shapeC, device=A.device, dtype=torch.int32)


# More info on `out` overloads:
# https://github.com/pytorch/pytorch/issues/125044
torch.library.define(
"bitsandbytes::int8_linear_matmul.out",
"(Tensor A, Tensor B, Tensor! out) -> ()",
)


@register_fake("bitsandbytes::int8_linear_matmul.out")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
shapeC = (*A.shape[:-1], B.shape[0])

torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
torch._check(out.shape == shapeC, lambda: f"Expected out.shape == {shapeC}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == torch.int32, lambda: f"Expected out.dtype == int32, got {out.dtype}")


torch.library.define(
Expand Down Expand Up @@ -107,7 +119,7 @@ def _(A: torch.Tensor, stats: torch.Tensor):

torch.library.define(
"bitsandbytes::int8_mm_dequant",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? out=None, Tensor? bias=None) -> Tensor",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor",
)


Expand All @@ -117,7 +129,6 @@ def _(
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
Expand All @@ -126,17 +137,13 @@ def _(

torch.library.define(
"bitsandbytes::int8_double_quant",
"(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
)


@register_fake("bitsandbytes::int8_double_quant")
def _(
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
row_stats: Optional[torch.Tensor] = None,
out_col: Optional[torch.Tensor] = None,
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
out_row = torch.empty_like(A, dtype=torch.int8)
Expand All @@ -156,12 +163,39 @@ def _(

@register_fake("bitsandbytes::dequantize_4bit")
def _(
A: torch.Tensor, absmax: torch.Tensor, blocksize: int, quant_type: str, shape: Sequence[int], dtype: torch.dtype
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)


torch.library.define(
"bitsandbytes::dequantize_4bit.out",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()",
)


@register_fake("bitsandbytes::dequantize_4bit.out")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")


torch.library.define(
"bitsandbytes::quantize_4bit",
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
Expand Down Expand Up @@ -194,6 +228,23 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
return torch.empty_like(A, dtype=dtype)


torch.library.define(
"bitsandbytes::dequantize_blockwise.out",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()",
)


@register_fake("bitsandbytes::dequantize_blockwise.out")
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
):
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")


torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")


Expand Down Expand Up @@ -229,3 +280,37 @@ def _(
)
shape = (*A.shape[:-1], shapeB[0])
return torch.empty(shape, device=A.device, dtype=A.dtype)


torch.library.define(
"bitsandbytes::gemv_4bit.out",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()",
)


@register_fake("bitsandbytes::gemv_4bit.out")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
12 changes: 11 additions & 1 deletion bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@


@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
def _(A: torch.Tensor, B: torch.Tensor):
return _int8_linear_matmul_impl(A, B)


@register_kernel("bitsandbytes::int8_linear_matmul.out", "cpu")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
torch._check(out.dtype == torch.int32)
_int8_linear_matmul_impl(A, B, out)


def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None):
# Naive implementation: perform matmul in fp32
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
if out is not None:
Expand Down
Loading

0 comments on commit 2d5b2cc

Please sign in to comment.