-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quantization] Add Quanto backend #10756
base: main
Are you sure you want to change the base?
Conversation
@@ -32,7 +42,7 @@ | |||
"loaders": ["FromOriginalModelMixin"], | |||
"models": [], | |||
"pipelines": [], | |||
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replacing direct imports here with dummy objects. It's a better guard for cases where imports in the quant configs might break the main diffusers import.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Very specific and thorough!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
||
- All features are available in eager mode (works with non-traceable models) | ||
- Supports quantization aware training | ||
- Quantized models are compatible with `torch.compile` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we verified this? Last time I checked only weight-quantized models were compatible with torch.compile
. Cc: @dacorvo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, but this should be fixed in pytorch 2.6 (I did not check though).
from diffusers import FluxTransformer2DModel, QuantoConfig | ||
|
||
model_id = "black-forest-labs/FLUX.1-dev" | ||
quantization_config = QuantoConfig(weights="float8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a comment to note that only weights
will be quantized.
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" | ||
quantization_config = QuantoConfig(weights="float8") | ||
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh lovely. Not to digress from this PR but would it make sense to also do something similar for bitsandbytes and torchao for from_single_file()
or not yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TorchAO should just work out of the box. We can add a section in the docs page.
For BnB the conversion step in single file is still a bottleneck. We need to figure out how to handle that gracefully.
- int4 | ||
- int2 | ||
|
||
### Activations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's show an example from the docs as well?
Additionally, we could refer the users to this blog post so that they have a sense of the savings around memory and latency?
"optimum_quanto>=0.2.6", | ||
"gguf>=0.10.0", | ||
"torchao>=0.7.0", | ||
"bitsandbytes>=0.43.3", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👀
Do we want to perhaps segregate them in terms of different quantization backends? Like when someone does pip install diffusers[quanto]
, only quanto along with the other required libraries would be installed and so on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
@@ -1041,7 +1041,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
model, | |||
state_dict, | |||
device=param_device, | |||
dtype=torch_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this going away?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is a mistake. Thanks for catching.
self, | ||
weights="int8", | ||
activations=None, | ||
modules_to_not_convert: Optional[List] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also include an example of using this in the docs? In https://huggingface.co/blog/quanto-diffusers, we showed one through exluce="proj_out"
.
weights="int8", | ||
activations=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong opinions but should these be weight_dtype
and activation_dtype
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
raise ValueError( | ||
"You are using `device_map='auto'` on an optimum-quanto quantized model. To automatically compute" | ||
" the appropriate device map, you should upgrade your `accelerate` library," | ||
"`pip install --upgrade accelerate` or install it from source." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device_map
related error seems irrelevant here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I copied this from transformers. But yeah the error doesn't seem relevant here.
|
||
def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list): | ||
# Quanto imports diffusers internally. These are placed here to avoid circular imports | ||
from optimum.quanto import QLinear, qfloat8, qint2, qint4, qint8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quanto does support Conv layers, though. Should we consider them in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a model that benefits from quantized Conv layers? I recall that it didn't work so great for SD UNets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a note could be nice if we're not adding QConv. In my experiments, I was able to quantize the VAE and obtain decent results.
|
||
model = _replace_layers(model, quantization_config, modules_to_not_convert) | ||
|
||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it also return has_been_replaced
similar to how it's done in bitsandbytes
?
return {"weights": "int4"} | ||
|
||
|
||
class FluxTransformerInt2(FluxTransformerQuantoMixin, unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add one integration test on a full Flux pipeline?
def get_dummy_init_kwargs(self): | ||
return {"weights": "float8"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also test for activations
and modules_to_not_convert
. Not sure if those are being tested already.
def get_dummy_init_kwargs(self): | ||
return {"weights": "int8"} | ||
|
||
def test_torch_compile(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this only a part of FluxTransformerInt8
. This is could be a part of FluxTransformerQuantoMixin
, no?
|
||
|
||
@nightly | ||
@require_big_gpu_with_torch_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we adding this marker because of the torch.compile()
test? Otherwise, the tests seem fine to be executed without the marker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! Looks really good already.
I have left some comments but my major comments are on tests.
I think we need to also add the test to our CI.
Co-authored-by: Sayak Paul <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing much to add from my end as Sayak covered most of it :) Just some comments regarding tests (maybe not required at this point so can be ignored)
If it is possible to use quanto with sequential cpu offloading, maybe we could add a test (the context for this is that sequential cpu offloading failed with torchao when it was initially added IIRC).
If training is supported, maybe worth adding a simple test for that.
As torch.compile is supported, maybe a test that runs forward on a compiled flux module would be cool.
If possible to use Quanto with device_map, maybe a test for that too.
It looks like serialization is not supported looking at is_serializable
method, but there's also an example showcasing that it is possible. If yes, maybe a test for that too would be nice
weights="int8", | ||
activations=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
"optimum_quanto>=0.2.6", | ||
"gguf>=0.10.0", | ||
"torchao>=0.7.0", | ||
"bitsandbytes>=0.43.3", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
src/diffusers/quantizers/auto.py
Outdated
"torchao": TorchAoConfig, | ||
"quanto": QuantoConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): maybe this should go above torchao to keep quantizers in alphabetical order (does not really have to be addressed and we can do it order of quantization backend addition as well)
with torch.no_grad(): | ||
model(**inputs) | ||
max_memory = torch.cuda.max_memory_allocated() | ||
assert (max_memory / 1024**3) < self.expected_memory_use_in_gb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm missing some context, but how was the expected_memory_use_in_gb = 5
obtained?
Would maybe write the test as:
- Record memory when running forward with normal model as
X
- Record memory when running with quanto model as
Y
- Assert
Y < X
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory didn't serve me right, and going back to BnB and TorchAO tests, it seems like we don't test memory usage this way there. We only check the memory footprint without a forward pass (so just the model memory and no activations). Maybe we should consider adding similar test as yours there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed.
def is_serializable(self, safe_serialization=None): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the docs, I see that it is mentioned that we can save/load the quanto quantized models, but here we return False
. Is this supposed to be True
?
|
||
## Using `torch.compile` with Quanto | ||
|
||
Currently the Quanto backend only supports `torch.compile` for `int8` weights and activations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure where this restriction comes from ... did you have issues with float8 or int4 because of the custom kernels ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I tested with Float8 weight quantization and ran into this error
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(1, 4096, 64), dtype=torch.bfloat16), MarlinF8QBytesTensor(MarlinF8PackedTensor(FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32)), scale=FakeTensor(..., device='cuda:0', size=(1, 3072), dtype=torch.bfloat16), dtype=torch.bfloat16)), **{'bias': Parameter(FakeTensor(..., device='cuda:0', size=(3072,), dtype=torch.bfloat16))}):
Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in quanto.gemm_f16f8_marlin.default(FakeTensor(..., device='cuda:0', size=(4096, 64), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(1, 3072), dtype=torch.bfloat16), tensor([...], device='cuda:0', size=(768,), dtype=torch.int32), 8, 4096, 3072, 64)
from user code:
File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 482, in forward
hidden_states = self.x_embedder(hidden_states)
File "/home/dhruv/optimum-quanto/optimum/quanto/nn/qlinear.py", line 50, in forward
return torch.nn.functional.linear(input, self.qweight, bias=self.bias)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
And with INT4 I was running into what looks like a dtype issue, which I don't seem to run into when I'm not using compile
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(1, 24, 4608, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(1, 24, 4608, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(1, 24, 4608, 128))), **{'dropout_p': 0.0, 'is_causal': False}):
Expected query, key, and value to have the same dtype, but got query.dtype: c10::BFloat16 key.dtype: c10::BFloat16 and value.dtype: float instead.
from user code:
File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 529, in forward
encoder_hidden_states, hidden_states = block(
File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 188, in forward
attention_outputs = self.attn(
File "/home/dhruv/diffusers/src/diffusers/models/attention_processor.py", line 595, in forward
return self.processor(
File "/home/dhruv/diffusers/src/diffusers/models/attention_processor.py", line 2328, in __call__
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you try out PyTorch nightlies? How's the performance improvement with torch.compile()
and int8?
from accelerate import init_empty_weights | ||
|
||
|
||
def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a rewrite of quanto quantize
method: why reimplementing it here ?
logger = logging.get_logger(__name__) | ||
|
||
|
||
class QuantoQuantizer(DiffusersQuantizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks very similar to quanto.QuantizedDiffusersModel
. I understand that you want to align backends, but maybe you could have used inheritance of composition to avoid rewriting too much code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed the implementation used in transformers
. Since we make use of Diffusers from_pretrained
to load and quantize the models, we would have to create a Quantizer object to do this.
I'm not sure using QuantizedDiffusersModel
would apply here since we would need the methods in this class to be defined for the other checks used in from_pretrained
.
Additionally, we skip replacing Conv and Layernorms in this implementation since quantizing Conv layers in Diffusion models isn't typical. And quantizing the Layernorms currently leads to an error when trying to load the state dict with Diffusers:
huggingface/optimum-quanto#371
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.