diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 3d90ad12fb..ee6082d1ae 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -14,7 +14,7 @@ from transformers.models.bert.modeling_bert import BertIntermediate from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel -from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model +from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaModel, @@ -27,13 +27,11 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, - _IPEXGPT2MLP, _falcon_model_forward, - _gpt2_block_forward, _gpt2_model_forward, _ipex_rms_layer_norm_forward, _IPEXFalconDecoderLayer, - _IPEXGPT2Attention, + _IPEXGPT2Block, _IPEXIntermediate, _IPEXLlamaDecoderLayer, _llama_model_forward, @@ -59,12 +57,12 @@ def convert_functions(m, target_m, new_function_name, new_function): convert_functions(sub_m, target_m, new_function_name, new_function) -def convert_class(m, target_m, new_class, config=None): +def convert_class(m, target_m, new_class, device, config): for name, sub_m in m.named_children(): if isinstance(sub_m, target_m): - new_m = new_class(sub_m, config) + new_m = new_class(sub_m, device, config) setattr(m, name, new_m) - convert_class(sub_m, target_m, new_class, config) + convert_class(sub_m, target_m, new_class, device, config) def patch_op(m, target_m, new_op_name, new_op): @@ -82,7 +80,7 @@ def _patch_llama_model(model): """ convert_functions(model, LlamaModel, "forward", _llama_model_forward) convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward) - convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config) + convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.device, model.config) return model @@ -98,7 +96,7 @@ def _patch_falcon_model(model): setattr(model.config, "num_key_value_heads", num_key_value_heads) convert_functions(model, FalconModel, "forward", _falcon_model_forward) replace_customized_linear_with_linear(model) - convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config) + convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config) return model @@ -106,13 +104,12 @@ def _patch_gpt2_model(model): """ Patch gpt2 model: 1. Use IPEX paged attention + 2. Linear fusion with (Linear + Add) """ num_key_value_heads = model.config.num_attention_heads setattr(model.config, "num_key_value_heads", num_key_value_heads) convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) - convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) - convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) - convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config) + convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config) return model @@ -121,7 +118,7 @@ def _patch_bert_model(model): Patch bert model: 1. Linear fusion with Linear + Gelu """ - convert_class(model, BertIntermediate, _IPEXIntermediate) + convert_class(model, BertIntermediate, _IPEXIntermediate, model.device, model.config) return model @@ -130,7 +127,7 @@ def _patch_vit_model(model): Patch vit model: 1. Linear fusion with Linear + Gelu """ - convert_class(model, ViTIntermediate, _IPEXIntermediate) + convert_class(model, ViTIntermediate, _IPEXIntermediate, model.device, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 41dd5693df..cb22bc1c22 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -33,6 +33,7 @@ logger = logging.getLogger(__name__) _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0" +_accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "musa"] if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): @@ -137,6 +138,32 @@ def forward(self, x, y, z): return x +# Adapted from https://github.com/huggingface/accelerate/blob/v1.2.1/src/accelerate/hooks.py#L183 +def _remove_hooks_for_ipex(module, recurse): + if hasattr(module, "_hf_hook"): + module._hf_hook.detach_hook(module) + delattr(module, "_hf_hook") + + if hasattr(module, "_old_forward"): + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = module.__class__.forward.__get__(module) + else: + module.forward = module.__class__.forward.__get__(module) + delattr(module, "_old_forward") + + # Remove accelerate added warning hooks from dispatch_model + for attr in _accelerate_added_attributes: + module.__dict__.pop(attr, None) + + if recurse: + for child in module.children(): + _remove_hooks_for_ipex(child, recurse) + + return module + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 def _ipex_rms_layer_norm_forward(self, hidden_states): return rms_norm(hidden_states, self.weight, self.variance_epsilon) @@ -531,84 +558,12 @@ def _gpt2_model_forward( ) -# To pass input_lens, adapted from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt2/modeling_gpt2.py#L602 -def _gpt2_block_forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - **kwargs, -) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: - residual = hidden_states - hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs, - ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] - # residual connection - if hasattr(self.attn, "linear_add"): - hidden_states = self.attn.linear_add(attn_output, residual) - else: - hidden_states = attn_output + residual - - if encoder_hidden_states is not None: - # add one self-attention block for cross-attention - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " - "cross-attention layers by setting `config.add_cross_attention=True`" - ) - residual = hidden_states - hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( - hidden_states, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - **kwargs, - ) - attn_output = cross_attn_outputs[0] - # residual connection - hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights - - residual = hidden_states - hidden_states = self.ln_2(hidden_states) - feed_forward_hidden_states = self.mlp(hidden_states) - # residual connection - if hasattr(self.mlp, "linear_add"): - hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual) - else: - hidden_states = residual + feed_forward_hidden_states - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs # hidden_states, present, (attentions, cross_attentions) - - class _IPEXAttention(nn.Module): - def __init__(self, module, config) -> None: + def __init__(self, module, device, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.module_device = next(module.parameters()).device + self.module_device = device self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device @@ -745,32 +700,38 @@ def forward( class _IPEXLlamaAttention(_IPEXAttention): - def __init__(self, module, config) -> None: - super().__init__(module, config) - concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous() - bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias] - use_bias = bias_list != [] - self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias) - self.concat_qkv.weight = nn.Parameter(concat_weight) - if use_bias: - concat_bias = torch.concat(bias_list, 0).contiguous() - self.concat_linear.bias = nn.Parameter(concat_bias) - self.q_slice = self.q_proj.weight.shape[0] - self.k_slice = self.q_slice + self.k_proj.weight.shape[0] - self.v_slice = self.k_slice + self.v_proj.weight.shape[0] - if self.module_device.type == "cpu": - if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mha_linear_add = LinearAdd(module.o_proj) - - elif self.module_device.type == "xpu": - if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mha_linear_add = XPULinearAdd(module.o_proj) + def __init__(self, module, device, config) -> None: + super().__init__(module, device, config) + if getattr(config, "quantization_config", None) is None: + concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous() + bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias] + use_bias = bias_list != [] + self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias) + self.concat_qkv.weight = nn.Parameter(concat_weight) + if use_bias: + concat_bias = torch.concat(bias_list, 0).contiguous() + self.concat_linear.bias = nn.Parameter(concat_bias) + self.q_slice = self.q_proj.weight.shape[0] + self.k_slice = self.q_slice + self.k_proj.weight.shape[0] + self.v_slice = self.k_slice + self.v_proj.weight.shape[0] + if self.module_device.type == "cpu": + if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mha_linear_add = LinearAdd(module.o_proj) + + elif self.module_device.type == "xpu": + if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mha_linear_add = XPULinearAdd(module.o_proj) def qkv_gemm(self, hidden_states): - qkv_out = self.concat_qkv(hidden_states) - query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim) - key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) - value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) + if hasattr(self, "concat_qkv"): + qkv_out = self.concat_qkv(hidden_states) + query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim) + key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) + value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) + else: + query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim) + key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) + value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim) return query, key, value @@ -781,9 +742,9 @@ def rope(self, query, key, **kwargs): class _IPEXFalconAttention(_IPEXAttention): - def __init__(self, module, config): + def __init__(self, module, device, config): self.num_key_value_heads = config.num_key_value_heads - super().__init__(module, config) + super().__init__(module, device, config) self.q_slice = self.head_dim * config.num_kv_heads self.k_slice = self.q_slice + self.head_dim self.v_slice = self.k_slice + self.head_dim @@ -808,26 +769,30 @@ def rope(self, query, key, **kwargs): class _IPEXGPT2Attention(_IPEXAttention): - def __init__(self, module, config) -> None: + def __init__(self, module, device, config) -> None: self.num_key_value_heads = config.num_key_value_heads - super().__init__(module, config) + super().__init__(module, device, config) _setattr_from_module(self, module) - self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) - self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) - self.c_attn_linear.bias = self.c_attn.bias - self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) - self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) - self.c_proj_linear.bias = self.c_proj.bias - if self.module_device.type == "cpu": - if self.c_proj_linear not in ["LinearAllreduce"]: - self.linear_add = LinearAdd(self.c_proj_linear) - - elif self.module_device.type == "xpu": - if self.c_proj_linear not in ["LinearAllreduce"]: - self.linear_add = XPULinearAdd(self.c_proj_linear) + if getattr(config, "quantization_config", None) is None: + self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) + self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) + self.c_attn_linear.bias = self.c_attn.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) def qkv_gemm(self, hidden_states): - query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) + if hasattr(self, "c_attn_linear"): + query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) query = query.view(-1, self.num_heads, self.head_dim) key = key.view(-1, self.num_heads, self.head_dim) value = value.view(-1, self.num_heads, self.head_dim) @@ -847,21 +812,22 @@ def postprocess_attention_output(self, attn_output): # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186 class _IPEXLlamaMLP(nn.Module): - def __init__(self, module, config) -> None: + def __init__(self, module, device, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.module_device = next(module.parameters()).device - if self.module_device.type == "cpu": - # LinearAllreduce and LinearLayer cannot use fused op LinearAdd - if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mlp_linear_add = LinearAdd(module.down_proj) - self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) - elif self.module_device.type == "xpu": - # LinearAllreduce and LinearLayer cannot use fused op LinearAdd - if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mlp_linear_add = XPULinearAdd(module.down_proj) - self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj) + self.module_device = device + if getattr(config, "quantization_config", None) is None: + if self.module_device.type == "cpu": + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd + if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mlp_linear_add = LinearAdd(module.down_proj) + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + elif self.module_device.type == "xpu": + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd + if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mlp_linear_add = XPULinearAdd(module.down_proj) + self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj) def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): if hasattr(self, "linear_silu_mul"): @@ -879,21 +845,22 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, ** class _IPEXFalconMLP(nn.Module): - def __init__(self, module, config) -> None: + def __init__(self, module, device, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - # LinearAllreduce and LinearLayer cannot use fused op LinearAdd - self.module_device = next(module.parameters()).device - if self.module_device.type == "cpu": - self.linear_gelu = LinearGelu(module.dense_h_to_4h) - elif self.module_device.type == "xpu": - self.linear_gelu = XPULinearGelu(module.dense_h_to_4h) - if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]: + self.module_device = device + if getattr(config, "quantization_config", None) is None: + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd if self.module_device.type == "cpu": - self.linear_add_add = LinearAddAdd(module.dense_4h_to_h) + self.linear_gelu = LinearGelu(module.dense_h_to_4h) elif self.module_device.type == "xpu": - self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h) + self.linear_gelu = XPULinearGelu(module.dense_h_to_4h) + if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]: + if self.module_device.type == "cpu": + self.linear_add_add = LinearAddAdd(module.dense_4h_to_h) + elif self.module_device.type == "xpu": + self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h) def forward( self, @@ -902,38 +869,44 @@ def forward( residual: torch.Tensor = None, **kwargs, ): - mlp_hidden_states = self.linear_gelu(hidden_states) + if hasattr(self, "linear_gelu"): + mlp_hidden_states = self.linear_gelu(hidden_states) + else: + mlp_hidden_states = self.act(self.dense_h_to_4h(hidden_states)) + if hasattr(self, "linear_add_add"): output = self.linear_add_add(mlp_hidden_states, attention_output, residual) else: - mlp_output = self.mlp.dense_4h_to_h(mlp_hidden_states) + mlp_output = self.dense_4h_to_h(mlp_hidden_states) output = mlp_output + attention_output + residual return output class _IPEXGPT2MLP(nn.Module): - def __init__(self, module, config) -> None: + def __init__(self, module, device, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.module_device = next(module.parameters()).device - self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) - self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) - self.c_fc_linear.bias = self.c_fc.bias - self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) - self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) - self.c_proj_linear.bias = self.c_proj.bias - if self.module_device.type == "cpu": - self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) - - if self.module_device.type == "cpu": - if self.c_proj_linear not in ["LinearAllreduce"]: - self.linear_add = LinearAdd(self.c_proj_linear) - - elif self.module_device.type == "xpu": - if self.c_proj_linear not in ["LinearAllreduce"]: - self.linear_add = XPULinearAdd(self.c_proj_linear) + self.module_device = device + + if getattr(config, "quantization_config", None) is None: + self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) + self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) + self.c_fc_linear.bias = self.c_fc.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) + + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: if hasattr(self, "linear_new_gelu"): @@ -948,11 +921,13 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayer(nn.Module): - def __init__(self, module, config): + def __init__(self, module, device, config): super().__init__() _setattr_from_module(self, module) - self.self_attn = _IPEXLlamaAttention(module.self_attn, config) - self.mlp = _IPEXLlamaMLP(module.mlp, config) + self.self_attn = _IPEXLlamaAttention(module.self_attn, device, config) + self.mlp = _IPEXLlamaMLP(module.mlp, device, config) + if getattr(config, "quantization_config", None): + _remove_hooks_for_ipex(self, True) def forward(self, hidden_states: torch.Tensor, **kwargs): # Please see the original model's forward to check the parameter @@ -981,11 +956,13 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): class _IPEXFalconDecoderLayer(nn.Module): - def __init__(self, module, config): + def __init__(self, module, device, config): super().__init__() _setattr_from_module(self, module) - self.self_attention = _IPEXFalconAttention(module.self_attention, config) - self.mlp = _IPEXFalconMLP(module.mlp, config) + self.self_attention = _IPEXFalconAttention(module.self_attention, device, config) + self.mlp = _IPEXFalconMLP(module.mlp, device, config) + if getattr(config, "quantization_config", None): + _remove_hooks_for_ipex(self, True) def forward(self, hidden_states: torch.Tensor, **kwargs): # Please see the original model's forward to check the parameter @@ -1006,17 +983,102 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return outputs +class _IPEXGPT2Block(nn.Module): + def __init__(self, module, device, config): + super().__init__() + _setattr_from_module(self, module) + self.attn = _IPEXGPT2Attention(module.attn, device, config) + self.mlp = _IPEXGPT2MLP(module.mlp, device, config) + if getattr(config, "quantization_config", None): + _remove_hooks_for_ipex(self, True) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + if hasattr(self.attn, "linear_add"): + hidden_states = self.attn.linear_add(attn_output, residual) + else: + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + **kwargs, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + if hasattr(self.mlp, "linear_add"): + hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual) + else: + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524 class _IPEXIntermediate(nn.Module): - def __init__(self, module, config): + def __init__(self, module, device, config): super().__init__() _setattr_from_module(self, module) - self.module_device = next(module.parameters()).device - if self.module_device.type == "cpu": - self.linear_gelu = LinearGelu(module.dense) - elif self.module_device.type == "xpu": - self.linear_gelu = XPULinearGelu(module.dense) + self.module_device = device + if getattr(config, "quantization_config", None) is None: + if self.module_device.type == "cpu": + self.linear_gelu = LinearGelu(module.dense) + elif self.module_device.type == "xpu": + self.linear_gelu = XPULinearGelu(module.dense) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.linear_gelu(hidden_states) + if hasattr(self, "linear_gelu"): + hidden_states = self.linear_gelu(hidden_states) + else: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 3263e31db3..81172090be 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -189,6 +189,7 @@ def maybe_apply_torch_compile(self): self.model.device.type != "cpu" or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE) + or getattr(self.config, "quantization_config", None) ): return if self.use_cache and not self._supports_static_cache: @@ -297,6 +298,7 @@ def forward( def _prepare_generation_config( self, generation_config: Optional[GenerationConfig], **kwargs: Dict ) -> Tuple[GenerationConfig, Dict]: + kwargs["use_cache"] = self.use_cache generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) generation_method = generation_config.get_generation_mode().value if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache: diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 419e1bb42a..ac311f40f6 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -29,10 +29,12 @@ AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, AutoTokenizer, + BitsAndBytesConfig, GenerationConfig, PretrainedConfig, pipeline, set_seed, + is_bitsandbytes_available, ) from optimum.intel import ( IPEXModel, @@ -433,6 +435,51 @@ def test_patched_model(self, model_arch): ) self.assertTrue(torch.allclose(ipex_outputs.logits[0], exported_outputs.logits[0], atol=1e-7)) + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @unittest.skipIf(not is_bitsandbytes_available(), reason="Test requires bitsandbytes") + def test_bnb(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + # Test model forward do not need cache. + ipex_model = IPEXModelForCausalLM.from_pretrained( + model_id, torch_dtype=dtype, device_map=DEVICE, quantization_config=quantization_config + ) + self.assertIsInstance(ipex_model.config, PretrainedConfig) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample", + return_tensors="pt", + return_token_type_ids=False if model_arch in ("llama", "llama2") else None, + ).to(DEVICE) + inputs = ipex_model.prepare_inputs_for_generation(**tokens) + outputs = ipex_model(**inputs) + + self.assertIsInstance(outputs.logits, torch.Tensor) + + transformers_model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=dtype, device_map=DEVICE, quantization_config=quantization_config + ) + with torch.no_grad(): + transformers_outputs = transformers_model(**tokens) + + # Test re-load model + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype, device_map=DEVICE) + loaded_model_outputs = loaded_model(**inputs) + + # Test init method + init_model = self.IPEX_MODEL_CLASS(transformers_model) + init_model_outputs = init_model(**inputs) + + # Compare tensor outputs + self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + # To avoid float pointing error + self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7)) + self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7)) + class IPEXModelForAudioClassificationTest(unittest.TestCase): IPEX_MODEL_CLASS = IPEXModelForAudioClassification