Skip to content

Commit

Permalink
Enable IPEXModel with deepspeed (#732)
Browse files Browse the repository at this point in the history
* support deepspeed

* enable IPEXModel with deepspeed

* make style

* add env setup

* revert changes for openvino

* rm json

* fix style

* rm import deepspeed

* fix comments

* fix deepspeed layer check

* fix layer check

* update deepspeed readme

* update examples

* rebase

* add jit trace in init method

* Update optimum/intel/ipex/modeling_base.py

Co-authored-by: Ella Charlaix <[email protected]>

* add tests for init method

* assign pipeline framework to pt

* delete deepspeed example

* fix config check

* fix tests

* fix float pointing error

---------

Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
jiqing-feng and echarlaix authored Jun 14, 2024
1 parent 0c2217d commit 458f271
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 20 deletions.
6 changes: 3 additions & 3 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def convert_functions(m, target_m, new_function_name, new_function):
convert_functions(sub_m, target_m, new_function_name, new_function)


def convert_class(m, target_m, new_class, config, distributed=False):
def convert_class(m, target_m, new_class, config):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config, distributed)
new_m = new_class(sub_m, config)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, config, distributed)
convert_class(sub_m, target_m, new_class, config)


def patch_op(m, target_m, new_op_name, new_op):
Expand Down
20 changes: 9 additions & 11 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,18 @@ def _llama_model_forward(

# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
class _IPEXLlamaAttention(nn.Module):
def __init__(self, module, config, distributed=False) -> None:
def __init__(self, module, config) -> None:
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding"
)
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.distributed = distributed
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding

if not self.distributed:
self.mha_linear_add = LinearAdd(self.o_proj)
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)
del self.__dict__["_modules"]["o_proj"]
self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(
text_max_length=module.config.max_position_embeddings
Expand Down Expand Up @@ -296,18 +295,18 @@ def forward(

# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186
class _IPEXLlamaMLP(nn.Module):
def __init__(self, module, config, distributed=False) -> None:
def __init__(self, module, config) -> None:
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports Linear2SiluMul, LinearAdd"
)
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.distributed = distributed
from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd

if not self.distributed:
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
del self.__dict__["_modules"]["down_proj"]
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
Expand Down Expand Up @@ -336,12 +335,11 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **

# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayer(nn.Module):
def __init__(self, module, config, distributed=False):
def __init__(self, module, config):
super().__init__()
_setattr_from_module(self, module)
self.distributed = distributed
self.self_attn = _IPEXLlamaAttention(module.self_attn, config, distributed)
self.mlp = _IPEXLlamaMLP(module.mlp, config, distributed)
self.self_attn = _IPEXLlamaAttention(module.self_attn, config)
self.mlp = _IPEXLlamaMLP(module.mlp, config)

def forward(
self,
Expand Down
21 changes: 15 additions & 6 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,22 @@ def __init__(
warmup: bool = True,
**kwargs,
):
OptimizedModel.__init__(self, model=model, config=config)
if is_torch_xpu_available(check_device=True):
self._device = torch.device("xpu:0")
elif torch.cuda.is_available():
self._device = torch.device("cuda:0")
else:
self._device = torch.device("cpu")

# CPU only support jit model for now.
if not isinstance(model, torch.jit.RecursiveScriptModule):
config = model.config if config is None else config
use_cache = getattr(model.config, "use_cache", False)
model = ipex_jit_trace(model, self.export_feature, use_cache)
config.torchscript = True

OptimizedModel.__init__(self, model=model, config=config)

self.model.to(self._device)
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
self.model_save_dir = model_save_dir
Expand Down Expand Up @@ -438,8 +447,8 @@ def __init__(
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)
GenerationMixin.__init__(self)

model_type = config.model_type.replace("_", "-")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config)
model_type = self.config.model_type.replace("_", "-")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config)
self.use_cache = "past_key_values" in self.input_names

if use_cache ^ self.use_cache:
Expand All @@ -449,10 +458,10 @@ def __init__(
f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. "
"To export your model, simply set `export=True`."
)
config.is_decoder = True
config.is_encoder_decoder = False
self.config.is_decoder = True
self.config.is_encoder_decoder = False

self.generation_config = GenerationConfig.from_model_config(config)
self.generation_config = GenerationConfig.from_model_config(self.config)
try:
self.model_cls = get_class_from_dynamic_module(
self.config.auto_map["AutoModelForCausalLM"], model_save_dir
Expand Down
38 changes: 38 additions & 0 deletions tests/ipex/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,22 @@ def test_compare_to_transformers(self, model_arch):
transformers_outputs = transformers_model(**tokens)
outputs = ipex_model(**tokens)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname)
loaded_model_outputs = loaded_model(**tokens)
# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model)
init_model_outputs = init_model(**tokens)
self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule)

# Compare tensor outputs
for output_name in {"logits", "last_hidden_state"}:
if output_name in transformers_outputs:
self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4))
self.assertTrue(torch.equal(outputs[output_name], loaded_model_outputs[output_name]))
self.assertTrue(torch.equal(outputs[output_name], init_model_outputs[output_name]))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
Expand Down Expand Up @@ -147,18 +154,26 @@ def test_compare_to_transformers(self, model_arch):
transformers_outputs = transformers_model(**tokens)
outputs = ipex_model(**tokens)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname)
loaded_model_outputs = loaded_model(**tokens)

# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model)
init_model_outputs = init_model(**tokens)
self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule)

self.assertIn("start_logits", outputs)
self.assertIn("end_logits", outputs)
# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.start_logits, transformers_outputs.start_logits, atol=1e-4))
self.assertTrue(torch.allclose(outputs.end_logits, transformers_outputs.end_logits, atol=1e-4))
self.assertTrue(torch.equal(outputs.start_logits, loaded_model_outputs.start_logits))
self.assertTrue(torch.equal(outputs.end_logits, loaded_model_outputs.end_logits))
self.assertTrue(torch.equal(outputs.start_logits, init_model_outputs.start_logits))
self.assertTrue(torch.equal(outputs.end_logits, init_model_outputs.end_logits))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
Expand Down Expand Up @@ -220,13 +235,21 @@ def test_compare_to_transformers(self, model_arch):
with torch.no_grad():
transformers_outputs = transformers_model(**tokens)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname)
loaded_model_outputs = loaded_model(**inputs)

# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model)
init_model_outputs = init_model(**inputs)
self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule)

# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits))
self.assertTrue(torch.equal(outputs.logits, init_model_outputs.logits))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
Expand Down Expand Up @@ -354,13 +377,21 @@ def test_compare_to_transformers(self, model_arch):
transformers_outputs = transformers_model(**inputs)
outputs = ipex_model(**inputs)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname)
loaded_model_outputs = loaded_model(**inputs)

# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model)
init_model_outputs = init_model(**inputs)
self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule)

# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-3))
self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits))
self.assertTrue(torch.equal(outputs.logits, init_model_outputs.logits))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
Expand Down Expand Up @@ -400,15 +431,22 @@ def test_compare_to_transformers(self, model_arch):
transformers_outputs = transformers_model(**inputs)
outputs = ipex_model(**inputs)

# Test re-load model
with tempfile.TemporaryDirectory() as tmpdirname:
ipex_model.save_pretrained(tmpdirname)
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname)
loaded_model_outputs = loaded_model(**inputs)

# Test init method
init_model = self.IPEX_MODEL_CLASS(transformers_model)
init_model_outputs = init_model(**inputs)
self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule)

self.assertIn("logits", outputs)
# Compare tensor outputs
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4))
self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits))
self.assertTrue(torch.allclose(init_model_outputs.logits, transformers_outputs.logits, atol=1e-4))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline(self, model_arch):
Expand Down

0 comments on commit 458f271

Please sign in to comment.