diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index e9b0397614f..39e048a6039 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1297,7 +1297,7 @@ def set_deepspeed_weakref(self): if ds_config.get("train_batch_size", None) == "auto": del ds_config["train_batch_size"] - if compare_versions("transformers", "<", "4.33"): + if compare_versions("transformers", "<", "4.46"): from transformers.deepspeed import HfDeepSpeedConfig, unset_hf_deepspeed_config else: from transformers.integrations import HfDeepSpeedConfig, unset_hf_deepspeed_config diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 368a44675ff..044f67f5727 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -271,7 +271,7 @@ def test_init_zero3(self): with mockenv_context(**self.dist_env): accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin) # noqa: F841 - from transformers.deepspeed import is_deepspeed_zero3_enabled + from transformers.integrations import is_deepspeed_zero3_enabled assert is_deepspeed_zero3_enabled()