Skip to content

Commit

Permalink
Support AWQ models (#1049)
Browse files Browse the repository at this point in the history
* Support AWQ models

* Add tests

* Add dependencies

* Fix tests

* enable awq export only if ov support it

* fix style (#2)

* disable awq and gptq install for old torch (#3)

* fix style

* disable autogptq and autoawq install for old transformers testing

* separate common quant models patching and gptq (#4)

* disable windows install (#5)

* separate common quant models patching and gptq

* disable awq windows

* skip logits check for quantized models (#6)

* fix test after rebase

* fix testing condition for 2024.6 and unpatch in case if failed

* Fix qwen2-vl tests (#1084)

* Skip private mdoel loading test for external contributors (#1082)

* Fix reshaping unet if timestep is 0d tensor (#1083)

* Disable kv cache compression for fp vlm (#1080)

* Support AWQ models

* Add tests

* Add dependencies

* Fix tests

* enable awq export only if ov support it

* fix style (#2)

* disable awq and gptq install for old torch (#3)

* fix style

* disable autogptq and autoawq install for old transformers testing

* separate common quant models patching and gptq (#4)

* disable windows install (#5)

* separate common quant models patching and gptq

* disable awq windows

* skip logits check for quantized models (#6)

* fix test after rebase

* fix testing condition for 2024.6 and unpatch in case if failed

* add necessary packages in test_openvino_full

* fix code style after rebase (#7)

---------

Co-authored-by: eaidova <[email protected]>
Co-authored-by: Nikita Savelyev <[email protected]>
Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
4 people authored Dec 23, 2024
1 parent ea6fa42 commit 29b2ac9
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 155 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/test_openvino.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ jobs:
name: Install specific dependencies and versions required for older transformers
run: |
pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.* diffusers==0.30.* transformers_stream_generator
- if: ${{ matrix.transformers-version == 'latest' && matrix.test-pattern == '*modeling*'}}
name: Install auto-gptq, autoawq
run: |
pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu
- if: ${{ matrix.test-pattern == '*modeling*' }}
name: Uninstall NNCF
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/test_openvino_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ jobs:
if: ${{ matrix.transformers-version != 'latest' }}
run: pip install transformers==${{ matrix.transformers-version }}

- if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' }}
name: Install auto-gptq, autoawq
run: |
pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu
- name: Pip freeze
run: pip freeze

Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/test_openvino_slow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ jobs:
name: Install specific dependencies and versions required for older transformers
run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.*, diffusers==0.30.* transformers_stream_generator

- if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' }}
name: Install auto-gptq, autoawq
run: |
pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu
- name: Pip freeze
run: pip freeze

Expand Down
289 changes: 150 additions & 139 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def main_export(
)

do_gptq_patching = False
do_quant_patching = False
custom_architecture = False
patch_16bit = False
loading_kwargs = model_loading_kwargs or {}
Expand All @@ -247,7 +248,11 @@ def main_export(
trust_remote_code=trust_remote_code,
)
quantization_config = getattr(config, "quantization_config", None)
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
supported_quant_methods = ["gptq"]
if is_openvino_version(">=", "2024.6.0"):
supported_quant_methods.append("awq")
do_quant_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods
do_gptq_patching = do_quant_patching and quantization_config["quant_method"] == "gptq"
model_type = config.model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
Expand Down Expand Up @@ -296,7 +301,6 @@ def main_export(
if (
dtype is None
and framework == "pt"
and not do_gptq_patching
and (
task.startswith("text-generation")
or getattr(config, "model_type", None) in MULTI_MODAL_TEXT_GENERATION_MODELS
Expand All @@ -315,28 +319,28 @@ def main_export(
patch_16bit = True
loading_kwargs["torch_dtype"] = dtype
# Patch the modules to export of GPTQ models w/o GPU
if do_gptq_patching:
torch.set_default_dtype(torch.float32)
if do_quant_patching:
orig_cuda_check = torch.cuda.is_available
torch.cuda.is_available = lambda: True

from optimum.gptq import GPTQQuantizer
if do_gptq_patching:
from optimum.gptq import GPTQQuantizer

orig_post_init_model = GPTQQuantizer.post_init_model
orig_post_init_model = GPTQQuantizer.post_init_model

def post_init_model(self, model):
from auto_gptq import exllama_set_max_input_length
def post_init_model(self, model):
from auto_gptq import exllama_set_max_input_length

class StoreAttr(object):
pass
class StoreAttr(object):
pass

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
model = exllama_set_max_input_length(model, self.max_input_length)
return model
model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
model = exllama_set_max_input_length(model, self.max_input_length)
return model

GPTQQuantizer.post_init_model = post_init_model
GPTQQuantizer.post_init_model = post_init_model
elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"):
dtype = deduce_diffusers_dtype(
model_name_or_path,
Expand All @@ -351,143 +355,150 @@ class StoreAttr(object):
loading_kwargs["torch_dtype"] = dtype
patch_16bit = True

if library_name == "open_clip":
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
else:
model = TasksManager.get_model_from_task(
task,
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
framework=framework,
device=device,
library_name=library_name,
**loading_kwargs,
)
try:
if library_name == "open_clip":
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
else:
model = TasksManager.get_model_from_task(
task,
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
framework=framework,
device=device,
library_name=library_name,
**loading_kwargs,
)

needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None
needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None

if needs_pad_token_id:
if pad_token_id is not None:
model.config.pad_token_id = pad_token_id
else:
tok = AutoTokenizer.from_pretrained(model_name_or_path)
pad_token_id = getattr(tok, "pad_token_id", None)
if pad_token_id is None:
raise ValueError(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
)
model.config.pad_token_id = pad_token_id
if needs_pad_token_id:
if pad_token_id is not None:
model.config.pad_token_id = pad_token_id
else:
tok = AutoTokenizer.from_pretrained(model_name_or_path)
pad_token_id = getattr(tok, "pad_token_id", None)
if pad_token_id is None:
raise ValueError(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
)
model.config.pad_token_id = pad_token_id

if hasattr(model.config, "export_model_type"):
model_type = model.config.export_model_type.replace("_", "-")
else:
model_type = model.config.model_type.replace("_", "-")

if (
not custom_architecture
and library_name != "diffusers"
and task + "-with-past"
in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="openvino", library_name=library_name)
):
# Make -with-past the default if --task was not explicitely specified
if original_task == "auto":
task = task + "-with-past"
if hasattr(model.config, "export_model_type"):
model_type = model.config.export_model_type.replace("_", "-")
else:
logger.info(
f"The task `{task}` was manually specified, and past key values will not be reused in the decoding."
f" if needed, please pass `--task {task}-with-past` to export using the past key values."
model_type = model.config.model_type.replace("_", "-")

if (
not custom_architecture
and library_name != "diffusers"
and task + "-with-past"
in TasksManager.get_supported_tasks_for_model_type(
model_type, exporter="openvino", library_name=library_name
)
):
# Make -with-past the default if --task was not explicitely specified
if original_task == "auto":
task = task + "-with-past"
else:
logger.info(
f"The task `{task}` was manually specified, and past key values will not be reused in the decoding."
f" if needed, please pass `--task {task}-with-past` to export using the past key values."
)

if original_task == "auto":
synonyms_for_task = sorted(TasksManager.synonyms_for_task(task))
if synonyms_for_task:
synonyms_for_task = ", ".join(synonyms_for_task)
possible_synonyms = f" (possible synonyms are: {synonyms_for_task})"
else:
possible_synonyms = ""
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")
if original_task == "auto":
synonyms_for_task = sorted(TasksManager.synonyms_for_task(task))
if synonyms_for_task:
synonyms_for_task = ", ".join(synonyms_for_task)
possible_synonyms = f" (possible synonyms are: {synonyms_for_task})"
else:
possible_synonyms = ""
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")

preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)
preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)

submodel_paths = export_from_model(
model=model,
output=output,
task=task,
ov_config=ov_config,
stateful=stateful,
model_kwargs=model_kwargs,
custom_export_configs=custom_export_configs,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
device=device,
trust_remote_code=trust_remote_code,
patch_16bit_model=patch_16bit,
**kwargs_shapes,
)
submodel_paths = export_from_model(
model=model,
output=output,
task=task,
ov_config=ov_config,
stateful=stateful,
model_kwargs=model_kwargs,
custom_export_configs=custom_export_configs,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
device=device,
trust_remote_code=trust_remote_code,
patch_16bit_model=patch_16bit,
**kwargs_shapes,
)

if convert_tokenizer:
maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task)

clear_class_registry()
del model
gc.collect()

for submodel_path in submodel_paths:
submodel_path = Path(output) / submodel_path
submodel = core.read_model(submodel_path)

quantization_config = None
if ov_config is None:
num_parameters = 0
for op in submodel.get_ops():
if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]:
num_parameters += reduce(operator.mul, op.shape, 1)
del op
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
quantization_config = {"bits": 8, "sym": False}
logger.info("The model weights will be quantized to int8_asym.")
else:
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 "
"requires nncf. Please install it with `pip install nncf`"
)
break
else:
quantization_config = ov_config.quantization_config
if quantization_config is None:
del submodel
gc.collect()
continue
if convert_tokenizer:
maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task)

if not is_nncf_available():
raise ImportError("Quantization of the weights requires nncf, please install it with `pip install nncf`")
clear_class_registry()
del model
gc.collect()

from optimum.intel.openvino.quantization import _weight_only_quantization
for submodel_path in submodel_paths:
submodel_path = Path(output) / submodel_path
submodel = core.read_model(submodel_path)

quantization_config = None
if ov_config is None:
num_parameters = 0
for op in submodel.get_ops():
if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]:
num_parameters += reduce(operator.mul, op.shape, 1)
del op
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
quantization_config = {"bits": 8, "sym": False}
logger.info("The model weights will be quantized to int8_asym.")
else:
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 "
"requires nncf. Please install it with `pip install nncf`"
)
break
else:
quantization_config = ov_config.quantization_config
if quantization_config is None:
del submodel
gc.collect()
continue

if not is_nncf_available():
raise ImportError(
"Quantization of the weights requires nncf, please install it with `pip install nncf`"
)

_weight_only_quantization(submodel, quantization_config)
compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml"
save_model(submodel, compressed_submodel_path, compress_to_fp16=False)
del submodel
gc.collect()
from optimum.intel.openvino.quantization import _weight_only_quantization

submodel_path.unlink()
submodel_path.with_suffix(".bin").unlink()
compressed_submodel_path.rename(submodel_path)
compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin"))
_weight_only_quantization(submodel, quantization_config)
compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml"
save_model(submodel, compressed_submodel_path, compress_to_fp16=False)
del submodel
gc.collect()

# Unpatch modules after GPTQ export
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model
submodel_path.unlink()
submodel_path.with_suffix(".bin").unlink()
compressed_submodel_path.rename(submodel_path)
compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin"))

finally:
# Unpatch modules after quantized model export
if do_quant_patching:
torch.cuda.is_available = orig_cuda_check
if do_gptq_patching:
GPTQQuantizer.post_init_model = orig_post_init_model


def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None):
Expand Down
6 changes: 5 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,11 @@ def ts_patched_forward(*args, **kwargs):
from openvino.frontend.pytorch.patch_model import unpatch_model

unpatch_model(model, "_openvino_module_extension_patch_orig_forward")
model.to(torch.float32)
for m in model.modules():
if any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters(False)) or any(
b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers(False)
):
m.float()

return export_pytorch_via_onnx(
model,
Expand Down
Loading

0 comments on commit 29b2ac9

Please sign in to comment.