forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 349
/
Copy pathcheckpointing.py
784 lines (655 loc) · 31.3 KB
/
checkpointing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Input/output checkpointing."""
import os
import random
import sys
import numpy as np
from deepspeed.accelerator import get_accelerator
import torch
from megatron import update_num_microbatches, get_tokenizer
from megatron.core import mpu, tensor_parallel
from .global_vars import get_args
from .utils import (unwrap_model,
print_rank_0,
is_rank_0)
from deepspeed.checkpoint import (
ORIGINAL_VOCAB_SIZE,
PADDED_VOCAB_SIZE,
UNIVERSAL_CHECKPOINT_INFO,
UNIVERSAL_CHECKPOINT_VERSION_KEY,
UNIVERSAL_CHECKPOINT_VERSION_VALUE,
)
_CHECKPOINT_VERSION = None
def set_checkpoint_version(value):
global _CHECKPOINT_VERSION
if _CHECKPOINT_VERSION is not None:
assert _CHECKPOINT_VERSION == value, \
"checkpoint versions do not match"
_CHECKPOINT_VERSION = value
def get_checkpoint_version():
global _CHECKPOINT_VERSION
return _CHECKPOINT_VERSION
def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retrieved from checkpoint."""
args = get_args()
def _compare(arg_name, old_arg_name=None, default=None):
if old_arg_name is not None:
ckpt_arg_name = old_arg_name
else:
ckpt_arg_name = arg_name
if default is not None:
checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default)
else:
checkpoint_value = getattr(checkpoint_args, ckpt_arg_name)
args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format(
arg_name, checkpoint_value, args_value)
assert checkpoint_value == args_value, error_message
if not args.mos and not args.kd:
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('num_key_value_heads')
_compare('add_position_embedding', default=True)
if args.vocab_file:
_compare('max_position_embeddings')
if not args.universal_checkpoint:
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
if args.data_parallel_random_init:
_compare('data_parallel_random_init')
if get_checkpoint_version() < 3.0 and not args.universal_checkpoint:
_compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size')
if get_checkpoint_version() >= 3.0 and not args.universal_checkpoint:
_compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size')
def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok = True)
def get_checkpoint_name(checkpoints_path, iteration, release=False,
pipeline_parallel=None,
tensor_rank=None, pipeline_rank=None):
"""Determine the directory name for this rank's checkpoint."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
if pipeline_parallel is None:
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
if not pipeline_parallel:
common_path = os.path.join(checkpoints_path, directory,
f'mp_rank_{tensor_rank:02d}')
else:
common_path = os.path.join(checkpoints_path, directory,
f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
return os.path.join(common_path, "model_optim_rng.pt")
def get_distributed_optimizer_checkpoint_name(model_checkpoint_name):
return os.path.join(os.path.dirname(model_checkpoint_name),
"distrib_optim.pt")
def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
"""Finds the checkpoint for rank 0 without knowing if we are using
pipeline parallelism or not.
Since the checkpoint naming scheme changes if pipeline parallelism
is present, we need to look for both naming schemes if we don't
know if the checkpoint has pipeline parallelism.
"""
# Look for checkpoint with no pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
# Look for checkpoint with pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
return None, None
def get_checkpoint_tracker_filename(checkpoints_path):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def read_metadata(tracker_filename):
# Read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Get the max iteration retrieved across the ranks.
if torch.distributed.is_initialized():
iters_cuda = get_accelerator().LongTensor([iteration])
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
max_iter = iters_cuda[0].item()
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
if iteration != max_iter:
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
torch.distributed.get_rank(), iteration, max_iter), flush=True)
else:
# When loading a checkpoint outside of training (for example,
# when editing it), we might not have torch distributed
# initialized, in this case, just assume we have the latest
max_iter = iteration
return max_iter, release
def get_rng_state():
""" collect rng state across data parallel ranks """
args = get_args()
rng_state = {
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': get_accelerator().get_rng_state(),
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
rng_state,
group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]
return rng_state_list
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
"""Save a model checkpoint."""
args = get_args()
# Only rank zero of the data parallel writes to the disk.
if not args.deepspeed:
model = unwrap_model(model)
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# Collect rng state across data parallel ranks.
rng_state = get_rng_state()
# Checkpoint name.
checkpoint_name = get_checkpoint_name(args.save, iteration)
# Save distributed optimizer's custom parameter state.
if args.use_distributed_optimizer:
optim_checkpoint_name = \
get_distributed_optimizer_checkpoint_name(checkpoint_name)
ensure_directory_exists(optim_checkpoint_name)
optimizer.save_parameter_state(optim_checkpoint_name)
# Collect args, model, RNG.
if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0 or args.deepspeed:
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
state_dict['tokens'] = args.consumed_train_tokens
state_dict[UNIVERSAL_CHECKPOINT_INFO] = _universal_checkpoint_info(model)
# DeepSpeed saves the model/optimizer/scheduler
if not args.deepspeed:
if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint()
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = \
opt_param_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict["rng_state"] = rng_state
# Save.
if not args.deepspeed:
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
if args.deepspeed:
#megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict
#state_dict is used by deepspeed for module saving so it needs to point to the right function
if args.no_pipeline_parallel:
original_state_dict = model[0].module.state_dict
def state_dict_for_save_checkpoint_deepspeed(destination=None, prefix='', keep_vars=False):
return model[0].module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
model[0].module.state_dict = state_dict_for_save_checkpoint_deepspeed
# Saving is a collective communication
checkpoint_name = get_checkpoint_name(args.save, iteration)
# Trim off the filename and mp_rank_* directory.
for _ in range(3):
checkpoint_name = os.path.dirname(checkpoint_name)
model[0].save_checkpoint(checkpoint_name, client_state=state_dict)
if args.no_pipeline_parallel:
model[0].module.state_dict = original_state_dict
# Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \
.format(iteration, args.save))
# And update the latest iteration
if is_rank_0():
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
# We use a self_attention module but the values extracted aren't
# specific to self attention so should work for cross attention as well
while hasattr(model, 'module'):
model = model.module
attention_module = model.language_model.encoder.layers[0].self_attention
#attention_module = model.language_model.encoder.layers[0].attention
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
if num_splits_first:
"""[num_splits * np * hn, h]
-->(view) [num_splits, np, hn, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_splits, num_attention_heads_per_partition,
hidden_size_per_attention_head) + input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(0, 1).contiguous()
else:
"""[np * hn * num_splits, h]
-->(view) [np, hn, num_splits, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_attention_heads_per_partition,
hidden_size_per_attention_head, num_splits) +\
input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(1, 2).contiguous()
t = t.view(*input_shape)
return t
def fix_query_key_value_ordering(model, checkpoint_version):
"""Fix up query/key/value matrix ordering if checkpoint
version is smaller than 2.0
"""
if checkpoint_version < 2.0:
if isinstance(model, list):
assert len(model)==1
model = model[0]
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def _load_base_checkpoint(load_dir, rank0=False):
""" Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
"""
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return nothing
if not os.path.isfile(tracker_filename):
if not rank0:
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return None, False
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration, release = read_metadata(tracker_filename)
# Checkpoint.
if rank0:
checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
else:
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if release:
print_rank_0(f' loading release checkpoint from {load_dir}')
else:
print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
if not rank0:
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException as e:
print_rank_0('could not load the checkpoint')
print_rank_0(e)
sys.exit()
return state_dict, release
def load_args_from_checkpoint(args, load_arg='load'):
"""Set required arguments from the checkpoint specified in the
arguments.
Will overwrite arguments that have a non-None default value, but
will leave any arguments that default to None as set.
Returns the same args NameSpace with the new values added/updated.
If no checkpoint is specified in args, or if the checkpoint is
there but invalid, the arguments will not be modified
"""
load_dir = getattr(args, load_arg)
if load_dir is None:
print_rank_0('No load directory specified, using provided arguments.')
return args
state_dict, release = _load_base_checkpoint(load_dir, rank0=True)
# Args.
if not state_dict:
print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
return args
if 'args' not in state_dict:
print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
return args
checkpoint_args = state_dict['args']
checkpoint_version = state_dict.get('checkpoint_version', 0)
args.iteration = state_dict['iteration']
# One-off conversion for foundation models
if hasattr(checkpoint_args, 'disable_bias_linear'):
setattr(checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear'))
def _set_arg(arg_name, old_arg_name=None, force=False):
if not force and getattr(args, arg_name, None) is not None:
return
if old_arg_name is not None:
checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
else:
checkpoint_value = getattr(checkpoint_args, arg_name, None)
if checkpoint_value is not None:
print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
setattr(args, arg_name, checkpoint_value)
else:
print_rank_0(f"Checkpoint did not provide arguments {arg_name}")
_set_arg('num_layers')
_set_arg('hidden_size')
_set_arg('ffn_hidden_size')
_set_arg('seq_length')
_set_arg('num_attention_heads')
_set_arg('num_key_value_heads')
_set_arg('kv_channels')
_set_arg('max_position_embeddings')
_set_arg('add_position_embedding', force=True)
_set_arg('use_rotary_position_embeddings', force=True)
_set_arg('rotary_percent', force=True)
_set_arg('add_bias_linear', force=True)
_set_arg('swiglu', force=True)
_set_arg('untie_embeddings_and_output_weights', force=True)
_set_arg('apply_layernorm_1p', force=True)
_set_arg('tokenizer_type')
_set_arg('padded_vocab_size')
if checkpoint_version < 3.0:
_set_arg('tensor_model_parallel_size',
'model_parallel_size')
else:
_set_arg('tensor_model_parallel_size', force=True)
_set_arg('pipeline_model_parallel_size', force=True)
_set_arg('virtual_pipeline_model_parallel_size', force=True)
_set_arg('num_layers_per_virtual_pipeline_stage')
return args, checkpoint_args
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True, load_only_weights=False):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
"""
args = get_args()
load_dir = getattr(args, load_arg)
if args.deepspeed:
if args.finetune:
loaded_dir, state_dict = model[0].load_checkpoint(load_dir,
load_module_strict=strict, load_optimizer_states=False,
load_lr_scheduler_states=False, load_module_only=True,
tag=args.load_tag)
else:
loaded_dir, state_dict = model[0].load_checkpoint(load_dir,
load_module_strict=strict, tag=args.load_tag)
if loaded_dir is None:
print_rank_0('WARNING: could not find the metadata file {} '.format(
load_dir))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0
release = False
else:
model = unwrap_model(model)
state_dict, release = _load_base_checkpoint(load_dir, rank0=False)
# Checkpoint not loaded.
if state_dict is None:
# Conditionally exit at this point.
if args.exit_on_missing_checkpoint:
print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
torch.distributed.barrier()
sys.exit()
# Iteration defaults to 0.
return 0
checkpoint_name = get_checkpoint_name(load_dir, state_dict['iteration'], release)
# Set checkpoint version.
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
# Set iteration.
if args.finetune or release or args.reset_iteration or load_only_weights:
iteration = 0
# Make DeepSpeed engine aware of this reset of iteration
model[0].global_steps = 0
else:
try:
iteration = state_dict['iteration']
if 'tokens' in state_dict:
args.consumed_train_tokens = state_dict['tokens']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
checkpoint_name))
sys.exit()
# Check arguments.
reset_train_valid_samples = args.reset_iteration
if not load_only_weights and not reset_train_valid_samples:
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict and not args.finetune:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
else:
print_rank_0('could not find arguments in the checkpoint ...')
# Model.
if not args.deepspeed:
if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed.
checkpoint_version = get_checkpoint_version()
print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version)
# Optimizer.
if not args.deepspeed:
if not release and not args.finetune and not args.no_load_optim:
try:
# Load state dict.
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
# Load distributed optimizer's custom parameter state.
if args.use_distributed_optimizer:
tracker_filename = get_checkpoint_tracker_filename(load_dir)
iteration, release = read_metadata(tracker_filename)
model_checkpoint_name = \
get_checkpoint_name(load_dir, iteration, release)
optim_checkpoint_name = \
get_distributed_optimizer_checkpoint_name(
model_checkpoint_name)
optimizer.load_parameter_state(optim_checkpoint_name)
# Load scheduler.
if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
else:
if (args.fp16 or args.bf16) and optimizer is not None:
optimizer.reload_model_params()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
if 'rng_state' in state_dict:
# access rng_state for data parallel rank
if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
get_accelerator().set_rng_state(rng_state['cuda_rng_state'])
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
get_accelerator().set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the rng state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
if args.universal_checkpoint:
# TLDR: unique rng is needed for dropout to be really random on TP ranks
#
# Each tp-rank stores its model-parallel-rng states info.
# This is required to e.g. have different dropout patterns on different tp ranks that operate on
# slices of attention_probs tensor.
#
# When loading from universal checkpoint, we use mp_rank_<mp>_model_states.pt checkpoint files
# to restore the model-parallel-rng (<mp> is {tp-rank, pp-rank} combination).
# However, if the loaded checkpoint mp configuration does not match the current mp configuration,
# we can not use it to restore model-parallel-rng info.
#
# In the case of mp configuration change, we reconfigure the model-parallel-rng states s.t. each
# tp-rank will have a unique state. In order to ensure that subsequent loads from universal will
# not cause the model-parallel-rng states to be repeated, we add the iteration number to the base seed.
ckp_args = state_dict['args']
if ((args.tensor_model_parallel_size != ckp_args.tensor_model_parallel_size)
or (args.pipeline_model_parallel_size != ckp_args.pipeline_model_parallel_size)):
print_rank_0(' loading universal checkpoint with modified mp configuration '
'-> reconfigure tp seed')
tensor_parallel.model_parallel_reconfigure_tp_seed(args.seed + iteration)
# Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.is_initialized():
torch.distributed.barrier()
print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
# from .utils import dump_weights, dump_position_embed_weights
# dump_weights(f'{args.universal_checkpoint=}', iteration, model, optimizer)
# dump_position_embed_weights("init", 0, model)
return iteration
def load_biencoder_checkpoint(model, only_query_model=False,
only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
args = get_args()
model = unwrap_model(model)
load_path = custom_load_path if custom_load_path is not None else args.load
tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
checkpoint_name = get_checkpoint_name(load_path, iteration,
args.use_distributed_optimizer,
release=False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ret_state_dict = state_dict['model']
if only_query_model:
ret_state_dict.pop('context_model')
if only_context_model:
ret_state_dict.pop('query_model')
assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
def _universal_checkpoint_info(model):
args = get_args()
tokenizer = get_tokenizer()
info = dict()
info[UNIVERSAL_CHECKPOINT_VERSION_KEY] = UNIVERSAL_CHECKPOINT_VERSION_VALUE
info[ORIGINAL_VOCAB_SIZE] = tokenizer.vocab_size
info[PADDED_VOCAB_SIZE] = args.padded_vocab_size
info.update(model[0].universal_checkpoint_info())
return info