Skip to content

Commit

Permalink
[Wandb] Refine wandb logging function (#416)
Browse files Browse the repository at this point in the history
* Refine wandb logging function

* Address comments

* enable user to specify wandb local save dir

* Update and fix comments

* Update
  • Loading branch information
billishyahao authored Jul 16, 2024
1 parent 0971e68 commit 73029ed
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 95 deletions.
13 changes: 12 additions & 1 deletion examples_deepspeed/pretrain_llama2_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ GRAD_CLIP=1
# activation_checkpoint="true"
activation_checkpoint="false"

LOG_TO_WANDB=0
WANDB_ARGS=
if [ $LOG_TO_WANDB -eq 1 ]
then
WANDB_ARGS="\
--wandb-project pretrain-llama2 \
--wandb-exp-name exp0 \
--wandb-save-dir ${BASE_PATH}/wandb \
"
fi

# Below configuration required for llama model as per llama paper
# --no-query-key-layer-scaling \
# --attention-dropout 0 \
Expand All @@ -53,7 +64,6 @@ activation_checkpoint="false"
######################################



cat <<EOT > $DS_CONFIG
{
"train_batch_size" : $GLOBAL_BATCH_SIZE,
Expand Down Expand Up @@ -132,4 +142,5 @@ torchrun $DISTRIBUTED_ARGS \
--normalization rmsnorm \
--disable-bias-linear \
--num-key-value-heads $NUM_KV_HEADS \
$WANDB_ARGS \
$ds_args
1 change: 1 addition & 0 deletions megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_wandb_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .initialize import initialize_megatron
Expand Down
6 changes: 6 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,12 @@ def _add_logging_args(parser):
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
group.add_argument('--wandb-project', type=str, default='',
help='The wandb project name. Ignore wandb by default.')
group.add_argument('--wandb-exp-name', type=str, default='',
help='The wandb experiment name.')
group.add_argument('--wandb-save-dir', type=str, default='',
help='Path to save the wandb results locally.')

return parser

Expand Down
47 changes: 47 additions & 0 deletions megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_WANDB_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None
Expand Down Expand Up @@ -56,6 +57,12 @@ def get_tensorboard_writer():
return _GLOBAL_TENSORBOARD_WRITER


def get_wandb_writer():
"""Return wandb writer. It can be None so no need
to check if it is initialized."""
return _GLOBAL_WANDB_WRITER


def get_adlr_autoresume():
"""ADLR autoresume object. It can be None so no need
to check if it is initialized."""
Expand Down Expand Up @@ -91,6 +98,7 @@ def set_global_variables(args):
_build_num_microbatches_calculator(args)
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_wandb_writer(args)
_set_adlr_autoresume(args)
_set_timers(args)

Expand Down Expand Up @@ -152,6 +160,45 @@ def _set_tensorboard_writer(args):
'no TensorBoard logs will be written.', flush=True)


def _set_wandb_writer(args):
"""Set wandb writer."""
global _GLOBAL_WANDB_WRITER
_ensure_var_is_not_initialized(_GLOBAL_WANDB_WRITER,
'wandb writer')
getattr(args, 'wandb_project', '')
getattr(args, 'wandb_exp_name', '')

if args.rank == (args.world_size - 1):
if args.wandb_project == '' or \
args.wandb_exp_name == '':
print('WARNING: WANDB writing requested but no legit wandb '
'project or experiment name provided, '
'therefore WANDB logs will be written '
'according to random generated project or experiment name.', flush=True)

try:
import wandb
except (ImportError, ModuleNotFoundError):
print('WARNING: WANDB writing requested but is not '
'available (try to pip install wandb to solve it), '
'no WANDB logs will be written.', flush=True)
return

if args.wandb_save_dir:
save_dir = args.wandb_save_dir
else:
# Defaults to the save dir.
save_dir = os.path.join(args.save, 'wandb')
wandb_kwargs = {
'dir': save_dir,
'name': args.wandb_exp_name,
'project': args.wandb_project,
'config': vars(args)}
os.makedirs(wandb_kwargs['dir'], exist_ok=True)
wandb.init(**wandb_kwargs)
_GLOBAL_WANDB_WRITER = wandb


def _set_adlr_autoresume(args):
"""Initialize ADLR autoresume."""
global _GLOBAL_ADLR_AUTORESUME
Expand Down
2 changes: 1 addition & 1 deletion megatron/timers.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def write(self, names, writer, iteration, normalizer=1.0,
assert normalizer > 0.0
name_to_min_max_time = self._get_global_min_max_time(
names, reset, barrier, normalizer)
if writer is not None:
if writer.is_enabled():
for name in name_to_min_max_time:
_, max_time = name_to_min_max_time[name]
writer.add_scalar(name + '-time', max_time, iteration)
Loading

0 comments on commit 73029ed

Please sign in to comment.