-
Notifications
You must be signed in to change notification settings - Fork 9
/
dist_helper.py
41 lines (35 loc) · 1.31 KB
/
dist_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import subprocess
import torch
import torch.distributed as dist
def setup_distributed(backend="nccl", port=None):
"""AdaHessian Optimizer
Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py
Originally licensed MIT, Copyright (c) 2020 Wei Li
"""
num_gpus = torch.cuda.device_count()
if "SLURM_JOB_ID" in os.environ:
rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ["SLURM_NTASKS"])
node_list = os.environ["SLURM_NODELIST"]
addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
# specify master port
if port is not None:
os.environ["MASTER_PORT"] = str(port)
elif "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "10685"
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = addr
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_RANK"] = str(rank % num_gpus)
os.environ["RANK"] = str(rank)
else:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
)
return rank, world_size