Skip to content

Commit

Permalink
Fix sentence transformers ipex support (#1128)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Jan 24, 2025
1 parent b49fcbb commit 833ab0d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
9 changes: 8 additions & 1 deletion optimum/intel/ipex/modeling_sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@
from transformers import MT5Config, T5Config
from transformers.dynamic_module_utils import get_class_from_dynamic_module

from ..utils.import_utils import _sentence_transformers_version, is_sentence_transformers_version
from .modeling_base import IPEXModel


class IPEXTransformer(Transformer):
def __init__(self, *args, **kwargs):
if is_sentence_transformers_version("<", "3.4"):
raise ImportError(
f"Backend: ipex requires sentence-transformers>=3.4 but found {_sentence_transformers_version}. "
"You can install it with pip: `pip install --upgrade sentence-transformers`"
)

super().__init__(*args, **kwargs)
self.backend = "ipex"

def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
def _load_model(self, model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args) -> None:
self._load_ipex_model(model_name_or_path, config, cache_dir, **model_args)

def _load_ipex_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
Expand Down
21 changes: 18 additions & 3 deletions optimum/intel/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,10 @@


_sentence_transformers_available = importlib.util.find_spec("sentence_transformers") is not None
_sentence_transformers_available = "N/A"

_sentence_transformers_version = "N/A"
if _sentence_transformers_available:
try:
_sentence_transformers_available = importlib_metadata.version("sentence_transformers")
_sentence_transformers_version = importlib_metadata.version("sentence_transformers")
except importlib_metadata.PackageNotFoundError:
_sentence_transformers_available = False

Expand Down Expand Up @@ -449,6 +448,15 @@ def is_datasets_version(operation: str, version: str):
return compare_versions(parse(_datasets_version), operation, version)


def is_sentence_transformers_version(operation: str, version: str):
"""
Compare the current sentence-transformers version to a given reference with an operation.
"""
if not _sentence_transformers_available:
return False
return compare_versions(parse(_sentence_transformers_version), operation, version)


DIFFUSERS_IMPORT_ERROR = """
{0} requires the diffusers library but it was not found in your environment. You can install it with pip:
`pip install diffusers`. Please note that you may need to restart your runtime after installation.
Expand Down Expand Up @@ -484,6 +492,12 @@ def is_datasets_version(operation: str, version: str):
`pip install accelerate`. Please note that you may need to restart your runtime after installation.
"""

SENTENCE_TRANSFORMERS_IMPORT_ERROR = """
{0} requires the sentence-transformers library but it was not found in your environment. You can install it with pip:
`pip install sentence-transformers`. Please note that you may need to restart your runtime after installation.
"""


BACKENDS_MAPPING = OrderedDict(
[
("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)),
Expand All @@ -492,6 +506,7 @@ def is_datasets_version(operation: str, version: str):
("openvino", (is_openvino_available, OPENVINO_IMPORT_ERROR)),
("neural_compressor", (is_neural_compressor_available, NEURAL_COMPRESSOR_IMPORT_ERROR)),
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
("sentence_transformers", (is_sentence_transformers_available, SENTENCE_TRANSFORMERS_IMPORT_ERROR)),
]
)

Expand Down

0 comments on commit 833ab0d

Please sign in to comment.