Skip to content

Commit

Permalink
[Enhance] Support dist_train without slurm (#791)
Browse files Browse the repository at this point in the history
* add docs for distributed training

* add pytorch docs link

* update dist train commands

* update version
  • Loading branch information
wangruohui authored Mar 17, 2022
1 parent e844543 commit 54ce3d7
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 4 deletions.
19 changes: 19 additions & 0 deletions docs/en/quick_run.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ Difference between `resume-from` and `load-from`:
`resume-from` loads both the model weights and optimizer status, and the iteration is also inherited from the specified checkpoint. It is usually used for resuming the training process that is interrupted accidentally.
`load-from` only loads the model weights and the training iteration starts from 0. It is usually used for fine-tuning.

#### Train with multiple nodes

To launch distributed training on multiple machines, which can be accessed via IPs, run following commands:

On the first machine:

```shell
NNODES=2 NODE_RANK=0 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR tools/dist_train.sh $CONFIG $GPUS
```

On the second machine:

```shell
NNODES=2 NODE_RANK=1 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR tools/dist_train.sh $CONFIG $GPUS
```

To speed up network communication, high speed network hardware, such as Infiniband, is recommended.
Please refer to [PyTorch docs](https://pytorch.org/docs/1.11/distributed.html#launch-utility) for more information.

### Train with Slurm

If you run MMEditing on a cluster managed with [slurm](https://slurm.schedmd.com/), you can use the script `slurm_train.sh`. (This script also supports single machine training.)
Expand Down
20 changes: 20 additions & 0 deletions docs/zh_cn/quick_run.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ evaluation = dict(interval=1e4, by_epoch=False) # 每一万次迭代进行一
`resume-from` 加载模型权重和优化器状态,迭代也从指定的检查点继承。 它通常用于恢复意外中断的训练过程。
`load-from` 只加载模型权重,训练迭代从 0 开始,通常用于微调。

#### 使用多节点训练

如果您有多个计算节点,而且他们可以通过 IP 互相访问,可以使用以下命令启动分布式训练:

在第一个节点:

```shell
NNODES=2 NODE_RANK=0 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR tools/dist_train.sh $CONFIG $GPUS
```

在第二个节点:

```shell
NNODES=2 NODE_RANK=1 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR tools/dist_train.sh $CONFIG $GPUS
```

为提高网络通信速度,推荐使用高速网络设备,如 Infiniband 等。
更多信息可参照[PyTorch 文档](https://pytorch.org/docs/1.11/distributed.html#launch-utility).


### 在 slurm 上训练

如果您在使用 [slurm](https://slurm.schedmd.com/) 管理的集群上运行 MMEditing,则可以使用脚本 `slurm_train.sh`。(此脚本也支持单机训练。)
Expand Down
16 changes: 14 additions & 2 deletions tools/dist_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,20 @@
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}

PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/test.py \
$CONFIG \
$CHECKPOINT \
--launcher pytorch \
${@:4}
15 changes: 13 additions & 2 deletions tools/dist_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,19 @@

CONFIG=$1
GPUS=$2
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}

PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--seed 0 \
--launcher pytorch ${@:3}

0 comments on commit 54ce3d7

Please sign in to comment.