Skip to content

Commit

Permalink
add grad clipping, use wandb run name for save dir
Browse files Browse the repository at this point in the history
  • Loading branch information
Guitaricet committed Aug 6, 2023
1 parent ac3083c commit a89c534
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 275 deletions.
1 change: 0 additions & 1 deletion configs/pile_megatron_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# because we load it from yaml and then feed to NeoXArgs.from_dict().
# Use _ instead of - in the key names

"global_num_gpus": 8,
"pipe_parallel_size": 1,
"model_parallel_size": 1,

Expand Down
9 changes: 0 additions & 9 deletions peft_pretraining/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@ def check_args_torchrun_main(args):
logger.error("Are you sure? Not training LN is a bad idea.")
raise ValueError("Are you sure? Not training LN is a bad idea.")

if args.save_dir is None:
if args.model_config is not None:
# use checkpoints / model name, date and time as save directory
args.save_dir = f"checkpoints/{args.model_config.split('/')[-1].rstrip('.json')}-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
elif args.model_name_or_path is not None:
args.save_dir = f"checkpoints/{args.model_name_or_path.split('/')[-1]}-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
else:
raise ValueError("Either --args.save_dir or --model_config or --model_name_or_path must be specified")

if args.tags is not None:
args.tags = args.tags.split(",")

Expand Down
8 changes: 3 additions & 5 deletions peft_pretraining/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,12 @@ def print_optimizer_state_size(optimizer):
first_moment_count += torch.numel(state['exp_avg'])
second_moment_count += torch.numel(state['exp_avg_sq'])

logger.info(f'Number of floats in the first moment: {first_moment_count / 1_000_000:.2f}M')
logger.info(f'Number of floats in the second moment: {second_moment_count / 1_000_000:.2f}M')
global_rank = 0
if dist.is_initialized():
global_rank = dist.get_rank()
if 0 < global_rank < 8:
print(f"(Rank {global_rank}) Number of floats in the first moment: {first_moment_count / 1_000_000:.2f}M")
print(f"(Rank {global_rank}) Number of floats in the second moment: {second_moment_count / 1_000_000:.2f}M")

print(f"(Rank {global_rank}) Number of floats in the first moment: {first_moment_count / 1_000_000:.2f}M")
print(f"(Rank {global_rank}) Number of floats in the second moment: {second_moment_count / 1_000_000:.2f}M")


def check_lr_and_alert(optimizer, max_lr):
Expand Down
71 changes: 0 additions & 71 deletions scripts/60M_relora.sh

This file was deleted.

52 changes: 0 additions & 52 deletions scripts/9m_model_in_depth.sh

This file was deleted.

52 changes: 0 additions & 52 deletions scripts/scaling_laws_full.sh

This file was deleted.

61 changes: 0 additions & 61 deletions scripts/scaling_laws_lora.sh

This file was deleted.

Loading

0 comments on commit a89c534

Please sign in to comment.