Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Custom Operator Integration #1544

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

Conversation

matthewdouglas
Copy link
Member

@matthewdouglas matthewdouglas commented Feb 27, 2025

Overview

This PR introduces the initial scaffolding to integrate PyTorch Custom Operators as the primary mechanism for dispatching to device-specific operator implementation.

As outlined in the related RFC #1545, the intent is that this will supersede the previous backend registration interface that was developed on the multi-backend-refactor branch. The baseline CUDA operators are established in this PR, and the implementation for additional backends is to be ported over to this new interface.

Why Custom Ops?

  • Registering operators with torch.library allows us to take advantage of the existing device dispatch mechanisms in PyTorch.
  • We can treat calls to functionality in our CUDA kernels, or other low-level backend implementations, as opaque for improved torch.compile support.
  • We can provide naive implementations of operators with only PyTorch code as a fallback option.
  • This helps to simplify the development for additional backends, while taking an idiomatic modern PyTorch approach.

Operator Definitions

We broadly categorize operator functionality into three feature groups, though there can be some overlap.

LLM.int8()

Inference requirements

  • int8_vectorwise_quant(A: Tensor, threshold: float = 0.0) -> (Tensor, Tensor, Tensor?)
    • Implements the LLM.int8() quantization algorithm with the specified threshold.
    • Returns an int8 quantized tensor, a float32 tensor containing the scaling stats, and an optional int32 tensor containing a list of column indices with outliers present.
  • int8_linear_dequant(A: Tensor, B: Tensor, row_stats: Tensor, col_stats: Tensor, bias: Tensor?, dtype=torch.float16) Name may change
    • By default, this is a composition of the below two operators. The choice can be made to implement one fused operator or two separately.
      • int8_linear_matmul(A: Tensor, B: Tensor) -> Tensor
        • Performs an 8-bit integer matrix multiplication between two int8 matrices.
        • Returns an int32 matrix: A @ B.T
      • int8_mm_dequant(A: Tensor, row_stats: Tensor, col_stats: Tensor, dtype=torch.float16, bias: Tensor?) -> Tensor
        • Dequantizes the result of a quantized 8-bit matrix multiplication with an optional fused bias.
        • The result is returned in the specified dtype, which is always torch.float16 for the current CUDA implementation.

Optional

  • int8_vectorwise_dequant(A: Tensor, stats: Tensor)
    • Dequantizes an int8 tensor that was quantized with int8_vectorwise_quant.
    • A default implementation in PyTorch is provided, which should work with any backend.
    • This is a utility utilized by Transformers, Diffusers, PEFT, and others.
  • int8_double_quant(A: Tensor, threshold: float = 0.0)
    • Quantizes the input tensor using the LLM.int8() algorithm across both dimensions.
    • This is only useful for full int8 training (e.g. not LoRA), and as such, we only recommend implementing int8_vectorwise_quant.

NF4/FP4

Minimal requirements

  • dequantize_4bit(A: Tensor, absmax: Tensor, blocksize: int, quant_type: Literal["nf4" | "fp4"], shape: int[], dtype) -> Tensor
    • Dequantizes a packed 4bit tensor into the specified floating point dtype.
    • Note: Unlike bitsandbytes.functional.dequantize_4bit, this operator does not dequantize the absmax tensor. If utilized, dequantize_blockwise must be performed first.
  • quantize_4bit(A: Tensor, blocksize: int, quant_type: Literal["nf4" | "fp4"], quant_storage=torch.uint8) -> (Tensor, Tensor)
    • Quantizes a floating point tensor into a packed 4bit tensor.
    • Returns a tensor with the quantized data packed into into bytes, backed by the storage type specified. The float32 absmax scaling factors are additionally returned.
    • Note: Unlike bitsandbytes.functional.quantize_4bit, this operator does not quantize the absmax tensor. If utilized, quantize_blockwise must be performed first.

Double quantization (aka compressed_statistics or nested)

  • dequantize_blockwise(A: Tensor, absmax: Tensor, code: Tensor, blocksize: int, dtype) -> Tensor
    • Dequantizes an 8bit tensor that was quantized with quantize_blockwise
    • The dequantized tensor with the specified dtype.
  • quantize_blockwise(A: Tensor, code: Tensor, blocksize: int) -> (Tensor, Tensor)
    • Quantizes into an 8bit blocked data type defined by code.
    • The blocksize will typically be 256 for usage with NF4/FP4 and optimizers.
    • Returns the quantized tensor in uint8 format, along with float32 absmax.

Optional

  • gemv_4bit
    • Fast path for bsz=1 inference with 4bit quantization. This operator is subject to some future revision.

Optimizers

Optimizer functionality will be implemented to support the custom operators in a future update.

@matthewdouglas matthewdouglas added high priority (first issues that will be worked on) cross-platform labels Feb 27, 2025
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +19 to +22
torch.library.define(
"bitsandbytes::int8_linear_dequant",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I'm the main maintainer of custom operators in PyTorch. I'm curious -- why not use the torch.library.custom_op API instead of torch.library.define?

It would look something like:

@torch.library.custom_op("bitsandbytes::int8_linear_dequant", mutates_args=())
def int8_linear_dequant(A: Tensor, B: Tensor, row_stats: Tensor, col_stats: Tensor, bias: Optional[Tensor], dtype: torch.dtype) -> Tensor:
    raise NotImplementedError("")
 
@int8_linear_dequant.register_fake
 def _(
    A: torch.Tensor,
    B: torch.Tensor,
    row_stats: torch.Tensor,
    col_stats: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    dtype=torch.float16,
) -> torch.Tensor:
    shapeC = (*A.shape[:-1], B.shape[0])
    return torch.empty(shapeC, device=A.device, dtype=dtype)   

We generally encourage people to use torch.library.custom_op because the custom ops produced from it are guarded from various footguns when compared to torch.library.Library.define

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for the feedback :)

While the custom_op API does look to be convenient, there are two main reasons it was avoided:

  1. I'm not sure if we're ready to bump our minimum PyTorch requirement to 2.4.0+. With that said, we're not strictly opposed to that, however.
  2. I've heard from some others that there was significant overhead introduced with the use of custom_op:

I am curious, is it still reasonable to make use of infer_schema, and is that API available in torch < 2.4?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! It's not clear to me if we have fully fixed the performance issues, but I will check.
torch.library.infer_schema is only available in 2.5+. So if your goal is to support older pytorch versions you are doing the right thing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cross-platform high priority (first issues that will be worked on)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants