diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 0c86c796d22..e9b0397614f 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -2202,7 +2202,7 @@ def set_scheduler_args(self, scheduler): self.megatron_lm_default_args["min_lr"] = self.min_lr def set_tensorboard_logging_options(self): - from megatron.arguments import _add_logging_args + from megatron.training.arguments import _add_logging_args parser = argparse.ArgumentParser() parser = _add_logging_args(parser) diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 453042b27a0..f408e60d9d1 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -200,8 +200,8 @@ def is_megatron_lm_available(): if importlib.util.find_spec("megatron") is not None: try: megatron_version = parse(importlib.metadata.version("megatron-core")) - if compare_versions(megatron_version, "==", "0.5.0"): - return importlib.util.find_spec(".data", "megatron") + if compare_versions(megatron_version, ">=", "0.8.0"): + return importlib.util.find_spec(".training", "megatron") except Exception as e: warnings.warn(f"Parse Megatron version failed. Exception:{e}") return False diff --git a/src/accelerate/utils/megatron_lm.py b/src/accelerate/utils/megatron_lm.py index 552cb6d35f2..38e2a9d2788 100644 --- a/src/accelerate/utils/megatron_lm.py +++ b/src/accelerate/utils/megatron_lm.py @@ -30,31 +30,39 @@ if is_megatron_lm_available(): - from megatron import ( + from megatron.core import mpu, tensor_parallel + from megatron.core.distributed import DistributedDataParallel as LocalDDP + from megatron.core.distributed import finalize_model_grads + from megatron.core.enums import ModelType + from megatron.core.num_microbatches_calculator import get_num_microbatches + from megatron.core.optimizer import get_megatron_optimizer + from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank + from megatron.core.pipeline_parallel import get_forward_backward_func + from megatron.core.utils import get_model_config + from megatron.inference.text_generation.communication import broadcast_int_list, broadcast_tensor + from megatron.inference.text_generation.generation import ( + beam_search_and_return_on_first_stage, + generate_tokens_probs_and_return_on_first_stage, + ) + from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets + from megatron.legacy.model import BertModel, Float16Module, GPTModel, T5Model + from megatron.legacy.model.classification import Classification + from megatron.training import ( get_args, - get_num_microbatches, get_tensorboard_writer, get_tokenizer, print_rank_last, ) - from megatron.arguments import ( + from megatron.training.arguments import ( _add_data_args, _add_validation_args, core_transformer_config_from_args, parse_args, validate_args, ) - from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint - from megatron.core import mpu, tensor_parallel - from megatron.core.distributed import DistributedDataParallel as LocalDDP - from megatron.core.distributed import finalize_model_grads - from megatron.core.enums import ModelType - from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank - from megatron.core.pipeline_parallel import get_forward_backward_func - from megatron.core.utils import get_model_config - from megatron.data.dataset_utils import build_train_valid_test_datasets - from megatron.global_vars import set_global_variables - from megatron.initialize import ( + from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint + from megatron.training.global_vars import set_global_variables + from megatron.training.initialize import ( _compile_dependencies, _init_autoresume, _initialize_distributed, @@ -62,16 +70,8 @@ set_jit_fusion_options, write_args_to_tensorboard, ) - from megatron.model import BertModel, Float16Module, GPTModel, T5Model - from megatron.model.classification import Classification - from megatron.optimizer import get_megatron_optimizer - from megatron.text_generation.communication import broadcast_int_list, broadcast_tensor - from megatron.text_generation.generation import ( - beam_search_and_return_on_first_stage, - generate_tokens_probs_and_return_on_first_stage, - ) - from megatron.tokenizer.tokenizer import _vocab_size_with_padding - from megatron.training import ( + from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding + from megatron.training.training import ( build_train_valid_test_data_iterators, get_optimizer_param_scheduler, num_floating_point_operations, @@ -79,7 +79,7 @@ train_step, training_log, ) - from megatron.utils import ( + from megatron.training.utils import ( average_losses_across_data_parallel_group, calc_params_l2_norm, get_ltor_masks_and_position_ids,