-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
356 lines (291 loc) · 10.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
"""Utility functions for the project.
This module is largely based on the [HALOs repo](https://github.com/ContextualAI/HALOs/blob/main/utils.py).
"""
import os
import getpass
from datetime import datetime
import torch
import random
import numpy as np
import torch.distributed as dist
import inspect
import importlib.util
import socket
import os
from typing import Dict, Union, Type, List
from collections.abc import Mapping
def deepcopy_fsdp_models(src, tgt) -> None:
"""Given two models, copy every parameter from the src to the tgt model."""
with torch.no_grad():
src_params = { k: v for k,v in src.named_parameters() }
tgt_params = { k: v for k,v in tgt.named_parameters() }
for k in tgt_params:
if k in src_params:
tgt_params[k].data.copy_(src_params[k].data.detach())
else:
rank0_print(f"{k} not found")
def get_open_port() -> None:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) # bind to all interfaces and use an OS provided port
return s.getsockname()[1] # return only the port number
def get_remote_file(remote_path, local_path=None) -> str:
hostname, path = remote_path.split(':')
local_hostname = socket.gethostname()
if hostname == local_hostname or \
hostname == local_hostname[:local_hostname.find('.')]:
return path
if local_path is None:
local_path = path
# local_path = local_path.replace('/scr-ssd', '/scr')
if os.path.exists(local_path):
return local_path
local_dir = os.path.dirname(local_path)
os.makedirs(local_dir, exist_ok=True)
print(f'Copying {hostname}:{path} to {local_path}')
os.system(f'scp {remote_path} {local_path}')
return local_path
def rank0_print(*args, **kwargs) -> None:
"""Print, but only on rank 0."""
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
def on_rank0() -> bool:
return (not dist.is_initialized()) or (dist.get_rank() == 0)
def slice_and_move_batch_for_device(
batch: Dict,
rank: int,
world_size: int,
device: str
) -> Dict:
"""Slice a batch into chunks, and move each chunk to the specified device.
"""
chunk_size = len(list(batch.values())[0]) // world_size
start = chunk_size * rank
end = chunk_size * (rank + 1)
sliced = {k: v[start:end] for k, v in batch.items()}
on_device = {
k: (v.to(device) if isinstance(v, torch.Tensor) else v) \
for k, v in sliced.items()
}
return on_device
def pad_to_length(
tensor: torch.Tensor,
length: int,
pad_value: Union[int, float],
dim: int = -1
) -> torch.Tensor:
if tensor.size(dim) >= length:
return tensor
else:
pad_size = list(tensor.shape)
pad_size[dim] = length - tensor.size(dim)
return torch.cat([tensor, pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device)], dim=dim)
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
token_level: bool = False
) -> torch.Tensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized).
Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens
with a value of -100 are ignored.
Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per
(non-masked) token. Otherwise, return the sum of the
log probabilities of the (non-masked) tokens.
token_level: If true, return the token-level log probabilities (do not
aggregate across tokens)
Returns:
The relevant log probabilities. Of shape (batch_size,) by default and
shape (batch size, sequence length) if token_level.
"""
assert logits.shape[:-1] == labels.shape
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = (labels != -100)
# dummy token; we'll ignore the losses on these tokens later
labels[labels == -100] = 0
distribution_logps = logits.log_softmax(-1)
per_token_logps = torch.gather(
distribution_logps, dim=2, index=labels.unsqueeze(2)
).squeeze(2)
if token_level:
return (per_token_logps * loss_mask)
elif average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def clip_by_value(
x: torch.Tensor,
tensor_min: float,
tensor_max: float,
) -> torch.Tensor:
"""
Tensor extenstion to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
"""
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
return clipped
def masked_mean(
values: torch.Tensor,
mask: torch.Tensor,
axis=None
) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
else:
return (values * mask).sum() / mask.sum()
def masked_var(
values: torch.Tensor,
mask: torch.Tensor,
unbiased=True
) -> torch.Tensor:
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
return variance
def rowwise_product(
mat: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
"""
Calculate the row-wise product over all the elements that have not been
masked out.
Args:
mat: tensor of shape (batch_size, sequence length)
mask: tensor of shape (batch_size, sequence length)
Returns:
Matrix of batch size.
"""
mat = mat.clone()
indices = (mask == 0).long().nonzero()
mat[indices[:,0], indices[:,1]] = 1
return mat.prod(dim=1)
def entropy_from_logits(
logits: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
"""Calculate entropy from logits.
Args:
logits: tensor of shape (batch_size, sequence length, vocab)
mask: tensor of shape (batch_size, sequence length)
Returns:
The average tokenwise entropy across all non-masked tokens
(of shape (1,)).
"""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = masked_mean(
torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1),
mask
)
return entropy
def flatten_dict(nested, sep="/") -> Dict:
"""Flatten dictionary and concatenate nested keys with separator."""
def rec(nest, prefix, into):
for k, v in nest.items():
if sep in k:
raise ValueError(
f"separator '{sep}' not allowed to be in key '{k}'"
)
if isinstance(v, Mapping):
rec(v, prefix + k + sep, into)
else:
into[prefix + k] = v
flat = {}
rec(nested, "", flat)
return flat
def all_gather_if_needed(
values: torch.Tensor,
rank: int,
world_size: int
) -> torch.Tensor:
"""Gather and stack/cat values from all processes, if there are multiple
processes.
"""
if world_size == 1:
return values
device = torch.device('cuda', rank)
all_values = [
torch.empty_like(values).to(device) for _ in range(world_size)
]
dist.all_gather(all_values, values)
cat_function = torch.cat if values.dim() > 0 else torch.stack
return cat_function(all_values, dim=0)
def formatted_dict(d: Dict) -> Dict:
"""Format a dictionary for printing."""
return {k: (f"{v:.5g}" if type(v) == float else v) for k, v in d.items()}
def disable_dropout(model: torch.nn.Module) -> None:
"""Disable dropout in a model."""
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
def delete_dict(d: Dict) -> None:
"""Delete all items inside the dict."""
for k in list(d.keys()):
del d[k]
def print_gpu_memory(rank: int = None, message: str = '') -> None:
"""Print the amount of GPU memory currently allocated for each GPU."""
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
for i in range(device_count):
device = torch.device(f'cuda:{i}')
allocated_bytes = torch.cuda.memory_allocated(device)
if allocated_bytes == 0:
continue
print('*' * 40)
print(
f'[{message} rank {rank} ] GPU {i}: ' + \
f'{allocated_bytes /1024**2:.2f} MB'
)
print('*' * 40)
def get_block_class_from_model(
model: torch.nn.Module,
block_class_name: str
) -> torch.nn.Module:
"""Get the class of a block from a model, using the block's class name."""
for module in model.modules():
if module.__class__.__name__ == block_class_name:
return module.__class__
raise ValueError(
f"Could not find block class {block_class_name} in model {model}"
)
def get_block_class_from_model_class_and_block_name(
model_class: Type,
block_class_name: str
) -> Type:
filepath = inspect.getfile(model_class)
assert filepath.endswith('.py'), f"Expected a .py file, got {filepath}"
assert os.path.exists(filepath), f"File {filepath} does not exist"
assert "transformers" in filepath, \
f"Expected a transformers model, got {filepath}"
module_name = filepath[filepath.find('transformers'):].replace( \
'/', '.')[:-3]
print(
f'Searching in file {filepath}, ' + \
f'module {module_name} for class {block_class_name}'
)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(module_name, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Get the class dynamically
class_ = getattr(module, block_class_name)
print(f"Found class {class_} in module {module_name}")
return class_
def init_distributed(
rank: int,
world_size: int,
master_addr: str = 'localhost',
port: int = 12355,
backend: str = 'nccl'
) -> None:
print(rank, 'initializing distributed')
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(port)
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
torch.cuda.set_device(rank)
dist.init_process_group(backend, rank=rank, world_size=world_size)