Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable quant model support #1074

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3888824
enable IPEXModelForSeq2SeqLM
jiqing-feng Dec 9, 2024
f9fa807
set static cache
jiqing-feng Dec 9, 2024
202df43
add tests for IPEXModelForSeq2SeqLM
jiqing-feng Dec 9, 2024
4488073
add docs
jiqing-feng Dec 9, 2024
16fecf8
fix readme
jiqing-feng Dec 9, 2024
de501f4
Merge branch 'main' into text2text
jiqing-feng Dec 10, 2024
4225bf0
refactor compile
jiqing-feng Dec 11, 2024
2ac7ecf
fix check
jiqing-feng Dec 11, 2024
24b988c
fix ruff check
jiqing-feng Dec 11, 2024
5c4f9a1
Merge branch 'huggingface:main' into text2text
jiqing-feng Dec 16, 2024
46b93a4
enable quantized model
jiqing-feng Dec 16, 2024
82d39ce
add bnb test
jiqing-feng Dec 16, 2024
7dc08da
add bnb tests in yaml
jiqing-feng Dec 16, 2024
30027ff
fix tests
jiqing-feng Dec 16, 2024
314db04
disable bnb tests
jiqing-feng Dec 16, 2024
87656ca
fix gpt2
jiqing-feng Dec 16, 2024
9a7e931
Merge branch 'main' into quant
jiqing-feng Dec 18, 2024
b0cec9c
set actual device
jiqing-feng Dec 18, 2024
94cf35d
assign device when convert class
jiqing-feng Dec 18, 2024
9af46d1
fix class init
jiqing-feng Dec 18, 2024
18b2a6a
fix ipex attn init
jiqing-feng Dec 18, 2024
9f6db33
rm set device on config
jiqing-feng Dec 18, 2024
6d8a969
fix format
jiqing-feng Dec 18, 2024
dd811f9
fix mlp class init
jiqing-feng Dec 18, 2024
d91eefb
Merge branch 'huggingface:main' into quant
jiqing-feng Jan 14, 2025
f094cad
Merge branch 'main' into quant
jiqing-feng Jan 21, 2025
dab4a78
add use_cache param when init generation config
jiqing-feng Jan 21, 2025
6bf3b8b
fix gpt2 quant model
jiqing-feng Jan 21, 2025
356d51d
fix falcon linear fusion
jiqing-feng Jan 22, 2025
d1eee87
fix falcon
jiqing-feng Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
Expand All @@ -27,13 +27,11 @@

from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXGPT2MLP,
_falcon_model_forward,
_gpt2_block_forward,
_gpt2_model_forward,
_ipex_rms_layer_norm_forward,
_IPEXFalconDecoderLayer,
_IPEXGPT2Attention,
_IPEXGPT2Block,
_IPEXIntermediate,
_IPEXLlamaDecoderLayer,
_llama_model_forward,
Expand All @@ -59,12 +57,12 @@ def convert_functions(m, target_m, new_function_name, new_function):
convert_functions(sub_m, target_m, new_function_name, new_function)


def convert_class(m, target_m, new_class, config=None):
def convert_class(m, target_m, new_class, device, config):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config)
new_m = new_class(sub_m, device, config)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, config)
convert_class(sub_m, target_m, new_class, device, config)


def patch_op(m, target_m, new_op_name, new_op):
Expand All @@ -82,7 +80,7 @@ def _patch_llama_model(model):
"""
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.device, model.config)
return model


Expand All @@ -98,21 +96,20 @@ def _patch_falcon_model(model):
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
replace_customized_linear_with_linear(model)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config)
return model


def _patch_gpt2_model(model):
"""
Patch gpt2 model:
1. Use IPEX paged attention
2. Linear fusion with (Linear + Add)
"""
num_key_value_heads = model.config.num_attention_heads
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config)
convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config)
return model


Expand All @@ -121,7 +118,7 @@ def _patch_bert_model(model):
Patch bert model:
1. Linear fusion with Linear + Gelu
"""
convert_class(model, BertIntermediate, _IPEXIntermediate)
convert_class(model, BertIntermediate, _IPEXIntermediate, model.device, model.config)
return model


Expand All @@ -130,7 +127,7 @@ def _patch_vit_model(model):
Patch vit model:
1. Linear fusion with Linear + Gelu
"""
convert_class(model, ViTIntermediate, _IPEXIntermediate)
convert_class(model, ViTIntermediate, _IPEXIntermediate, model.device, model.config)
return model


Expand Down
Loading
Loading