diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index db35324a9..755236ebe 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -50,6 +50,11 @@ jobs: name: Install specific dependencies and versions required for older transformers run: | pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.* diffusers==0.30.* transformers_stream_generator + + - if: ${{ matrix.transformers-version == 'latest' && matrix.test-pattern == '*modeling*'}} + name: Install auto-gptq, autoawq + run: | + pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu - if: ${{ matrix.test-pattern == '*modeling*' }} name: Uninstall NNCF diff --git a/.github/workflows/test_openvino_full.yml b/.github/workflows/test_openvino_full.yml index 914035b75..ab852c653 100644 --- a/.github/workflows/test_openvino_full.yml +++ b/.github/workflows/test_openvino_full.yml @@ -78,6 +78,11 @@ jobs: if: ${{ matrix.transformers-version != 'latest' }} run: pip install transformers==${{ matrix.transformers-version }} + - if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' }} + name: Install auto-gptq, autoawq + run: | + pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu + - name: Pip freeze run: pip freeze diff --git a/.github/workflows/test_openvino_slow.yml b/.github/workflows/test_openvino_slow.yml index 8c3d9b2d3..a4e8a046b 100644 --- a/.github/workflows/test_openvino_slow.yml +++ b/.github/workflows/test_openvino_slow.yml @@ -49,6 +49,11 @@ jobs: name: Install specific dependencies and versions required for older transformers run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.*, diffusers==0.30.* transformers_stream_generator + - if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' }} + name: Install auto-gptq, autoawq + run: | + pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu + - name: Pip freeze run: pip freeze diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 592cd85a4..09f2eaa10 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -232,6 +232,7 @@ def main_export( ) do_gptq_patching = False + do_quant_patching = False custom_architecture = False patch_16bit = False loading_kwargs = model_loading_kwargs or {} @@ -247,7 +248,11 @@ def main_export( trust_remote_code=trust_remote_code, ) quantization_config = getattr(config, "quantization_config", None) - do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" + supported_quant_methods = ["gptq"] + if is_openvino_version(">=", "2024.6.0"): + supported_quant_methods.append("awq") + do_quant_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods + do_gptq_patching = do_quant_patching and quantization_config["quant_method"] == "gptq" model_type = config.model_type.replace("_", "-") if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True @@ -296,7 +301,6 @@ def main_export( if ( dtype is None and framework == "pt" - and not do_gptq_patching and ( task.startswith("text-generation") or getattr(config, "model_type", None) in MULTI_MODAL_TEXT_GENERATION_MODELS @@ -315,28 +319,28 @@ def main_export( patch_16bit = True loading_kwargs["torch_dtype"] = dtype # Patch the modules to export of GPTQ models w/o GPU - if do_gptq_patching: - torch.set_default_dtype(torch.float32) + if do_quant_patching: orig_cuda_check = torch.cuda.is_available torch.cuda.is_available = lambda: True - from optimum.gptq import GPTQQuantizer + if do_gptq_patching: + from optimum.gptq import GPTQQuantizer - orig_post_init_model = GPTQQuantizer.post_init_model + orig_post_init_model = GPTQQuantizer.post_init_model - def post_init_model(self, model): - from auto_gptq import exllama_set_max_input_length + def post_init_model(self, model): + from auto_gptq import exllama_set_max_input_length - class StoreAttr(object): - pass + class StoreAttr(object): + pass - model.quantize_config = StoreAttr() - model.quantize_config.desc_act = self.desc_act - if self.desc_act and not self.disable_exllama and self.max_input_length is not None: - model = exllama_set_max_input_length(model, self.max_input_length) - return model + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + if self.desc_act and not self.disable_exllama and self.max_input_length is not None: + model = exllama_set_max_input_length(model, self.max_input_length) + return model - GPTQQuantizer.post_init_model = post_init_model + GPTQQuantizer.post_init_model = post_init_model elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"): dtype = deduce_diffusers_dtype( model_name_or_path, @@ -351,143 +355,150 @@ class StoreAttr(object): loading_kwargs["torch_dtype"] = dtype patch_16bit = True - if library_name == "open_clip": - model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir) - else: - model = TasksManager.get_model_from_task( - task, - model_name_or_path, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - token=token, - local_files_only=local_files_only, - force_download=force_download, - trust_remote_code=trust_remote_code, - framework=framework, - device=device, - library_name=library_name, - **loading_kwargs, - ) + try: + if library_name == "open_clip": + model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir) + else: + model = TasksManager.get_model_from_task( + task, + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + framework=framework, + device=device, + library_name=library_name, + **loading_kwargs, + ) - needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None + needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None - if needs_pad_token_id: - if pad_token_id is not None: - model.config.pad_token_id = pad_token_id - else: - tok = AutoTokenizer.from_pretrained(model_name_or_path) - pad_token_id = getattr(tok, "pad_token_id", None) - if pad_token_id is None: - raise ValueError( - "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" - ) - model.config.pad_token_id = pad_token_id + if needs_pad_token_id: + if pad_token_id is not None: + model.config.pad_token_id = pad_token_id + else: + tok = AutoTokenizer.from_pretrained(model_name_or_path) + pad_token_id = getattr(tok, "pad_token_id", None) + if pad_token_id is None: + raise ValueError( + "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" + ) + model.config.pad_token_id = pad_token_id - if hasattr(model.config, "export_model_type"): - model_type = model.config.export_model_type.replace("_", "-") - else: - model_type = model.config.model_type.replace("_", "-") - - if ( - not custom_architecture - and library_name != "diffusers" - and task + "-with-past" - in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="openvino", library_name=library_name) - ): - # Make -with-past the default if --task was not explicitely specified - if original_task == "auto": - task = task + "-with-past" + if hasattr(model.config, "export_model_type"): + model_type = model.config.export_model_type.replace("_", "-") else: - logger.info( - f"The task `{task}` was manually specified, and past key values will not be reused in the decoding." - f" if needed, please pass `--task {task}-with-past` to export using the past key values." + model_type = model.config.model_type.replace("_", "-") + + if ( + not custom_architecture + and library_name != "diffusers" + and task + "-with-past" + in TasksManager.get_supported_tasks_for_model_type( + model_type, exporter="openvino", library_name=library_name ) + ): + # Make -with-past the default if --task was not explicitely specified + if original_task == "auto": + task = task + "-with-past" + else: + logger.info( + f"The task `{task}` was manually specified, and past key values will not be reused in the decoding." + f" if needed, please pass `--task {task}-with-past` to export using the past key values." + ) - if original_task == "auto": - synonyms_for_task = sorted(TasksManager.synonyms_for_task(task)) - if synonyms_for_task: - synonyms_for_task = ", ".join(synonyms_for_task) - possible_synonyms = f" (possible synonyms are: {synonyms_for_task})" - else: - possible_synonyms = "" - logger.info(f"Automatic task detection to {task}{possible_synonyms}.") + if original_task == "auto": + synonyms_for_task = sorted(TasksManager.synonyms_for_task(task)) + if synonyms_for_task: + synonyms_for_task = ", ".join(synonyms_for_task) + possible_synonyms = f" (possible synonyms are: {synonyms_for_task})" + else: + possible_synonyms = "" + logger.info(f"Automatic task detection to {task}{possible_synonyms}.") - preprocessors = maybe_load_preprocessors( - model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code - ) + preprocessors = maybe_load_preprocessors( + model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code + ) - submodel_paths = export_from_model( - model=model, - output=output, - task=task, - ov_config=ov_config, - stateful=stateful, - model_kwargs=model_kwargs, - custom_export_configs=custom_export_configs, - fn_get_submodels=fn_get_submodels, - preprocessors=preprocessors, - device=device, - trust_remote_code=trust_remote_code, - patch_16bit_model=patch_16bit, - **kwargs_shapes, - ) + submodel_paths = export_from_model( + model=model, + output=output, + task=task, + ov_config=ov_config, + stateful=stateful, + model_kwargs=model_kwargs, + custom_export_configs=custom_export_configs, + fn_get_submodels=fn_get_submodels, + preprocessors=preprocessors, + device=device, + trust_remote_code=trust_remote_code, + patch_16bit_model=patch_16bit, + **kwargs_shapes, + ) - if convert_tokenizer: - maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task) - - clear_class_registry() - del model - gc.collect() - - for submodel_path in submodel_paths: - submodel_path = Path(output) / submodel_path - submodel = core.read_model(submodel_path) - - quantization_config = None - if ov_config is None: - num_parameters = 0 - for op in submodel.get_ops(): - if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]: - num_parameters += reduce(operator.mul, op.shape, 1) - del op - if num_parameters >= _MAX_UNCOMPRESSED_SIZE: - if is_nncf_available(): - quantization_config = {"bits": 8, "sym": False} - logger.info("The model weights will be quantized to int8_asym.") - else: - logger.warning( - "The model will be converted with no weights quantization. Quantization of the weights to int8 " - "requires nncf. Please install it with `pip install nncf`" - ) - break - else: - quantization_config = ov_config.quantization_config - if quantization_config is None: - del submodel - gc.collect() - continue + if convert_tokenizer: + maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task) - if not is_nncf_available(): - raise ImportError("Quantization of the weights requires nncf, please install it with `pip install nncf`") + clear_class_registry() + del model + gc.collect() - from optimum.intel.openvino.quantization import _weight_only_quantization + for submodel_path in submodel_paths: + submodel_path = Path(output) / submodel_path + submodel = core.read_model(submodel_path) + + quantization_config = None + if ov_config is None: + num_parameters = 0 + for op in submodel.get_ops(): + if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]: + num_parameters += reduce(operator.mul, op.shape, 1) + del op + if num_parameters >= _MAX_UNCOMPRESSED_SIZE: + if is_nncf_available(): + quantization_config = {"bits": 8, "sym": False} + logger.info("The model weights will be quantized to int8_asym.") + else: + logger.warning( + "The model will be converted with no weights quantization. Quantization of the weights to int8 " + "requires nncf. Please install it with `pip install nncf`" + ) + break + else: + quantization_config = ov_config.quantization_config + if quantization_config is None: + del submodel + gc.collect() + continue + + if not is_nncf_available(): + raise ImportError( + "Quantization of the weights requires nncf, please install it with `pip install nncf`" + ) - _weight_only_quantization(submodel, quantization_config) - compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml" - save_model(submodel, compressed_submodel_path, compress_to_fp16=False) - del submodel - gc.collect() + from optimum.intel.openvino.quantization import _weight_only_quantization - submodel_path.unlink() - submodel_path.with_suffix(".bin").unlink() - compressed_submodel_path.rename(submodel_path) - compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin")) + _weight_only_quantization(submodel, quantization_config) + compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml" + save_model(submodel, compressed_submodel_path, compress_to_fp16=False) + del submodel + gc.collect() - # Unpatch modules after GPTQ export - if do_gptq_patching: - torch.cuda.is_available = orig_cuda_check - GPTQQuantizer.post_init_model = orig_post_init_model + submodel_path.unlink() + submodel_path.with_suffix(".bin").unlink() + compressed_submodel_path.rename(submodel_path) + compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin")) + + finally: + # Unpatch modules after quantized model export + if do_quant_patching: + torch.cuda.is_available = orig_cuda_check + if do_gptq_patching: + GPTQQuantizer.post_init_model = orig_post_init_model def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None): diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 432e52d0c..636bc00f0 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -456,7 +456,11 @@ def ts_patched_forward(*args, **kwargs): from openvino.frontend.pytorch.patch_model import unpatch_model unpatch_model(model, "_openvino_module_extension_patch_orig_forward") - model.to(torch.float32) + for m in model.modules(): + if any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters(False)) or any( + b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers(False) + ): + m.float() return export_pytorch_via_onnx( model, diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index dc43c9458..0f166a635 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -15,6 +15,7 @@ import copy import gc import os +import platform import tempfile import time import unittest @@ -62,7 +63,7 @@ ) from transformers.onnx.utils import get_preprocessor from transformers.testing_utils import slow -from utils_tests import MODEL_NAMES, TEST_IMAGE_URL +from utils_tests import MODEL_NAMES, TEST_IMAGE_URL, mock_torch_cuda_is_available, patch_awq_for_inference from optimum.exporters.openvino.model_patcher import patch_update_causal_mask from optimum.intel import ( @@ -872,7 +873,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "gpt_neo", "gpt_neox", "llama", - # "llama_gptq", "marian", "minicpm", "mistral", @@ -917,6 +917,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "minicpm3", ) + # gptq and awq install disabled for windows test environment + if platform.system() != "Windows": + SUPPORTED_ARCHITECTURES += ("opt_gptq",) + + # autoawq install disabled for windows test environment + if is_openvino_version(">=", "2024.6.0") and platform.system() != "Windows": + SUPPORTED_ARCHITECTURES += ("mixtral_awq",) + GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ( "chatglm", @@ -949,9 +957,6 @@ def test_compare_to_transformers(self, model_arch): if is_openvino_version("<", "2024.1"): not_stateful.extend(["llama", "gemma", "gpt_bigcode"]) - if "gptq" in model_arch: - self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM") - set_seed(SEED) model_kwargs = {} @@ -978,20 +983,27 @@ def test_compare_to_transformers(self, model_arch): if is_stateful: self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) + if "awq" in model_arch or "gptq" in model_arch: + # infer in FP32 + model_kwargs["torch_dtype"] = torch.float32 + set_seed(SEED) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + with mock_torch_cuda_is_available("awq" in model_arch or "gptq" in model_arch): + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) if model_arch in ["qwen", "arctic", "glm4"]: transformers_model.to(torch.float32) with torch.no_grad(): - transformers_outputs = transformers_model(**tokens) + with patch_awq_for_inference("awq" in model_arch): + transformers_outputs = transformers_model(**tokens) # Compare tensor outputs atol = 1e-3 if model_arch == "minicpm" else 1e-4 - self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=atol)) + # quantized models have different logits value range + if "awq" not in model_arch and "gptq" not in model_arch: + self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=atol)) # Qwen tokenizer does not support padding - if model_arch in ["qwen"]: return @@ -1025,7 +1037,12 @@ def test_compare_to_transformers(self, model_arch): from transformers.cache_utils import DynamicCache additional_inputs = {"past_key_values": DynamicCache()} - transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs) + with patch_awq_for_inference("awq" in model_arch): + transformers_outputs = transformers_model.generate( + **tokens, generation_config=gen_config, **additional_inputs + ) + print(f"ov_outputs: {ov_outputs}") + print(f"transformers_outputs: {transformers_outputs}") self.assertTrue( torch.allclose(ov_outputs, transformers_outputs), "OV output {ov_outputs}\nTransformers output {transformers_output}", @@ -1261,8 +1278,13 @@ def test_beam_search(self, model_arch): ov_model_stateless = OVModelForCausalLM.from_pretrained( model_id, export=True, use_cache=True, stateful=False, **model_kwargs ) + if "awq" in model_arch or "gptq" in model_arch: + # infer in FP32 + model_kwargs["torch_dtype"] = torch.float32 + set_seed(SEED) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + with mock_torch_cuda_is_available("awq" in model_arch or "gptq" in model_arch): + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) if model_arch == "arctic": transformers_model.to(torch.float32) @@ -1288,9 +1310,10 @@ def test_beam_search(self, model_arch): if model_arch == "gemma2": additional_inputs = {"past_key_values": DynamicCache()} - transformers_outputs = transformers_model.generate( - **tokens, generation_config=gen_config, **additional_inputs - ) + with patch_awq_for_inference("awq" in model_arch): + transformers_outputs = transformers_model.generate( + **tokens, generation_config=gen_config, **additional_inputs + ) set_seed(SEED) ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) self.assertTrue( diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index d139d8cb2..2011e11f0 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest +from contextlib import contextmanager from typing import Dict, List, Union import numpy as np @@ -81,12 +82,12 @@ "longt5": "hf-internal-testing/tiny-random-longt5", "llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM", "llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM", - "llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ", "llava": "katuni4ka/tiny-random-llava", "llava_next": "katuni4ka/tiny-random-llava-next", "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "opt": "hf-internal-testing/tiny-random-OPTModel", "opt125m": "facebook/opt-125m", + "opt_gptq": "ybelkada/opt-125m-gptq-4bit", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "minicpm": "katuni4ka/tiny-random-minicpm", @@ -95,6 +96,7 @@ "mistral": "echarlaix/tiny-random-mistral", "mistral-nemo": "katuni4ka/tiny-random-mistral-nemo", "mixtral": "TitanML/tiny-mixtral", + "mixtral_awq": "TitanML/tiny-mixtral-AWQ-4bit", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet_v1": "google/mobilenet_v1_0.75_192", "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", @@ -226,6 +228,61 @@ def get_num_quantized_nodes(model): return num_fake_quantize, num_weight_nodes +@contextmanager +def mock_torch_cuda_is_available(to_patch): + original_is_available = torch.cuda.is_available + if to_patch: + torch.cuda.is_available = lambda: True + try: + yield + finally: + if to_patch: + torch.cuda.is_available = original_is_available + + +@contextmanager +def patch_awq_for_inference(to_patch): + orig_gemm_forward = None + if to_patch: + # patch GEMM module to allow inference without CUDA GPU + from awq.modules.linear.gemm import WQLinearMMFunction + from awq.utils.packing_utils import dequantize_gemm + + def new_forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + ): + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + + out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) + out = torch.matmul(x, out) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out + + orig_gemm_forward = WQLinearMMFunction.forward + WQLinearMMFunction.forward = new_forward + try: + yield + finally: + if orig_gemm_forward is not None: + WQLinearMMFunction.forward = orig_gemm_forward + + def compare_num_quantized_nodes_per_model( test_case: unittest.TestCase, models: List[Union[ov.Model, OVBaseModel]],