Skip to content

Commit

Permalink
update Megatron-LM plugin code to version 0.8.0 or higher. (#3174)
Browse files Browse the repository at this point in the history
* I have adapted the Megatron-LM plugin code to version 0.8.0 or higher.

* update megatron import in set_tensorboard_logging_options
  • Loading branch information
eljandoubi authored Oct 24, 2024
1 parent 1ace241 commit 2f39575
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,7 +2202,7 @@ def set_scheduler_args(self, scheduler):
self.megatron_lm_default_args["min_lr"] = self.min_lr

def set_tensorboard_logging_options(self):
from megatron.arguments import _add_logging_args
from megatron.training.arguments import _add_logging_args

parser = argparse.ArgumentParser()
parser = _add_logging_args(parser)
Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def is_megatron_lm_available():
if importlib.util.find_spec("megatron") is not None:
try:
megatron_version = parse(importlib.metadata.version("megatron-core"))
if compare_versions(megatron_version, "==", "0.5.0"):
return importlib.util.find_spec(".data", "megatron")
if compare_versions(megatron_version, ">=", "0.8.0"):
return importlib.util.find_spec(".training", "megatron")
except Exception as e:
warnings.warn(f"Parse Megatron version failed. Exception:{e}")
return False
Expand Down
50 changes: 25 additions & 25 deletions src/accelerate/utils/megatron_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,56 +30,56 @@


if is_megatron_lm_available():
from megatron import (
from megatron.core import mpu, tensor_parallel
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from megatron.core.distributed import finalize_model_grads
from megatron.core.enums import ModelType
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.optimizer import get_megatron_optimizer
from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.utils import get_model_config
from megatron.inference.text_generation.communication import broadcast_int_list, broadcast_tensor
from megatron.inference.text_generation.generation import (
beam_search_and_return_on_first_stage,
generate_tokens_probs_and_return_on_first_stage,
)
from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets
from megatron.legacy.model import BertModel, Float16Module, GPTModel, T5Model
from megatron.legacy.model.classification import Classification
from megatron.training import (
get_args,
get_num_microbatches,
get_tensorboard_writer,
get_tokenizer,
print_rank_last,
)
from megatron.arguments import (
from megatron.training.arguments import (
_add_data_args,
_add_validation_args,
core_transformer_config_from_args,
parse_args,
validate_args,
)
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint
from megatron.core import mpu, tensor_parallel
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from megatron.core.distributed import finalize_model_grads
from megatron.core.enums import ModelType
from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.utils import get_model_config
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.global_vars import set_global_variables
from megatron.initialize import (
from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint
from megatron.training.global_vars import set_global_variables
from megatron.training.initialize import (
_compile_dependencies,
_init_autoresume,
_initialize_distributed,
_set_random_seed,
set_jit_fusion_options,
write_args_to_tensorboard,
)
from megatron.model import BertModel, Float16Module, GPTModel, T5Model
from megatron.model.classification import Classification
from megatron.optimizer import get_megatron_optimizer
from megatron.text_generation.communication import broadcast_int_list, broadcast_tensor
from megatron.text_generation.generation import (
beam_search_and_return_on_first_stage,
generate_tokens_probs_and_return_on_first_stage,
)
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron.training import (
from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
from megatron.training.training import (
build_train_valid_test_data_iterators,
get_optimizer_param_scheduler,
num_floating_point_operations,
setup_model_and_optimizer,
train_step,
training_log,
)
from megatron.utils import (
from megatron.training.utils import (
average_losses_across_data_parallel_group,
calc_params_l2_norm,
get_ltor_masks_and_position_ids,
Expand Down

0 comments on commit 2f39575

Please sign in to comment.