diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 025a40e057..ffd084d4e6 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -119,6 +119,15 @@ def parse_args_openvino(parser: "ArgumentParser"): "or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models." ), ) + optional_group.add_argument( + "--all-layers", + action="store_true", + default=None, + help=( + "Whether embeddings and last MatMul layers should be compressed to INT4. If not provided an weight " + "compression is applied, they are compressed to INT8." + ), + ) optional_group.add_argument( "--disable-stateful", action="store_true", @@ -198,6 +207,7 @@ def run(self): and self.args.ratio is None and self.args.group_size is None and self.args.sym is None + and self.args.all_layers is None and self.args.model in _DEFAULT_4BIT_CONFIGS ): quantization_config = _DEFAULT_4BIT_CONFIGS[self.args.model] @@ -207,6 +217,7 @@ def run(self): "ratio": 1 if is_int8 else (self.args.ratio or 0.8), "sym": self.args.sym or False, "group_size": -1 if is_int8 else self.args.group_size, + "all_layers": None if is_int8 else self.args.all_layers, } if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}: diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index baa34a5cd0..3b214f77e4 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -614,7 +614,12 @@ def export_from_model( model.config.save_pretrained(output) generation_config = getattr(model, "generation_config", None) if generation_config is not None: - generation_config.save_pretrained(output) + try: + generation_config.save_pretrained(output) + except Exception as exception: + logger.warning( + f"The generation config will not be saved, saving failed with following error:\n{exception}" + ) model_name_or_path = model.config._name_or_path maybe_save_preprocessors(model_name_or_path, output, trust_remote_code=trust_remote_code) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 5b6a83a6cd..00269d1ba2 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -42,15 +42,18 @@ from optimum.utils.normalized_config import NormalizedTextConfig from .model_patcher import ( + AquilaModelPatcher, BaichuanModelPatcher, ChatGLMModelPatcher, GemmaModelPatcher, - InternLMPatcher, + InternLM2Patcher, + InternLMModelPatcher, LlamaModelPatcher, MixtralModelPatcher, MPTModelPatcher, Phi3ModelPatcher, QwenModelPatcher, + XverseModelPatcher, ) @@ -461,7 +464,7 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None ) -> "ModelPatcher": - return InternLMPatcher(self, model, model_kwargs=model_kwargs) + return InternLM2Patcher(self, model, model_kwargs=model_kwargs) @register_in_tasks_manager("orion", *["text-generation", "text-generation-with-past"], library_name="transformers") @@ -501,6 +504,12 @@ def patch_model_for_export( library_name="transformers", ) class Phi3OpenVINOConfig(PhiOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + MistralDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) + def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None ) -> "ModelPatcher": @@ -608,3 +617,140 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return { "sample": {0: "batch_size", 2: "height", 3: "width"}, } + + +@register_in_tasks_manager( + "persimmon", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class PersimmonOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager("biogpt", *["text-generation", "text-generation-with-past"], library_name="transformers") +class BioGPTOpenVINOConfig(TextDecoderOnnxConfig): + # BioGPT does not require position_ids input. + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager( + "gpt-neox-japanese", *["text-generation", "text-generation-with-past"], library_name="transformers" +) +class GPTNeoxJapaneseOpenVINOConfig(TextDecoderOnnxConfig): + # GPTNeoxJapanese does not require position_ids input. + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager( + "cohere", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class CohereOpenVINOConfig(LlamaOpenVINOConfig): + pass + + +@register_in_tasks_manager("xglm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class XGLMConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + num_attention_heads="attention_heads", hidden_size="d_model" + ) + + +class AquilaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task, + normalized_config, + batch_size, + sequence_length, + random_batch_size_range, + random_sequence_length_range, + **kwargs, + ) + self.num_key_value_heads = getattr( + normalized_config, "num_key_value_heads", normalized_config.num_attention_heads + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + shape = ( + self.batch_size, + self.num_key_value_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + +@register_in_tasks_manager("aquila", *["text-generation", "text-generation-with-past"], library_name="transformers") +class AquilaMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, AquilaDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = AquilaDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return AquilaModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("xverse", *["text-generation", "text-generation-with-past"], library_name="transformers") +class XverseMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return XverseModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("internlm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class InternLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = DummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return InternLMModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 1699c6d362..678cd39e3b 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -946,7 +946,7 @@ def __exit__(self, exc_type, exc_value, traceback): block.attn.forward = block.attn._orig_forward -def _internlm_attention_forward( +def _internlm2_attention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -1037,14 +1037,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return attn_output, attn_weights, past_key_value -class InternLMPatcher(DecoderModelPatcher): +class InternLM2Patcher(DecoderModelPatcher): def __enter__(self): super().__enter__() if is_torch_version(">=", "2.1.0"): for block in self._model.model.layers: block.attention._orig_forward = block.attention.forward - block.attention.forward = types.MethodType(_internlm_attention_forward, block.attention) + block.attention.forward = types.MethodType(_internlm2_attention_forward, block.attention) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) @@ -1053,15 +1053,380 @@ def __exit__(self, exc_type, exc_value, traceback): block.attention.forward = block.attention._orig_forward +# Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L426 +def _phi3_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + return self._orig_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + # TO DO: remove llama imports when transformers with phi3 support will be released + try: + from transformers.models.phi3.modelling_phi3 import apply_rotary_pos_emb, repeat_kv + except ImportError: + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + class Phi3ModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() - # https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113 # init inv_freq for torchscript tracing for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + if layer.self_attn.rotary_emb.inv_freq is None: rotary_emb = layer.self_attn.rotary_emb layer.self_attn.rotary_emb.inv_freq = 1.0 / ( rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) ) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +def _aquila_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + bsz, q_len, _ = hidden_states.size() + + if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if hasattr(self, "num_key_value_groups"): + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +class AquilaModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_aquila_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +def _xverse_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +def _internlm_self_attn_sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + if output_attentions: + return self._orig_forward( + hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache + ) + + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim)) + ) + attn_weights = None + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class XverseModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_xverse_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + + +class InternLMModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if is_torch_version(">=", "2.1.0"): + orig_self_attn_fwd = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(_internlm_self_attn_sdpa_forward, layer.self_attn) + layer.self_attn._orig_forward = orig_self_attn_fwd + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index d2963d55a1..e929a4ddb8 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -39,6 +39,7 @@ GenerationConfig, GenerationMixin, PretrainedConfig, + is_torch_xpu_available, ) from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput @@ -52,7 +53,7 @@ from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model from ..generation.modeling import prepare_jit_inputs from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device logger = logging.getLogger(__name__) @@ -128,10 +129,14 @@ def __init__( **kwargs, ): OptimizedModel.__init__(self, model=model, config=config) - # To do: add XPU support - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 + if is_torch_xpu_available(check_device=True): + self._device = torch.device("xpu:0") + elif torch.cuda.is_available(): + self._device = torch.device("cuda:0") + else: + self._device = torch.device("cpu") self.model.to(self._device) + self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 self.model_save_dir = model_save_dir self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature) @@ -321,6 +326,8 @@ def _init_warmup(self): if not self._is_ipex_exported: use_cache = "past_key_values" in self.input_names dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) + if self._device.type != "cpu": + dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device) for _ in range(2): self(**dummy_inputs) diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 45961a86ff..17305b947e 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -48,7 +48,7 @@ from ...exporters.openvino.model_patcher import patch_model_with_bettertransformer from ...exporters.openvino.stateful import ensure_export_task_support_stateful, ensure_stateful_is_available from ..utils.constant import _TASK_ALIASES -from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available +from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available, is_diffusers_available from ..utils.modeling_utils import get_model_device from .configuration import OVConfig, OVQuantizationConfig, OVQuantizationMethod, OVWeightQuantizationConfig from .modeling_base import OVBaseModel @@ -325,7 +325,8 @@ def _quantize_ovbasemodel( remove_unused_columns: bool = True, **kwargs, ): - from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionPipelineBase + if is_diffusers_available(): + from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionPipelineBase if save_directory is not None: save_directory = Path(save_directory) @@ -335,7 +336,7 @@ def _quantize_ovbasemodel( if calibration_dataset is not None: # Process custom calibration dataset - if isinstance(self.model, OVStableDiffusionPipelineBase): + if is_diffusers_available() and isinstance(self.model, OVStableDiffusionPipelineBase): calibration_dataset = self._prepare_unet_dataset( quantization_config.num_samples, dataset=calibration_dataset ) @@ -373,7 +374,7 @@ def _quantize_ovbasemodel( if isinstance(self.model, OVModelForCausalLM): calibration_dataset = self._prepare_builtin_dataset(quantization_config) - elif isinstance(self.model, OVStableDiffusionPipelineBase): + elif is_diffusers_available() and isinstance(self.model, OVStableDiffusionPipelineBase): calibration_dataset = self._prepare_unet_dataset( quantization_config.num_samples, dataset_name=quantization_config.dataset ) @@ -385,7 +386,7 @@ def _quantize_ovbasemodel( if quantization_config.quant_method == OVQuantizationMethod.HYBRID: if calibration_dataset is None: raise ValueError("Calibration dataset is required to run hybrid quantization.") - if isinstance(self.model, OVStableDiffusionPipelineBase): + if is_diffusers_available() and isinstance(self.model, OVStableDiffusionPipelineBase): # Apply weight-only quantization to all SD submodels except UNet quantization_config_copy = copy.deepcopy(quantization_config) quantization_config_copy.dataset = None diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 99ad42aafa..a2cd728354 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -169,3 +169,16 @@ def get_model_device(model: torch.nn.Module) -> torch.device: # The model had no parameters at all, doesn't matter which device to choose device = torch.device("cpu") return device + + +def recursive_to_device(value, device): + """ + Recursivley move the tensor element in `value` to `device` + """ + if isinstance(value, (tuple, list)): + return type(value)(recursive_to_device(v, device) for v in value) + elif isinstance(value, dict): + return {k: recursive_to_device(v, device) for k, v in value.items()} + elif isinstance(value, torch.Tensor): + return value.to(device) + return value diff --git a/tests/ipex/test_inference.py b/tests/ipex/test_inference.py index b65d3c9b8e..1a452fe408 100644 --- a/tests/ipex/test_inference.py +++ b/tests/ipex/test_inference.py @@ -16,8 +16,6 @@ import torch from parameterized import parameterized - -# TODO : add more tasks from transformers import ( AutoModelForCausalLM, AutoModelForQuestionAnswering, @@ -26,60 +24,51 @@ AutoTokenizer, pipeline, ) +from utils_tests import MODEL_NAMES from optimum.intel import inference_mode as ipex_inference_mode from optimum.intel.ipex.modeling_base import IPEXModel -MODEL_NAMES = { - "bert": "hf-internal-testing/tiny-random-bert", - "bloom": "hf-internal-testing/tiny-random-BloomModel", - "distilbert": "hf-internal-testing/tiny-random-distilbert", - "roberta": "hf-internal-testing/tiny-random-roberta", - "gptj": "hf-internal-testing/tiny-random-gptj", - "gpt2": "hf-internal-testing/tiny-random-gpt2", - "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", - "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", - "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", - "llama": "fxmarty/tiny-llama-fast-tokenizer", - "llama2": "Jiqing/tiny_random_llama2", - "opt": "hf-internal-testing/tiny-random-OPTModel", - "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", -} - _CLASSIFICATION_TASK_TO_AUTOMODELS = { "text-classification": AutoModelForSequenceClassification, "token-classification": AutoModelForTokenClassification, } -class IPEXIntegrationTest(unittest.TestCase): - CLASSIFICATION_SUPPORTED_ARCHITECTURES = ( +class IPEXClassificationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( "bert", "distilbert", "roberta", ) - TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ( - "bloom", - "gptj", - "gpt2", - "gpt_neo", - "gpt_bigcode", - "llama", - "llama2", - "opt", - "mpt", - ) + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = "This is a sample input" + for task, auto_model_class in _CLASSIFICATION_TASK_TO_AUTOMODELS.items(): + model = auto_model_class.from_pretrained(model_id, torch_dtype=torch.float32) + pipe = pipeline(task, model=model, tokenizer=tokenizer) - QA_SUPPORTED_ARCHITECTURES = ( + with torch.inference_mode(): + outputs = pipe(inputs) + with ipex_inference_mode(pipe, dtype=model.config.torch_dtype, verbose=False, jit=True) as ipex_pipe: + outputs_ipex = ipex_pipe(inputs) + self.assertTrue(isinstance(ipex_pipe.model._optimized.model, torch.jit.RecursiveScriptModule)) + self.assertEqual(outputs[0]["score"], outputs_ipex[0]["score"]) + + +class IPEXQuestionAnsweringTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( "bert", "distilbert", "roberta", ) - @parameterized.expand(QA_SUPPORTED_ARCHITECTURES) - def test_question_answering_pipeline_inference(self, model_arch): + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=torch.float32) @@ -95,24 +84,22 @@ def test_question_answering_pipeline_inference(self, model_arch): self.assertEqual(outputs["start"], outputs_ipex["start"]) self.assertEqual(outputs["end"], outputs_ipex["end"]) - @parameterized.expand(CLASSIFICATION_SUPPORTED_ARCHITECTURES) - def test_classification_pipeline_inference(self, model_arch): - model_id = MODEL_NAMES[model_arch] - tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = "This is a sample input" - for task, auto_model_class in _CLASSIFICATION_TASK_TO_AUTOMODELS.items(): - model = auto_model_class.from_pretrained(model_id, torch_dtype=torch.float32) - pipe = pipeline(task, model=model, tokenizer=tokenizer) - with torch.inference_mode(): - outputs = pipe(inputs) - with ipex_inference_mode(pipe, dtype=model.config.torch_dtype, verbose=False, jit=True) as ipex_pipe: - outputs_ipex = ipex_pipe(inputs) - self.assertTrue(isinstance(ipex_pipe.model._optimized.model, torch.jit.RecursiveScriptModule)) - self.assertEqual(outputs[0]["score"], outputs_ipex[0]["score"]) +class IPEXTextGenerationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( + "bloom", + "gptj", + "gpt2", + "gpt_neo", + "gpt_bigcode", + "llama", + "llama2", + "opt", + "mpt", + ) - @parameterized.expand(TEXT_GENERATION_SUPPORTED_ARCHITECTURES) - def test_text_generation_pipeline_inference(self, model_arch): + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, return_dict=False) model = model.eval() diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 94a5ca9e16..2a2f18f6f8 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -45,53 +45,11 @@ ) from optimum.intel.utils.import_utils import is_ipex_version from optimum.utils.testing_utils import grid_parameters +from utils_tests import MODEL_NAMES SEED = 42 -MODEL_NAMES = { - "albert": "hf-internal-testing/tiny-random-albert", - "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", - "bert": "hf-internal-testing/tiny-random-bert", - "bart": "hf-internal-testing/tiny-random-bart", - "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", - "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", - "bloom": "hf-internal-testing/tiny-random-BloomModel", - "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", - "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", - "convnext": "hf-internal-testing/tiny-random-convnext", - "distilbert": "hf-internal-testing/tiny-random-distilbert", - "electra": "hf-internal-testing/tiny-random-electra", - "flaubert": "hf-internal-testing/tiny-random-flaubert", - "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", - "gpt2": "hf-internal-testing/tiny-random-gpt2", - "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", - "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", - "gptj": "hf-internal-testing/tiny-random-GPTJModel", - "levit": "hf-internal-testing/tiny-random-LevitModel", - "llama": "fxmarty/tiny-llama-fast-tokenizer", - "llama2": "Jiqing/tiny_random_llama2", - "marian": "sshleifer/tiny-marian-en-de", - "mbart": "hf-internal-testing/tiny-random-mbart", - "mistral": "echarlaix/tiny-random-mistral", - "mobilenet_v1": "google/mobilenet_v1_0.75_192", - "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", - "mobilevit": "hf-internal-testing/tiny-random-mobilevit", - "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", - "mt5": "stas/mt5-tiny-random", - "opt": "hf-internal-testing/tiny-random-OPTModel", - "phi": "echarlaix/tiny-random-PhiForCausalLM", - "resnet": "hf-internal-testing/tiny-random-resnet", - "roberta": "hf-internal-testing/tiny-random-roberta", - "roformer": "hf-internal-testing/tiny-random-roformer", - "squeezebert": "hf-internal-testing/tiny-random-squeezebert", - "t5": "hf-internal-testing/tiny-random-t5", - "unispeech": "hf-internal-testing/tiny-random-unispeech", - "vit": "hf-internal-testing/tiny-random-vit", - "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", - "xlm": "hf-internal-testing/tiny-random-xlm", -} - class Timer(object): def __enter__(self): diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py index 89a27ab2c8..c4ae471a0f 100644 --- a/tests/ipex/test_pipelines.py +++ b/tests/ipex/test_pipelines.py @@ -20,6 +20,7 @@ from parameterized import parameterized from transformers import AutoTokenizer from transformers.pipelines import pipeline as transformers_pipeline +from utils_tests import MODEL_NAMES from optimum.intel.ipex.modeling_base import ( IPEXModelForAudioClassification, @@ -33,50 +34,6 @@ from optimum.intel.pipelines import pipeline as ipex_pipeline -MODEL_NAMES = { - "albert": "hf-internal-testing/tiny-random-albert", - "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", - "bert": "hf-internal-testing/tiny-random-bert", - "bart": "hf-internal-testing/tiny-random-bart", - "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", - "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", - "bloom": "hf-internal-testing/tiny-random-BloomModel", - "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", - "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", - "convnext": "hf-internal-testing/tiny-random-convnext", - "distilbert": "hf-internal-testing/tiny-random-distilbert", - "electra": "hf-internal-testing/tiny-random-electra", - "flaubert": "hf-internal-testing/tiny-random-flaubert", - "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", - "gpt2": "hf-internal-testing/tiny-random-gpt2", - "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", - "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", - "gptj": "hf-internal-testing/tiny-random-GPTJModel", - "levit": "hf-internal-testing/tiny-random-LevitModel", - "llama": "fxmarty/tiny-llama-fast-tokenizer", - "llama2": "Jiqing/tiny_random_llama2", - "marian": "sshleifer/tiny-marian-en-de", - "mbart": "hf-internal-testing/tiny-random-mbart", - "mistral": "echarlaix/tiny-random-mistral", - "mobilenet_v1": "google/mobilenet_v1_0.75_192", - "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", - "mobilevit": "hf-internal-testing/tiny-random-mobilevit", - "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", - "mt5": "stas/mt5-tiny-random", - "opt": "hf-internal-testing/tiny-random-OPTModel", - "phi": "echarlaix/tiny-random-PhiForCausalLM", - "resnet": "hf-internal-testing/tiny-random-resnet", - "roberta": "hf-internal-testing/tiny-random-roberta", - "roformer": "hf-internal-testing/tiny-random-roformer", - "squeezebert": "hf-internal-testing/tiny-random-squeezebert", - "t5": "hf-internal-testing/tiny-random-t5", - "unispeech": "hf-internal-testing/tiny-random-unispeech", - "vit": "hf-internal-testing/tiny-random-vit", - "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", - "xlm": "hf-internal-testing/tiny-random-xlm", -} - - class PipelinesIntegrationTest(unittest.TestCase): COMMON_SUPPORTED_ARCHITECTURES = ( "albert", diff --git a/tests/ipex/utils_tests.py b/tests/ipex/utils_tests.py new file mode 100644 index 0000000000..a14f0bf7ca --- /dev/null +++ b/tests/ipex/utils_tests.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +MODEL_NAMES = { + "albert": "hf-internal-testing/tiny-random-albert", + "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", + "bert": "hf-internal-testing/tiny-random-bert", + "bart": "hf-internal-testing/tiny-random-bart", + "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", + "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", + "bloom": "hf-internal-testing/tiny-random-BloomModel", + "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", + "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", + "convnext": "hf-internal-testing/tiny-random-convnext", + "distilbert": "hf-internal-testing/tiny-random-distilbert", + "electra": "hf-internal-testing/tiny-random-electra", + "flaubert": "hf-internal-testing/tiny-random-flaubert", + "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", + "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", + "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", + "gptj": "hf-internal-testing/tiny-random-GPTJModel", + "levit": "hf-internal-testing/tiny-random-LevitModel", + "llama": "fxmarty/tiny-llama-fast-tokenizer", + "llama2": "Jiqing/tiny_random_llama2", + "marian": "sshleifer/tiny-marian-en-de", + "mbart": "hf-internal-testing/tiny-random-mbart", + "mistral": "echarlaix/tiny-random-mistral", + "mobilenet_v1": "google/mobilenet_v1_0.75_192", + "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", + "mobilevit": "hf-internal-testing/tiny-random-mobilevit", + "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", + "mt5": "stas/mt5-tiny-random", + "opt": "hf-internal-testing/tiny-random-OPTModel", + "phi": "echarlaix/tiny-random-PhiForCausalLM", + "resnet": "hf-internal-testing/tiny-random-resnet", + "roberta": "hf-internal-testing/tiny-random-roberta", + "roformer": "hf-internal-testing/tiny-random-roformer", + "squeezebert": "hf-internal-testing/tiny-random-squeezebert", + "t5": "hf-internal-testing/tiny-random-t5", + "unispeech": "hf-internal-testing/tiny-random-unispeech", + "vit": "hf-internal-testing/tiny-random-vit", + "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", + "xlm": "hf-internal-testing/tiny-random-xlm", +} diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index cac79abaee..cce25bbae1 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -18,7 +18,6 @@ from parameterized import parameterized from utils_tests import ( - _ARCHITECTURES_TO_EXPECTED_INT4_INT8, _ARCHITECTURES_TO_EXPECTED_INT8, MODEL_NAMES, get_num_quantized_nodes, @@ -84,14 +83,13 @@ class OVCLIExportTestCase(unittest.TestCase): ("latent-consistency", 50, 135), ) - SUPPORTED_4BIT_ARCHITECTURES = (("text-generation-with-past", "opt125m"),) - - SUPPORTED_4BIT_OPTIONS = ["int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"] - - TEST_4BIT_CONFIGURATONS = [] - for arch in SUPPORTED_4BIT_ARCHITECTURES: - for option in SUPPORTED_4BIT_OPTIONS: - TEST_4BIT_CONFIGURATONS.append([arch[0], arch[1], option]) + TEST_4BIT_CONFIGURATONS = [ + ("text-generation-with-past", "opt125m", "int4_sym_g128", 62, 86), + ("text-generation-with-past", "opt125m", "int4_asym_g128", 62, 86), + ("text-generation-with-past", "opt125m", "int4_sym_g64", 62, 86), + ("text-generation-with-past", "opt125m", "int4_asym_g64", 62, 86), + ("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 16 --all-layers", 0, 32), + ] def _openvino_export( self, model_name: str, task: str, compression_option: str = None, compression_ratio: float = None @@ -197,17 +195,16 @@ def test_exporters_cli_hybrid_quantization(self, model_type: str, exp_num_fq: in self.assertEqual(exp_num_fq, num_fq) @parameterized.expand(TEST_4BIT_CONFIGURATONS) - def test_exporters_cli_int4(self, task: str, model_type: str, option: str): + def test_exporters_cli_int4(self, task: str, model_type: str, option: str, expected_int8: int, expected_int4: int): with TemporaryDirectory() as tmpdir: subprocess.run( - f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}", + f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --task {task} --weight-format {option} {tmpdir}", shell=True, check=True, ) model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {} model = eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs) - expected_int8, expected_int4 = _ARCHITECTURES_TO_EXPECTED_INT4_INT8[model_type] _, num_int8, num_int4 = get_num_quantized_nodes(model) self.assertEqual(expected_int8, num_int8) self.assertEqual(expected_int4, num_int4) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 75c95c1563..1191a93908 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -552,6 +552,15 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "orion", "falcon", "falcon-40b", + "persimmon", + "biogpt", + "gpt_neox_japanese", + "cohere", + "xglm", + "aquila", + "aquila2", + "xverse", + "internlm", ) GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ( @@ -564,6 +573,10 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "internlm2", "orion", "phi3", + "aquila", + "aquila2", + "xverse", + "internlm", ) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -591,6 +604,7 @@ def test_compare_to_transformers(self, model_arch): self.assertEqual(ov_model.stateful, ov_model.config.model_type not in not_stateful) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) tokens = tokenizer("This is a sample output", return_tensors="pt") + tokens.pop("token_type_ids", None) ov_outputs = ov_model(**tokens) self.assertTrue("logits" in ov_outputs) @@ -617,11 +631,15 @@ def test_compare_to_transformers(self, model_arch): if model_arch == "qwen": return - if model_arch != "chatglm": + if model_arch not in ["chatglm", "persimmon"]: tokenizer.pad_token_id = tokenizer.eos_token_id + + if model_arch == "persimmon": + tokenizer.pad_token_id = tokenizer.bos_token_id # Compare batched generation tokenizer.padding_side = "left" tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + tokens.pop("token_type_ids", None) ov_model.generation_config.eos_token_id = None transformers_model.generation_config.eos_token_id = None ov_model.config.eos_token_id = None @@ -845,6 +863,7 @@ def test_beam_search(self, model_arch): transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) tokenizer.pad_token_id = tokenizer.eos_token_id tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True) + tokens.pop("token_type_ids", None) ov_model_stateful.generation_config.eos_token_id = None ov_model_stateless.generation_config.eos_token_id = None transformers_model.generation_config.eos_token_id = None diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 9f28e40a4b..d4364d192a 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -18,6 +18,8 @@ MODEL_NAMES = { "albert": "hf-internal-testing/tiny-random-albert", + "aquila": "katuni4ka/tiny-random-aquilachat", + "aquila2": "katuni4ka/tiny-random-aquila2", "audio_spectrogram_transformer": "Ericwang/tiny-random-ast", "bge": "BAAI/bge-small-en-v1.5", "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", @@ -26,11 +28,13 @@ "baichuan2": "katuni4ka/tiny-random-baichuan2", "baichuan2-13b": "katuni4ka/tiny-random-baichuan2-13b", "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus", + "biogpt": "hf-tiny-model-private/tiny-random-BioGptForCausalLM", "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", + "cohere": "hf-internal-testing/tiny-random-CohereForCausalLM", "chatglm": "katuni4ka/tiny-random-chatglm2", "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", "data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel", @@ -51,9 +55,11 @@ "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", + "gpt_neox_japanese": "hf-internal-testing/tiny-random-GPTNeoXJapaneseForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", "hubert": "hf-internal-testing/tiny-random-HubertModel", "ibert": "hf-internal-testing/tiny-random-ibert", + "internlm": "katuni4ka/tiny-random-internlm", "internlm2": "katuni4ka/tiny-random-internlm2", "levit": "hf-internal-testing/tiny-random-LevitModel", "longt5": "hf-internal-testing/tiny-random-longt5", @@ -78,6 +84,7 @@ "olmo": "katuni4ka/tiny-random-olmo-hf", "orion": "katuni4ka/tiny-random-orion", "pegasus": "hf-internal-testing/tiny-random-pegasus", + "persimmon": "hf-internal-testing/tiny-random-PersimmonForCausalLM", "pix2struct": "fxmarty/pix2struct-tiny-random", "phi": "echarlaix/tiny-random-PhiForCausalLM", "phi3": "katuni4ka/tiny-random-phi3", @@ -115,6 +122,8 @@ "whisper": "openai/whisper-tiny.en", "xlm": "hf-internal-testing/tiny-random-xlm", "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", + "xglm": "hf-internal-testing/tiny-random-XGLMForCausalLM", + "xverse": "katuni4ka/tiny-random-xverse", } @@ -140,8 +149,6 @@ "stable-diffusion-xl-refiner": (366, 34, 42, 66), } -_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (62, 86)} - def get_num_quantized_nodes(ov_model): num_fake_quantize = 0