From 3888824b620c37d87536d84d8b35a8be5f985ac8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 9 Dec 2024 12:31:18 +0000 Subject: [PATCH 01/11] enable IPEXModelForSeq2SeqLM Signed-off-by: jiqing-feng --- optimum/intel/__init__.py | 2 + optimum/intel/ipex/__init__.py | 1 + optimum/intel/ipex/modeling_base.py | 85 ++++++++++++++++++++++- optimum/intel/ipex/utils.py | 1 + optimum/intel/pipelines/pipeline_base.py | 7 ++ optimum/intel/utils/dummy_ipex_objects.py | 11 +++ 6 files changed, 104 insertions(+), 3 deletions(-) diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index 0230394d29..6091243894 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -46,6 +46,7 @@ else: _import_structure["ipex"] = [ "IPEXModelForCausalLM", + "IPEXModelForSeq2SeqLM", "IPEXModelForSequenceClassification", "IPEXModelForMaskedLM", "IPEXModelForTokenClassification", @@ -237,6 +238,7 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) diff --git a/optimum/intel/ipex/__init__.py b/optimum/intel/ipex/__init__.py index c1f711acfc..79ed4734d3 100644 --- a/optimum/intel/ipex/__init__.py +++ b/optimum/intel/ipex/__init__.py @@ -19,6 +19,7 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 8611bddd21..161684d9c2 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -30,6 +30,7 @@ AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, GenerationConfig, @@ -57,6 +58,7 @@ logger = logging.getLogger(__name__) +_IPEX_SUPPORTED_GENERATION_TASKS = ("text-generation", "text2text-generation") _IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2") _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") _IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0" @@ -106,9 +108,9 @@ def __init__( # Non-generation tasks can use torch.compile to get acceleration. if ( - model.device.type == "cpu" - and self.export_feature not in _IPEX_EXPORTED_GENERATION_TASKS - and config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES + self.model.device.type == "cpu" + and self.export_feature not in _IPEX_SUPPORTED_GENERATION_TASKS + and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE) ): from torch._inductor import config @@ -336,6 +338,83 @@ def generate(self, *args, **kwargs): return result +class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin): + auto_model_class = AutoModelForSeq2SeqLM + export_feature = "text2text-generation" + + def __init__( + self, + model, + config: PretrainedConfig = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + use_cache: bool = True, + **kwargs, + ): + super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache) + + self._supports_cache_class = getattr(model, "_supports_cache_class", None) + self._supports_sdpa = getattr(model, "_supports_sdpa", None) + self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) + self._supports_static_cache = getattr(model, "_supports_static_cache", None) + + GenerationMixin.__init__(self) + + model_type = self.config.model_type.replace("_", "-") + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config) + + self.config.is_decoder = False + self.config.is_encoder_decoder = True + + self.generation_config = GenerationConfig.from_model_config(self.config) + try: + self.model_cls = get_class_from_dynamic_module( + self.config.auto_map["AutoModelForSeq2SeqLM"], model_save_dir + ) + except AttributeError: + self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping) + + if hasattr(self.model_cls, "_convert_to_standard_cache"): + self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache + if ( + self._supports_static_cache + and self.model.device.type == "cpu" + and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES + and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE) + ): + from torch._inductor import config + + # Use static cache for torch.compile + self.model.config.cache_implementation = "static" + self.config.cache_implementation = "static" + # System level optimization + torch._inductor.config.cpp_wrapper = True + os.environ["TORCHINDUCTOR_FREEZING"] = "1" + logger.info("Enable torch.compile optimization, start warm up") + self.model.forward = torch.compile(self.model.forward) + inputs = prepare_jit_inputs(model, self.export_feature, False) + self.model.generate(**inputs, max_length=4) + self.model.generate(**inputs, max_length=4) + logger.info("Warm up end") + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> CausalLMOutputWithPast: + return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + + def _reorder_cache(self, *args, **kwargs): + return self.model._reorder_cache(*args, **kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + def get_encoder(self, *args, **kwargs): + return self.model.get_encoder(*args, **kwargs) + + def _ipex_crop_past_key_values(model, past_key_values, max_length): if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"): if isinstance(past_key_values, IPEXPagedCache): diff --git a/optimum/intel/ipex/utils.py b/optimum/intel/ipex/utils.py index 3d3feb3db2..23126bcd4c 100644 --- a/optimum/intel/ipex/utils.py +++ b/optimum/intel/ipex/utils.py @@ -16,6 +16,7 @@ _HEAD_TO_AUTOMODELS = { "feature-extraction": "IPEXModel", "text-generation": "IPEXModelForCausalLM", + "text2text-generation": "IPEXModelForSeq2SeqLM", "text-classification": "IPEXModelForSequenceClassification", "token-classification": "IPEXModelForTokenClassification", "question-answering": "IPEXModelForQuestionAnswering", diff --git a/optimum/intel/pipelines/pipeline_base.py b/optimum/intel/pipelines/pipeline_base.py index d26d8c42b6..e301f1e995 100644 --- a/optimum/intel/pipelines/pipeline_base.py +++ b/optimum/intel/pipelines/pipeline_base.py @@ -58,6 +58,7 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) @@ -69,6 +70,12 @@ "default": "gpt2", "type": "text", }, + "text2text-generation": { + "impl": TextGenerationPipeline, + "class": (IPEXModelForSeq2SeqLM,), + "default": "google-t5/t5-small", + "type": "text", + }, "fill-mask": { "impl": FillMaskPipeline, "class": (IPEXModelForMaskedLM,), diff --git a/optimum/intel/utils/dummy_ipex_objects.py b/optimum/intel/utils/dummy_ipex_objects.py index 4bd7eee630..2a0db565d3 100644 --- a/optimum/intel/utils/dummy_ipex_objects.py +++ b/optimum/intel/utils/dummy_ipex_objects.py @@ -70,6 +70,17 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["ipex"]) +class IPEXModelForSeq2SeqLM(metaclass=DummyObject): + _backends = ["ipex"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["ipex"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["ipex"]) + + class IPEXModelForQuestionAnswering(metaclass=DummyObject): _backends = ["ipex"] From f9fa8074ec251de2060fc4a81e1eb5805fdcef4b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 9 Dec 2024 13:23:53 +0000 Subject: [PATCH 02/11] set static cache Signed-off-by: jiqing-feng --- optimum/intel/ipex/modeling_base.py | 16 +++++++++++----- optimum/intel/pipelines/pipeline_base.py | 16 ++++++++++++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 161684d9c2..7038515ad3 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -383,17 +383,14 @@ def __init__( ): from torch._inductor import config - # Use static cache for torch.compile - self.model.config.cache_implementation = "static" - self.config.cache_implementation = "static" # System level optimization torch._inductor.config.cpp_wrapper = True os.environ["TORCHINDUCTOR_FREEZING"] = "1" logger.info("Enable torch.compile optimization, start warm up") self.model.forward = torch.compile(self.model.forward) inputs = prepare_jit_inputs(model, self.export_feature, False) - self.model.generate(**inputs, max_length=4) - self.model.generate(**inputs, max_length=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4) logger.info("Warm up end") @torch.no_grad() @@ -405,6 +402,15 @@ def forward( ) -> CausalLMOutputWithPast: return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], **kwargs: Dict + ) -> Tuple[GenerationConfig, Dict]: + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) + # Use static cache for torch.compile + setattr(generation_config, "cache_implementation", "static") + + return generation_config, model_kwargs + def _reorder_cache(self, *args, **kwargs): return self.model._reorder_cache(*args, **kwargs) diff --git a/optimum/intel/pipelines/pipeline_base.py b/optimum/intel/pipelines/pipeline_base.py index e301f1e995..7bbcbc00cc 100644 --- a/optimum/intel/pipelines/pipeline_base.py +++ b/optimum/intel/pipelines/pipeline_base.py @@ -70,10 +70,22 @@ "default": "gpt2", "type": "text", }, + "summarization": { + "impl": SummarizationPipeline, + "class": (IPEXModelForSeq2SeqLM,), + "default": "t5-base", + "type": "text", + }, + "translation": { + "impl": TranslationPipeline, + "class": (IPEXModelForSeq2SeqLM,), + "default": "t5-small", + "type": "text", + }, "text2text-generation": { - "impl": TextGenerationPipeline, + "impl": Text2TextGenerationPipeline, "class": (IPEXModelForSeq2SeqLM,), - "default": "google-t5/t5-small", + "default": "t5-small", "type": "text", }, "fill-mask": { From 202df432124c3a0f3c132c5d4654bea20931a9d6 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 9 Dec 2024 14:42:21 +0000 Subject: [PATCH 03/11] add tests for IPEXModelForSeq2SeqLM Signed-off-by: jiqing-feng --- tests/ipex/test_modeling.py | 122 +++++++++++++++++++++++++++++++++++ tests/ipex/test_pipelines.py | 44 +++++++++++++ 2 files changed, 166 insertions(+) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 7f1104d7f7..9e921d30af 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -26,6 +26,7 @@ from transformers import ( AutoFeatureExtractor, AutoModelForCausalLM, + AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, AutoTokenizer, GenerationConfig, @@ -38,6 +39,7 @@ IPEXModel, IPEXModelForAudioClassification, IPEXModelForCausalLM, + IPEXModelForSeq2SeqLM, IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, @@ -510,3 +512,123 @@ def test_patched_model(self): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + + +class IPEXModelForSeq2SeqLMTest(unittest.TestCase): + IPEX_MODEL_CLASS = IPEXModelForSeq2SeqLM + SUPPORTED_ARCHITECTURES = ("t5",) + GENERATION_LENGTH = 2 + SPEEDUP_CACHE = 1.0 + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + # Test model forward do not need cache. + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype) + transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype) + self.assertIsInstance(ipex_model.config, PretrainedConfig) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample", + return_tensors="pt", + return_token_type_ids=False if model_arch in ("llama", "llama2") else None, + ) + decoder_start_token_id = transformers_model.config.decoder_start_token_id if model_arch != "mbart" else 2 + decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} + outputs = ipex_model(**tokens, **decoder_inputs) + + self.assertIsInstance(outputs.logits, torch.Tensor) + + with torch.no_grad(): + transformers_outputs = transformers_model(**tokens, **decoder_inputs) + + # Test re-load model + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype) + loaded_model_outputs = loaded_model(**tokens, **decoder_inputs) + + # Test init method + init_model = self.IPEX_MODEL_CLASS(transformers_model) + init_model_outputs = init_model(**tokens, **decoder_inputs) + + # Compare tensor outputs + self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + # To avoid float pointing error + self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7)) + self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline(self, model_arch): + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model_id = MODEL_NAMES[model_arch] + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype) + model.config.encoder_no_repeat_ngram_size = 0 + # model.to("cpu") + pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer) + outputs = pipe("This is a sample", max_new_tokens=10, do_sample=False) + self.assertEqual(pipe.device, model.device) + + def test_compare_with_and_without_past_key_values(self): + model_id = "hf-internal-testing/tiny-random-t5" + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model_with_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=True, torch_dtype=dtype) + device = model_with_pkv.device + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt").to(device) + # Warmup + model_with_pkv.generate(**tokens) + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=False, torch_dtype=dtype) + # Warmup + model_without_pkv.generate(**tokens) + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 + ) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + 1) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + 1) + + @parameterized.expand( + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "use_cache": [True, False], + } + ) + ) + def test_ipex_beam_search(self, test_name, model_arch, use_cache): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=use_cache, torch_dtype=dtype) + device = model.device + transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype).to(device) + self.assertEqual(model.use_cache, use_cache) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + # Test with batch_size is 1 and 2. + texts = ["This is a sample", ["This is the first input", "This is the second input"]] + generation_configs = ( + GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=False), + GenerationConfig( + max_new_tokens=4, do_sample=False, top_p=0.9, top_k=0, pad_token_id=tokenizer.eos_token_id + ), + ) + for text in texts: + tokens = tokenizer(text, padding=True, return_tensors="pt").to(device) + for generation_config in generation_configs: + outputs = model.generate(**tokens, generation_config=generation_config) + transformers_outputs = transformers_model.generate(**tokens, generation_config=generation_config) + self.assertIsInstance(outputs, torch.Tensor) + self.assertTrue(torch.equal(outputs, transformers_outputs)) diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py index 77790e19f4..e67a4a6e18 100644 --- a/tests/ipex/test_pipelines.py +++ b/tests/ipex/test_pipelines.py @@ -28,6 +28,7 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) @@ -82,6 +83,7 @@ class PipelinesIntegrationTest(unittest.TestCase): "resnet", "vit", ) + TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("t5",) @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) def test_token_classification_pipeline_inference(self, model_arch): @@ -215,3 +217,45 @@ def test_pipeline_load_from_jit_model(self, model_arch): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) self.assertGreaterEqual(ipex_output[0]["score"], 0.0) + + @parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_text2text_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("text2text-generation", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("text2text-generation", model_id, accelerator="ipex", torch_dtype=dtype) + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM)) + self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"]) + + @parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_summarization_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("summarization", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("summarization", model_id, accelerator="ipex", torch_dtype=dtype) + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM)) + self.assertEqual(transformers_output[0]["summary_text"], ipex_output[0]["summary_text"]) + + @parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_translation_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("translation", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("translation", model_id, accelerator="ipex", torch_dtype=dtype) + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM)) + self.assertEqual(transformers_output[0]["translation_text"], ipex_output[0]["translation_text"]) From 44880736a165f7548a2e9c8e4485a3f11f03e3f6 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 9 Dec 2024 14:54:48 +0000 Subject: [PATCH 04/11] add docs Signed-off-by: jiqing-feng --- docs/source/ipex/inference.mdx | 1 + docs/source/ipex/models.mdx | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/source/ipex/inference.mdx b/docs/source/ipex/inference.mdx index 54b586924d..bcbf234566 100644 --- a/docs/source/ipex/inference.mdx +++ b/docs/source/ipex/inference.mdx @@ -43,3 +43,4 @@ As shown in the table below, each task is associated with a class enabling to au | `IPEXModelForMaskedLM` | `fill-mask` | | `IPEXModelForAudioClassification` | `audio-classification` | | `IPEXModelForCausalLM` | `text-generation` | +| `IPEXModelForSeq2SeqLM` | `text2text-generation` | diff --git a/docs/source/ipex/models.mdx b/docs/source/ipex/models.mdx index 346ca26599..b8cd6c482f 100644 --- a/docs/source/ipex/models.mdx +++ b/docs/source/ipex/models.mdx @@ -40,6 +40,7 @@ Here is the list of the supported architectures : - Roberta - Roformer - SqueezeBert +- T5 - UniSpeech - Vit - Wav2Vec2 From 16fecf8b1bbbcd496eebcef99f7c8d18dbef44d1 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 9 Dec 2024 15:01:31 +0000 Subject: [PATCH 05/11] fix readme Signed-off-by: jiqing-feng --- docs/source/ipex/inference.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/ipex/inference.mdx b/docs/source/ipex/inference.mdx index bcbf234566..72826da595 100644 --- a/docs/source/ipex/inference.mdx +++ b/docs/source/ipex/inference.mdx @@ -14,7 +14,7 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m ## Loading -You can load your model and apply IPEX optimizations (apply torch.compile for non-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators. +You can load your model and apply IPEX optimizations (apply torch.compile except text-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators. For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22. ```diff From 4225bf03b03bc403754e6dd10d2dd285c129bde2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 11 Dec 2024 12:06:44 +0000 Subject: [PATCH 06/11] refactor compile Signed-off-by: jiqing-feng --- optimum/intel/ipex/modeling_base.py | 119 +++++++++++++++------------- 1 file changed, 64 insertions(+), 55 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 7038515ad3..55b6e93522 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -58,12 +58,11 @@ logger = logging.getLogger(__name__) -_IPEX_SUPPORTED_GENERATION_TASKS = ("text-generation", "text2text-generation") _IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2") _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") _IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0" -# TODO: Already fixed in torch 2.6, will enable when torch upgrading to 2.6 -_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit") +# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6 +_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit", "llama", "falcon", "gpt2") def _is_patched_with_ipex(model, task, use_cache: bool = True): @@ -86,15 +85,21 @@ def __init__( model, config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + warmup: Optional[bool] = True, **kwargs, ): config = config or model.config OptimizedModel.__init__(self, model=model, config=config) + self._supports_cache_class = getattr(model, "_supports_cache_class", None) + self._supports_sdpa = getattr(model, "_supports_sdpa", None) + self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) + self._supports_static_cache = getattr(model, "_supports_static_cache", None) self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32 self.use_cache = kwargs.get("use_cache", False) self.model_save_dir = model_save_dir self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache) + self.compiled = False self.input_names = set(inspect.signature(model.forward).parameters) @@ -106,25 +111,10 @@ def __init__( if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) - # Non-generation tasks can use torch.compile to get acceleration. - if ( - self.model.device.type == "cpu" - and self.export_feature not in _IPEX_SUPPORTED_GENERATION_TASKS - and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES - and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE) - ): - from torch._inductor import config - - # System level optimization - torch._inductor.config.cpp_wrapper = True - os.environ["TORCHINDUCTOR_FREEZING"] = "1" - logger.info("Enable torch.compile optimization, start warm up") - self.model.forward = torch.compile(self.model.forward) - inputs = prepare_jit_inputs(model, self.export_feature, False) - with torch.no_grad(): - self.model(**inputs) - self.model(**inputs) - logger.info("Warm up end") + self.maybe_apply_torch_compile() + + if warmup: + self._init_warmup() @classmethod def _from_transformers(cls, *args, **kwargs): @@ -194,6 +184,31 @@ def to(self, device: Union[torch.device, str]): def can_generate(self): return isinstance(self, GenerationMixin) + def maybe_apply_torch_compile(self): + if ( + not self.model.device.type != "cpu" + or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES + or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE) + ): + return + if self.use_cache and not self._supports_static_cache: + return + from torch._inductor import config + + # System level optimization + torch._inductor.config.cpp_wrapper = True + os.environ["TORCHINDUCTOR_FREEZING"] = "1" + logger.info("Enable torch.compile optimization") + self.model.forward = torch.compile(self.model.forward) + self.compiled = True + + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + with torch.no_grad(): + self.model(**inputs) + self.model(**inputs) + logger.info("Warm up end") + class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -238,16 +253,10 @@ def __init__( config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, + warmup: Optional[bool] = True, **kwargs, ): - super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache) - - self._supports_cache_class = getattr(model, "_supports_cache_class", None) - self._supports_sdpa = getattr(model, "_supports_sdpa", None) - self._supports_cache_class = getattr(model, "_supports_cache_class", None) - self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) - self._supports_static_cache = getattr(model, "_supports_static_cache", None) - + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache) if self._add_patch: self._supports_cache_class = True GenerationMixin.__init__(self) @@ -271,6 +280,9 @@ def __init__( if hasattr(self.model_cls, "_convert_to_bloom_cache"): self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache + if warmup: + self._init_warmup() + @torch.no_grad() def forward( self, @@ -285,6 +297,9 @@ def _prepare_generation_config( ) -> Tuple[GenerationConfig, Dict]: generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) generation_method = generation_config.get_generation_mode().value + if self.compiled: + # Use static cache for torch compile + generation_config.cache_implementation = "static" if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS: raise ValueError( f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" @@ -337,6 +352,12 @@ def generate(self, *args, **kwargs): return result + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + logger.info("Warm up end") + class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin): auto_model_class = AutoModelForSeq2SeqLM @@ -348,15 +369,10 @@ def __init__( config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, + warmup: Optional[bool] = True, **kwargs, ): - super().__init__(model, config, model_save_dir=model_save_dir, use_cache=use_cache) - - self._supports_cache_class = getattr(model, "_supports_cache_class", None) - self._supports_sdpa = getattr(model, "_supports_sdpa", None) - self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) - self._supports_static_cache = getattr(model, "_supports_static_cache", None) - + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache) GenerationMixin.__init__(self) model_type = self.config.model_type.replace("_", "-") @@ -375,23 +391,9 @@ def __init__( if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache - if ( - self._supports_static_cache - and self.model.device.type == "cpu" - and self.config.model_type not in _COMPILE_NOT_READY_MODEL_TYPES - and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_COMPILE) - ): - from torch._inductor import config - - # System level optimization - torch._inductor.config.cpp_wrapper = True - os.environ["TORCHINDUCTOR_FREEZING"] = "1" - logger.info("Enable torch.compile optimization, start warm up") - self.model.forward = torch.compile(self.model.forward) - inputs = prepare_jit_inputs(model, self.export_feature, False) - self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4) - self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4) - logger.info("Warm up end") + + if warmup: + self._init_warmup() @torch.no_grad() def forward( @@ -407,7 +409,8 @@ def _prepare_generation_config( ) -> Tuple[GenerationConfig, Dict]: generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) # Use static cache for torch.compile - setattr(generation_config, "cache_implementation", "static") + if self.compiled: + generation_config.cache_implementation = "static" return generation_config, model_kwargs @@ -420,6 +423,12 @@ def prepare_inputs_for_generation(self, *args, **kwargs): def get_encoder(self, *args, **kwargs): return self.model.get_encoder(*args, **kwargs) + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + logger.info("Warm up end") + def _ipex_crop_past_key_values(model, past_key_values, max_length): if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"): From 2ac7ecf1c4af12a32a7f78fad91e0ba86d8c689d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 11 Dec 2024 12:10:23 +0000 Subject: [PATCH 07/11] fix check Signed-off-by: jiqing-feng --- optimum/intel/ipex/modeling_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 55b6e93522..a040fdcf57 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -297,7 +297,7 @@ def _prepare_generation_config( ) -> Tuple[GenerationConfig, Dict]: generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) generation_method = generation_config.get_generation_mode().value - if self.compiled: + if self.compiled and generation_config.cache_implementation != "ipex_paged": # Use static cache for torch compile generation_config.cache_implementation = "static" if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS: From 24b988ccc60092f56e28d22903d3b5ae02b70c96 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 11 Dec 2024 12:12:59 +0000 Subject: [PATCH 08/11] fix ruff check Signed-off-by: jiqing-feng --- optimum/intel/ipex/modeling_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index a040fdcf57..de004db403 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -193,10 +193,10 @@ def maybe_apply_torch_compile(self): return if self.use_cache and not self._supports_static_cache: return - from torch._inductor import config + from torch._inductor import config as inductor_config # System level optimization - torch._inductor.config.cpp_wrapper = True + inductor_config.cpp_wrapper = True os.environ["TORCHINDUCTOR_FREEZING"] = "1" logger.info("Enable torch.compile optimization") self.model.forward = torch.compile(self.model.forward) From 22458d291c69ac3a63c5fa9554d1f7dba4f1b5fe Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 16 Dec 2024 14:21:40 +0000 Subject: [PATCH 09/11] fix check Signed-off-by: jiqing-feng --- optimum/intel/ipex/modeling_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index de004db403..2cdccb771d 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -186,7 +186,7 @@ def can_generate(self): def maybe_apply_torch_compile(self): if ( - not self.model.device.type != "cpu" + self.model.device.type != "cpu" or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE) ): From c11e01c1ae33b245fa07a6dbf9f75309b40248b2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 16 Dec 2024 16:46:41 +0000 Subject: [PATCH 10/11] fix tests Signed-off-by: jiqing-feng --- optimum/intel/ipex/modeling_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 2cdccb771d..e9a19b33b9 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -62,7 +62,7 @@ _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") _IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0" # TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6 -_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "beit", "llama", "falcon", "gpt2") +_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2") def _is_patched_with_ipex(model, task, use_cache: bool = True): @@ -297,7 +297,7 @@ def _prepare_generation_config( ) -> Tuple[GenerationConfig, Dict]: generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) generation_method = generation_config.get_generation_mode().value - if self.compiled and generation_config.cache_implementation != "ipex_paged": + if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache: # Use static cache for torch compile generation_config.cache_implementation = "static" if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS: From 8ff5fb84e228ecf69d2d394810ec18baad2d7425 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 17 Dec 2024 10:55:43 +0000 Subject: [PATCH 11/11] fix opt tests Signed-off-by: jiqing-feng --- tests/ipex/test_modeling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 68ca27534c..679783741b 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -47,7 +47,7 @@ IPEXSentenceTransformer, ) from optimum.utils.testing_utils import grid_parameters, require_sentence_transformers -from optimum.intel.utils.import_utils import is_sentence_transformers_available +from optimum.intel.utils.import_utils import is_sentence_transformers_available, is_torch_version if is_sentence_transformers_available(): from sentence_transformers import SentenceTransformer @@ -331,6 +331,9 @@ def test_ipex_beam_search(self, test_name, model_arch, use_cache): model = IPEXModelForCausalLM.from_pretrained( model_id, use_cache=use_cache, torch_dtype=dtype, device_map=DEVICE ) + # It will be removed when torch 2.6 released + if model_arch == "opt" and not use_cache and model.compiled and is_torch_version("<", "2.6.0"): + return if use_cache and model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: self.assertTrue(model.add_patch) transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)