forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
estimation.py
220 lines (188 loc) · 8.33 KB
/
estimation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import gc
import os
import torch
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore
from torchtitan import utils
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
def estimate_memory(job_config: JobConfig):
init_logger()
logger.info("Estimating memory usage...")
gc.disable()
gc.collect(1)
# Get the world size
world_size = int(os.environ["WORLD_SIZE"])
# if tp > or pp > 1, we exit
if (
job_config.training.tensor_parallel_degree > 1
or job_config.experimental.pipeline_parallel_degree > 1
):
logger.info(
"Tensor parallelism and pipeline parallelism are not supported yet."
)
return
# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
and not job_config.memory_estimation.disable_fake_mode
):
logger.info(
"Fused RMSNorm is not supported yet under fake estimation mode. "
"Switching to rmsnorm."
)
job_config.model.norm_type = "rmsnorm"
if job_config.model.norm_type == "compiled_rmsnorm":
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
job_config.model.norm_type = "rmsnorm"
if job_config.training.compile or job_config.experimental.enable_compiled_autograd:
logger.info("Compile mode is not supported yet. Switching to eager mode.")
job_config.training.compile = False
job_config.experimental.enable_compiled_autograd = False
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
# init fake pg
store = FakeStore()
torch.distributed.init_process_group(
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
)
# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")
if not parallel_dims.dp_enabled:
logger.info("Data parallelism is not enabled. Skipping memory estimation.")
return
model_name = job_config.model.name
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
train_context = utils.get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
)
# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1).float(), labels.flatten(0, 1)
)
# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
# 3. max_seq_len base on inputs
model_config.norm_type = job_config.model.norm_type
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len
with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext():
logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
model = model_cls.from_model_args(model_config)
# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(model)
# apply PT-D DP/TP parallelisms and activation checkpointing
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
model.to_empty(device="cuda")
if not active_fake_mode():
model.init_weights()
model.train()
# build optimizer after applying parallelisms to the model
optimizers = build_optimizers([model], job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
logger.info(f"Vocab size: {model_config.vocab_size}")
# Create a dummy batch instead of loading from a dataset
batch = (
torch.randint(
0,
model_config.vocab_size,
(job_config.training.batch_size, model_config.max_seq_len),
device="cuda",
),
torch.randint(
0,
model_config.vocab_size,
(job_config.training.batch_size, model_config.max_seq_len),
device="cuda",
),
)
fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)
with fsdp_memtracker:
for iter_idx in range(2):
input_ids, labels = batch
# train step
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()
# clip gradients
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
optimizers.step()
lr_schedulers.step()
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
if iter_idx == 0:
fsdp_memtracker.reset_mod_stats() # iter 0 does not have optimizer state
gc.collect(1)
fsdp_memtracker.display_modulewise_snapshots(
depth=3, units="MiB", tabulate=True
)
mem_stats = torch.cuda.memory_stats()
peak_active = mem_stats["active_bytes.all.peak"]
peak_reserved = mem_stats["reserved_bytes.all.peak"]
num_retries = mem_stats["num_alloc_retries"]
dev = torch.device(torch.cuda.current_device())
tracker_peak = fsdp_memtracker.get_tracker_snapshot("peak")[dev]["Total"]
gib = 1024**3
print(
f"peak active: {peak_active / gib} GiB | peak reserved:"
f" {peak_reserved / gib} GiB | num_retries: {num_retries}"
)
print(f"Tracker Max: {tracker_peak / gib} GiB")
if job_config.memory_estimation.disable_fake_mode and peak_active > 0:
print(f"Tracker Accuracy: {tracker_peak/peak_active}")
gc.enable()
if __name__ == "__main__":
config = JobConfig()
config.parse_args()
try:
estimate_memory(config)
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()