Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Add Quanto backend #10756

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open

[Quantization] Add Quanto backend #10756

wants to merge 26 commits into from

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented Feb 10, 2025

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 DN6 requested review from a-r-r-o-w and sayakpaul and removed request for a-r-r-o-w February 10, 2025 07:26
@@ -32,7 +42,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing direct imports here with dummy objects. It's a better guard for cases where imports in the quant configs might break the main diffusers import.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Very specific and thorough!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


- All features are available in eager mode (works with non-traceable models)
- Supports quantization aware training
- Quantized models are compatible with `torch.compile`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified this? Last time I checked only weight-quantized models were compatible with torch.compile. Cc: @dacorvo.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but this should be fixed in pytorch 2.6 (I did not check though).

from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights="float8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a comment to note that only weights will be quantized.

Comment on lines 56 to 58
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
quantization_config = QuantoConfig(weights="float8")
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh lovely. Not to digress from this PR but would it make sense to also do something similar for bitsandbytes and torchao for from_single_file() or not yet?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchAO should just work out of the box. We can add a section in the docs page.

For BnB the conversion step in single file is still a bottleneck. We need to figure out how to handle that gracefully.

- int4
- int2

### Activations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's show an example from the docs as well?

Additionally, we could refer the users to this blog post so that they have a sense of the savings around memory and latency?

Comment on lines +130 to +133
"optimum_quanto>=0.2.6",
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀

Do we want to perhaps segregate them in terms of different quantization backends? Like when someone does pip install diffusers[quanto], only quanto along with the other required libraries would be installed and so on?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@@ -1041,7 +1041,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model,
state_dict,
device=param_device,
dtype=torch_dtype,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this going away?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is a mistake. Thanks for catching.

self,
weights="int8",
activations=None,
modules_to_not_convert: Optional[List] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also include an example of using this in the docs? In https://huggingface.co/blog/quanto-diffusers, we showed one through exluce="proj_out".

Comment on lines 698 to 699
weights="int8",
activations=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinions but should these be weight_dtype and activation_dtype?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines 108 to 112
raise ValueError(
"You are using `device_map='auto'` on an optimum-quanto quantized model. To automatically compute"
" the appropriate device map, you should upgrade your `accelerate` library,"
"`pip install --upgrade accelerate` or install it from source."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_map related error seems irrelevant here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I copied this from transformers. But yeah the error doesn't seem relevant here.


def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list):
# Quanto imports diffusers internally. These are placed here to avoid circular imports
from optimum.quanto import QLinear, qfloat8, qint2, qint4, qint8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quanto does support Conv layers, though. Should we consider them in this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a model that benefits from quantized Conv layers? I recall that it didn't work so great for SD UNets?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a note could be nice if we're not adding QConv. In my experiments, I was able to quantize the VAE and obtain decent results.


model = _replace_layers(model, quantization_config, modules_to_not_convert)

return model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also return has_been_replaced similar to how it's done in bitsandbytes?

return {"weights": "int4"}


class FluxTransformerInt2(FluxTransformerQuantoMixin, unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to add one integration test on a full Flux pipeline?

Comment on lines 33 to 34
def get_dummy_init_kwargs(self):
return {"weights": "float8"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also test for activations and modules_to_not_convert. Not sure if those are being tested already.

def get_dummy_init_kwargs(self):
return {"weights": "int8"}

def test_torch_compile(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this only a part of FluxTransformerInt8. This is could be a part of FluxTransformerQuantoMixin, no?



@nightly
@require_big_gpu_with_torch_cuda
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we adding this marker because of the torch.compile() test? Otherwise, the tests seem fine to be executed without the marker.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Looks really good already.

I have left some comments but my major comments are on tests.

I think we need to also add the test to our CI.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing much to add from my end as Sayak covered most of it :) Just some comments regarding tests (maybe not required at this point so can be ignored)

If it is possible to use quanto with sequential cpu offloading, maybe we could add a test (the context for this is that sequential cpu offloading failed with torchao when it was initially added IIRC).

If training is supported, maybe worth adding a simple test for that.

As torch.compile is supported, maybe a test that runs forward on a compiled flux module would be cool.

If possible to use Quanto with device_map, maybe a test for that too.

It looks like serialization is not supported looking at is_serializable method, but there's also an example showcasing that it is possible. If yes, maybe a test for that too would be nice

Comment on lines 698 to 699
weights="int8",
activations=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +130 to +133
"optimum_quanto>=0.2.6",
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines 48 to 49
"torchao": TorchAoConfig,
"quanto": QuantoConfig,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): maybe this should go above torchao to keep quantizers in alphabetical order (does not really have to be addressed and we can do it order of quantization backend addition as well)

with torch.no_grad():
model(**inputs)
max_memory = torch.cuda.max_memory_allocated()
assert (max_memory / 1024**3) < self.expected_memory_use_in_gb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing some context, but how was the expected_memory_use_in_gb = 5 obtained?

Would maybe write the test as:

  • Record memory when running forward with normal model as X
  • Record memory when running with quanto model as Y
  • Assert Y < X

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory didn't serve me right, and going back to BnB and TorchAO tests, it seems like we don't test memory usage this way there. We only check the memory footprint without a forward pass (so just the model memory and no activations). Maybe we should consider adding similar test as yours there

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed.

Comment on lines 166 to 167
def is_serializable(self, safe_serialization=None):
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the docs, I see that it is mentioned that we can save/load the quanto quantized models, but here we return False. Is this supposed to be True?


## Using `torch.compile` with Quanto

Currently the Quanto backend only supports `torch.compile` for `int8` weights and activations.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where this restriction comes from ... did you have issues with float8 or int4 because of the custom kernels ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I tested with Float8 weight quantization and ran into this error

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(1, 4096, 64), dtype=torch.bfloat16), MarlinF8QBytesTensor(MarlinF8PackedTensor(FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32)), scale=FakeTensor(..., device='cuda:0', size=(1, 3072), dtype=torch.bfloat16), dtype=torch.bfloat16)), **{'bias': Parameter(FakeTensor(..., device='cuda:0', size=(3072,), dtype=torch.bfloat16))}):
Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in quanto.gemm_f16f8_marlin.default(FakeTensor(..., device='cuda:0', size=(4096, 64), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(1, 3072), dtype=torch.bfloat16), tensor([...], device='cuda:0', size=(768,), dtype=torch.int32), 8, 4096, 3072, 64)

from user code:
   File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 482, in forward
    hidden_states = self.x_embedder(hidden_states)
  File "/home/dhruv/optimum-quanto/optimum/quanto/nn/qlinear.py", line 50, in forward
    return torch.nn.functional.linear(input, self.qweight, bias=self.bias)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

And with INT4 I was running into what looks like a dtype issue, which I don't seem to run into when I'm not using compile

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(1, 24, 4608, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(1, 24, 4608, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(1, 24, 4608, 128))), **{'dropout_p': 0.0, 'is_causal': False}):
Expected query, key, and value to have the same dtype, but got query.dtype: c10::BFloat16 key.dtype: c10::BFloat16 and value.dtype: float instead.

from user code:
   File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 529, in forward
    encoder_hidden_states, hidden_states = block(
  File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 188, in forward
    attention_outputs = self.attn(
  File "/home/dhruv/diffusers/src/diffusers/models/attention_processor.py", line 595, in forward
    return self.processor(
  File "/home/dhruv/diffusers/src/diffusers/models/attention_processor.py", line 2328, in __call__
    hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try out PyTorch nightlies? How's the performance improvement with torch.compile() and int8?

from accelerate import init_empty_weights


def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a rewrite of quanto quantize method: why reimplementing it here ?

logger = logging.get_logger(__name__)


class QuantoQuantizer(DiffusersQuantizer):
Copy link

@dacorvo dacorvo Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks very similar to quanto.QuantizedDiffusersModel. I understand that you want to align backends, but maybe you could have used inheritance of composition to avoid rewriting too much code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the implementation used in transformers. Since we make use of Diffusers from_pretrained to load and quantize the models, we would have to create a Quantizer object to do this.

I'm not sure using QuantizedDiffusersModel would apply here since we would need the methods in this class to be defined for the other checks used in from_pretrained.

Additionally, we skip replacing Conv and Layernorms in this implementation since quantizing Conv layers in Diffusion models isn't typical. And quantizing the Layernorms currently leads to an error when trying to load the state dict with Diffusers:
huggingface/optimum-quanto#371

@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

6 participants