Skip to content

Commit

Permalink
added support of the multigpu training
Browse files Browse the repository at this point in the history
  • Loading branch information
alexteua committed Jan 30, 2023
1 parent ab1e5f1 commit 76b0852
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 57 deletions.
2 changes: 1 addition & 1 deletion dp/configs/autoreg_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ training:
n_generate_samples: 10 # Number of result samples to show on tensorboard.
store_phoneme_dict_in_model: true # Whether to store the raw phoneme dict in the model.
# It will be loaded by the phonemizer object.

ddp_backend: 'nccl' # Backend used by Torch DDS
9 changes: 5 additions & 4 deletions dp/configs/forward_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ training:
scheduler_plateau_patience: 10 # Number of text generations with no improvement to tolerate.
batch_size: 32 # Training batch size.
batch_size_val: 32 # Validation batch size.
epochs: 500 # Number of epochs to train.
generate_steps: 10000 # Interval of training steps to generate sample outputs. Also, at this step the phoneme and word
epochs: 2 # Number of epochs to train.
generate_steps: 100 # Interval of training steps to generate sample outputs. Also, at this step the phoneme and word
# error rates are calculated for the scheduler.
validate_steps: 10000 # Interval of training steps to validate the model
validate_steps: 10 # Interval of training steps to validate the model
# (for the autoregressive model this is teacher-forced).
checkpoint_steps: 100000 # Interval of training steps to save the model.
checkpoint_steps: 10 # Interval of training steps to save the model.
n_generate_samples: 10 # Number of result samples to show on tensorboard.
store_phoneme_dict_in_model: true # Whether to store the raw phoneme dict in the model.
# It will be loaded by the phonemizer object.
ddp_backend: 'nccl' # Backend used by Torch DDS

2 changes: 1 addition & 1 deletion dp/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, d_model: int, dropout=0.1, max_len=5000) -> None:
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
self.register_parameter('pe', torch.nn.Parameter(pe, requires_grad=False))

def forward(self, x: torch.Tensor) -> torch.Tensor: # shape: [T, N]
x = x + self.scale * self.pe[:x.size(0), :]
Expand Down
28 changes: 25 additions & 3 deletions dp/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
from pathlib import Path

import torch
from torch.distributed import init_process_group

from dp.model.model import load_checkpoint, ModelType, \
create_model
from dp.preprocessing.text import Preprocessor
Expand All @@ -10,12 +14,16 @@
logger = get_logger(__name__)


def train(config_file: str,
def train(rank: int,
num_gpus: int,
config_file: str,
checkpoint_file: str = None) -> None:
"""
Runs training of a transformer model.
Args:
rank (int): Device id
num_gpus (int): Number of devices
config_file (str): Path to the config.yaml that stores all necessary parameters.
checkpoint_file (str, optional): Path to a model checkpoint to resume training for (e.g. latest_model.pt)
Expand All @@ -25,6 +33,12 @@ def train(config_file: str,
"""

config = read_config(config_file)

if num_gpus >= 1:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
init_process_group(backend=config['training']['ddp_backend'], rank=rank, world_size=num_gpus)

if checkpoint_file is not None:
logger.info(f'Restoring model from checkpoint: {checkpoint_file}')
model, checkpoint = load_checkpoint(checkpoint_file)
Expand Down Expand Up @@ -53,7 +67,15 @@ def train(config_file: str,
checkpoint_dir = Path(config['paths']['checkpoint_dir'])
logger.info(f'Checkpoints will be stored at {checkpoint_dir.absolute()}')
loss_type = 'cross_entropy' if model_type.is_autoregressive() else 'ctc'
trainer = Trainer(checkpoint_dir=checkpoint_dir, loss_type=loss_type)

if num_gpus > 0:
device = torch.device('cuda:{:d}'.format(rank))
else:
device = torch.device('cpu')

use_ddp = True if num_gpus > 1 else False

trainer = Trainer(checkpoint_dir=checkpoint_dir, device=device, rank=rank, use_ddp=use_ddp, loss_type=loss_type)
trainer.train(model=model,
checkpoint=checkpoint,
store_phoneme_dict_in_model=config['training']['store_phoneme_dict_in_model'])
store_phoneme_dict_in_model=config['training']['store_phoneme_dict_in_model'])
17 changes: 11 additions & 6 deletions dp/training/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DistributedSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.sampler import Sampler

from dp.utils.io import unpickle_binary

Expand All @@ -34,7 +34,7 @@ def __len__(self):


# From https://github.com/fatchord/WaveRNN/blob/master/utils/dataset.py
class BinnedLengthSampler(Sampler):
class BinnedLengthSampler(DistributedSampler):

def __init__(self, phoneme_lens: List[int], batch_size: int, bin_size: int, seed=42) -> None:
_, self.idx = torch.sort(torch.tensor(phoneme_lens))
Expand Down Expand Up @@ -83,16 +83,21 @@ def collate_dataset(batch: List[dict]) -> Dict[str, torch.Tensor]:
def new_dataloader(dataset_file: Path,
batch_size=32,
drop_last=False,
use_binning=True) -> DataLoader:
use_binning=True,
use_ddp=False) -> DataLoader:
dataset = unpickle_binary(dataset_file)
phonemizer_dataset = PhonemizerDataset(dataset)
phoneme_lens = [len(p) for _, _, p in dataset]

if use_binning:
sampler = BinnedLengthSampler(phoneme_lens=phoneme_lens,
batch_size=batch_size,
bin_size=batch_size*3)
bin_size=batch_size * 3)
else:
sampler = None
if use_ddp:
sampler = DistributedSampler(phonemizer_dataset)
else:
sampler = None

return DataLoader(phonemizer_dataset,
collate_fn=collate_dataset,
Expand All @@ -101,4 +106,4 @@ def new_dataloader(dataset_file: Path,
num_workers=0,
shuffle=False,
drop_last=drop_last,
pin_memory=True)
pin_memory=True)
95 changes: 56 additions & 39 deletions dp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import tqdm
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -20,21 +21,26 @@


class Trainer:

""" Performs model training. """

def __init__(self, checkpoint_dir: Path, loss_type='ctc') -> None:
def __init__(self, checkpoint_dir: Path, device: torch.device, rank: int, use_ddp: bool, loss_type='ctc') -> None:
"""
Initializes a Trainer object.
Args:
checkpoint_dir (Path): Directory to store the model checkpoints.
device (torch.device): Device used for training
rank (int): Rank of the current device
use_ddp (bool): Flag whether DDP is used for training
loss_type (str): Type of loss: 'ctc' for forward transformer models
and 'cross_entropy' for autoregressive models.
"""

self.checkpoint_dir = checkpoint_dir
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.use_ddp = use_ddp
self.rank = rank
self.device = device
self.writer = SummaryWriter(log_dir=str(self.checkpoint_dir / 'logs'))
self.loss_type = loss_type
if loss_type == 'ctc':
Expand Down Expand Up @@ -66,28 +72,33 @@ def train(self,
config = checkpoint['config']
data_dir = Path(config['paths']['data_dir'])

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)
model = model.to(self.device)
model.train()

criterion = self.criterion.to(device)
if self.use_ddp:
model = DistributedDataParallel(model, device_ids=[self.rank])

criterion = self.criterion.to(self.device)

optimizer = Adam(model.parameters())
if 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
for g in optimizer.param_groups:
g['lr'] = config['training']['learning_rate']

train_loader = new_dataloader(dataset_file=data_dir / 'train_dataset.pkl',
train_loader = new_dataloader(dataset_file=data_dir / 'train_dataset.pkl', use_ddp=self.use_ddp,
drop_last=True, batch_size=config['training']['batch_size'])
val_loader = new_dataloader(dataset_file=data_dir / 'val_dataset.pkl',
drop_last=False, batch_size=config['training']['batch_size_val'])

if self.rank == 0:
val_loader = new_dataloader(dataset_file=data_dir / 'val_dataset.pkl', use_ddp=False,
drop_last=False, batch_size=config['training']['batch_size_val'])

val_batches = sorted([b for b in val_loader], key=lambda x: -x['text_len'][0])

if store_phoneme_dict_in_model:
phoneme_dict = unpickle_binary(data_dir / 'phoneme_dict.pkl')
checkpoint['phoneme_dict'] = phoneme_dict

val_batches = sorted([b for b in val_loader], key=lambda x: -x['text_len'][0])

scheduler = ReduceLROnPlateau(optimizer,
factor=config['training']['scheduler_plateau_factor'],
patience=config['training']['scheduler_plateau_patience'],
Expand All @@ -98,17 +109,21 @@ def train(self,
checkpoint['step'] = 0
start_epoch = checkpoint['step'] // len(train_loader)

if self.use_ddp and self.rank == 0:
train_loader.sampler.set_epoch(start_epoch)

for epoch in range(start_epoch + 1, config['training']['epochs'] + 1):
pbar = tqdm.tqdm(enumerate(train_loader, 1), total=len(train_loader))
for i, batch in pbar:
checkpoint['step'] += 1
step = checkpoint['step']
self._set_warmup_lr(optimizer=optimizer, step=step,
config=config)
batch = to_device(batch, device)
batch = to_device(batch, self.device)
avg_loss = sum(losses) / len(losses) if len(losses) > 0 else math.inf
pbar.set_description(desc=f'Epoch: {epoch} | Step {step} '
pbar.set_description(desc=f'Rank: {self.rank} | Epoch: {epoch} | Step {step} '
f'| Loss: {avg_loss:#.4}', refresh=True)

pred = model(batch)
loss = criterion(pred, batch)

Expand All @@ -125,34 +140,36 @@ def train(self,
self.writer.add_scalar('Params/learning_rate', [g['lr'] for g in optimizer.param_groups][0],
global_step=step)

if step % config['training']['validate_steps'] == 0:
val_loss = self._validate(model, val_batches)
self.writer.add_scalar('Loss/val', val_loss, global_step=step)

if step % config['training']['generate_steps'] == 0:
lang_samples = self._generate_samples(model=model,
preprocessor=checkpoint['preprocessor'],
val_batches=val_batches)
eval_result = evaluate_samples(lang_samples=lang_samples)
self._write_summaries(lang_samples=lang_samples,
eval_result=eval_result,
n_generate_samples=config['training']['n_generate_samples'],
step=step)
if eval_result['mean_per'] is not None and eval_result['mean_per'] < best_per:
if self.rank == 0:
if step % config['training']['validate_steps'] == 0:
val_loss = self._validate(model, val_batches)
self.writer.add_scalar('Loss/val', val_loss, global_step=step)

if step % config['training']['generate_steps'] == 0:
lang_samples = self._generate_samples(model=model,
preprocessor=checkpoint['preprocessor'],
val_batches=val_batches)
eval_result = evaluate_samples(lang_samples=lang_samples)
self._write_summaries(lang_samples=lang_samples,
eval_result=eval_result,
n_generate_samples=config['training']['n_generate_samples'],
step=step)
if eval_result['mean_per'] is not None and eval_result['mean_per'] < best_per:
self._save_model(model=model, optimizer=optimizer, checkpoint=checkpoint,
path=self.checkpoint_dir / f'best_model.pt')
self._save_model(model=model, optimizer=None, checkpoint=checkpoint,
path=self.checkpoint_dir / f'best_model_no_optim.pt')
scheduler.step(eval_result['mean_per'])

if step % config['training']['checkpoint_steps'] == 0:
step = step // 1000
self._save_model(model=model, optimizer=optimizer, checkpoint=checkpoint,
path=self.checkpoint_dir / f'best_model.pt')
self._save_model(model=model, optimizer=None, checkpoint=checkpoint,
path=self.checkpoint_dir / f'best_model_no_optim.pt')
scheduler.step(eval_result['mean_per'])

if step % config['training']['checkpoint_steps'] == 0:
step = step // 1000
self._save_model(model=model, optimizer=optimizer, checkpoint=checkpoint,
path=self.checkpoint_dir / f'model_step_{step}k.pt')
path=self.checkpoint_dir / f'model_step_{step}k.pt')

losses = []
self._save_model(model=model, optimizer=optimizer, checkpoint=checkpoint,
path=self.checkpoint_dir / 'latest_model.pt')
if self.rank == 0:
self._save_model(model=model, optimizer=optimizer, checkpoint=checkpoint,
path=self.checkpoint_dir / 'latest_model.pt')

def _validate(self, model: Model, val_batches: List[dict]) -> float:
device = next(model.parameters()).device
Expand Down Expand Up @@ -186,7 +203,7 @@ def _generate_samples(self,

for batch in val_batches:
batch = to_device(batch, device)
generated_batch, _ = model.generate(batch)
generated_batch, _ = (model.module if self.use_ddp else model).generate(batch)
for i in range(batch['text'].size(0)):
text_len = batch['text_len'][i]
text = batch['text'][i, :text_len]
Expand Down Expand Up @@ -241,7 +258,7 @@ def _save_model(self,
optimizer: torch.optim,
checkpoint: Dict[str, Any],
path: Path) -> None:
checkpoint['model'] = model.state_dict()
checkpoint['model'] = (model.module if self.use_ddp else model).state_dict()
if optimizer is not None:
checkpoint['optimizer'] = optimizer.state_dict()
else:
Expand Down
10 changes: 9 additions & 1 deletion run_training.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import torch
import torch.multiprocessing as mp

from dp.preprocess import preprocess
from dp.train import train

Expand All @@ -17,4 +20,9 @@
val_data=val_data,
deduplicate_train_data=False)

train(config_file=config_file)
num_gpus = torch.cuda.device_count()

if num_gpus >= 1:
mp.spawn(train, nprocs=num_gpus, args=(num_gpus, config_file))
else:
train(rank=0, num_gpus=num_gpus, config_file=config_file)
2 changes: 1 addition & 1 deletion tests/test_autoreg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_autoregtraining_happy_path(self) -> None:
val_data=val_data,
deduplicate_train_data=False)

train(config_file=config_path)
train(rank=0, num_gpus=0, config_file=config_path)

predictor = Predictor.from_checkpoint(checkpoint_dir / 'latest_model.pt')

Expand Down
2 changes: 1 addition & 1 deletion tests/test_forward_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_forward_training_happy_path(self) -> None:
val_data=val_data,
deduplicate_train_data=False)

train(config_file=config_path)
train(rank=0, num_gpus=0, config_file=config_path)

predictor = Predictor.from_checkpoint(checkpoint_dir / 'latest_model.pt')

Expand Down

0 comments on commit 76b0852

Please sign in to comment.