Skip to content

Commit

Permalink
Fix OVTrainer for transformers >= v4.29.0 (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed May 31, 2023
1 parent 50324e6 commit 7019728
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
14 changes: 11 additions & 3 deletions optimum/intel/openvino/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from optimum.exporters.onnx import OnnxConfig

from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_transformers_version
from .configuration import OVConfig
from .quantization import OVDataLoader
from .training_args import OVTrainingArguments
Expand Down Expand Up @@ -285,9 +286,16 @@ def _inner_training_loop(
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()

if self.args.local_rank != -1:
if self.compression_controller is not None:
self.compression_controller.distributed()
if is_transformers_version("<", "4.29.0"):
is_distributed = self.args.local_rank != -1
else:
from accelerate.utils import DistributedType

is_distributed = self.args.distributed_state.distributed_type != DistributedType.NO

if self.compression_controller is not None and is_distributed:
self.compression_controller.distributed()

model = self._wrap_model(self.model_wrapped)

if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
],
"openvino": ["openvino>=2023.0.0.dev20230217", "onnx", "onnxruntime"],
"nncf": ["nncf>=2.4.0", "openvino-dev>=2023.0.0.dev20230217"],
"ipex": ["intel-extension-for-pytorch"],
"ipex": ["intel-extension-for-pytorch", "onnx"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down

0 comments on commit 7019728

Please sign in to comment.