Skip to content

Commit

Permalink
[Flamingo] Fix the memory usage of 2x checkpoint size after loading (p…
Browse files Browse the repository at this point in the history
…ytorch#1201)

Co-authored-by: Martin Yuan <[email protected]>
  • Loading branch information
iseeyuan and Martin Yuan authored Sep 25, 2024
1 parent 6fd90bc commit f0a03a7
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

from torchchat.model import Model, ModelArgs, ModelType

from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings

from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
device_sync,
Expand Down Expand Up @@ -387,9 +389,23 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
with set_default_dtype(builder_args.precision), torch.device(
builder_args.device
):
model = Model.from_params(builder_args.params_path)
# It doubles the model size the memory, with redundancies of the initialized weights.
# model = Model.from_params(builder_args.params_path)

# Buffers in rotary embedding are not included in the checkpoint.
# Instead, they are calculated in initialization. Since buffers on meta device
# does not host any actual values, need to reinitialize them in the actual
# device. Only do those buffer initialization, without initializing the entire
# model.
decoder_config = model.config.transformer_args['decoder']
head_dim = decoder_config['embed_dim'] // decoder_config['num_heads']
max_seq_len = decoder_config['max_seq_len']
rope_base = decoder_config['rope_base']
for submodule in model.modules():
if isinstance(submodule, RotaryPositionalEmbeddings):
submodule.__init__(head_dim, max_seq_len, rope_base)
state_dict = flamingo_meta_to_tune(checkpoint)
model.model.load_state_dict(state_dict)
model.model.load_state_dict(state_dict, assign=True, strict=False)
else:
checkpoint = {"model." + k: v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint, assign=True, strict=True)
Expand Down Expand Up @@ -472,7 +488,6 @@ def _load_model(builder_args: BuilderArgs) -> Model:
model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()


def _initialize_model(
builder_args: BuilderArgs,
quantize,
Expand Down

0 comments on commit f0a03a7

Please sign in to comment.