Skip to content

Commit

Permalink
Custom optimize for handling complex trainings (#89)
Browse files Browse the repository at this point in the history
* Implement `isimplemened`

* Replace `hasattr` with `isimplemented`

* Add `torch.set_grad_emabled` toggles

* Implement custom optimize

* Update model

* Add GAN training tests

* Make style

* Make lint

* Do not expect `grad_norm` returned

* Add training examples

* Bump up to v0.0.21
  • Loading branch information
erogol authored Jan 9, 2023
1 parent 91d83a1 commit b68ac29
Show file tree
Hide file tree
Showing 8 changed files with 974 additions and 113 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@ Prefer installing from Github as it is more stable.
## Implementing a model
Subclass and overload the functions in the [```TrainerModel()```](trainer/model.py)

## Training a model
See the test script [here](tests/test_train_mnist.py) training a basic MNIST model.

## Training a model with auto optimization
See the [MNIST example](examples/train_mnist.py).


## Training a model with advanced optimization
See the [GAN training example](examples/train_simple_gan.py) with Gradient Accumulation


## Training with Batch Size Finder
see the test script [here](tests/test_train_batch_size_finder.py) for training with batch size finder.
Expand Down Expand Up @@ -95,6 +101,6 @@ trainer.fit()
To add a new logger, you must subclass [BaseDashboardLogger](trainer/logging/base_dash_logger.py) and overload its functions.

## Anonymized Telemetry
We constantly seek to improve 🐸 for the community. To understand the community's needs better and address them accordingly, we collect stripped-down anonymized usage stats when you run the trainer.
We constantly seek to improve 🐸 for the community. To understand the community's needs better and address them accordingly, we collect stripped-down anonymized usage stats when you run the trainer.

Of course, if you don't want, you can opt out by setting the environment variable `TRAINER_TELEMETRY=0`.
Of course, if you don't want, you can opt out by setting the environment variable `TRAINER_TELEMETRY=0`.
102 changes: 102 additions & 0 deletions examples/train_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
This example shows training of a simple Conv model with MNIST dataset using Auto Training mode of 👟.
"""

import os
from dataclasses import dataclass

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

from trainer import TrainerConfig, TrainerModel, Trainer, TrainerArgs


@dataclass
class MnistModelConfig(TrainerConfig):
optimizer: str = "Adam"
lr: float = 0.001
epochs: int = 1
print_step: int = 1
save_step: int = 5
plot_step: int = 5
dashboard_logger: str = "tensorboard"


class MnistModel(TrainerModel):
def __init__(self):
super().__init__()

# mnist images are (1, 28, 28) (channels, height, width)
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 256)
self.layer_3 = nn.Linear(256, 10)

def forward(self, x):
batch_size, _, _, _ = x.size()

# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.relu(x)
x = self.layer_3(x)

x = F.log_softmax(x, dim=1)
return x

def train_step(self, batch, criterion):
x, y = batch
logits = self(x)
loss = criterion(logits, y)
return {"model_outputs": logits}, {"loss": loss}

def eval_step(self, batch, criterion):
x, y = batch
logits = self(x)
loss = criterion(logits, y)
return {"model_outputs": logits}, {"loss": loss}

@staticmethod
def get_criterion():
return torch.nn.NLLLoss()

def get_data_loader(
self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
): # pylint: disable=unused-argument
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
dataset.data = dataset.data[:256]
dataset.targets = dataset.targets[:256]
dataloader = DataLoader(dataset, batch_size=config.batch_size)
return dataloader


def main():
"""Run `MNIST` model training from scratch or from previous checkpoint."""
# init args and config
train_args = TrainerArgs()
config = MnistModelConfig()

# init the model from config
model = MnistModel()

# init the trainer and 🚀
trainer = Trainer(
train_args,
config,
config.output_path,
model=model,
train_samples=model.get_data_loader(config, None, False, None, None, None),
eval_samples=model.get_data_loader(config, None, True, None, None, None),
parse_command_line_args=True,
)
trainer.fit()


if __name__ == "__main__":
main()
176 changes: 176 additions & 0 deletions examples/train_simple_gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""
This example shows training of a simple GAN model with MNIST dataset using Gradient Accumulation and Advanced
Optimization where you call optimizer steps manually.
"""

import os
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

from trainer import Trainer, TrainerConfig, TrainerModel
from trainer.trainer import TrainerArgs

is_cuda = torch.cuda.is_available()


# pylint: skip-file


class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super().__init__()
self.img_shape = img_shape

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh(),
)

def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img


class Discriminator(nn.Module):
def __init__(self, img_shape):
super().__init__()

self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)

def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)

return validity


@dataclass
class GANModelConfig(TrainerConfig):
epochs: int = 1
print_step: int = 2
training_seed: int = 666


class GANModel(TrainerModel):
def __init__(self):
super().__init__()
data_shape = (1, 28, 28)
self.generator = Generator(latent_dim=100, img_shape=data_shape)
self.discriminator = Discriminator(img_shape=data_shape)

def forward(self, x):
...

def optimize(self, batch, trainer):
imgs, _ = batch

# sample noise
z = torch.randn(imgs.shape[0], 100)
z = z.type_as(imgs)

# train discriminator
imgs_gen = self.generator(z)
logits = self.discriminator(imgs_gen.detach())
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)
loss_fake = trainer.criterion(logits, fake)

valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
logits = self.discriminator(imgs)
loss_real = trainer.criterion(logits, valid)
loss_disc = (loss_real + loss_fake) / 2

# step dicriminator
_, _ = self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0])

if trainer.total_steps_done % trainer.grad_accum_steps == 0:
trainer.optimizer[0].step()
trainer.optimizer[0].zero_grad()

# train generator
imgs_gen = self.generator(z)

valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

logits = self.discriminator(imgs_gen)
loss_gen = trainer.criterion(logits, valid)

# step generator
_, _ = self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1])
if trainer.total_steps_done % trainer.grad_accum_steps == 0:
trainer.optimizer[1].step()
trainer.optimizer[1].zero_grad()
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}

@torch.no_grad()
def eval_step(self, batch, criterion):
imgs, _ = batch

# sample noise
z = torch.randn(imgs.shape[0], 100)
z = z.type_as(imgs)

imgs_gen = self.generator(z)
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

logits = self.discriminator(imgs_gen)
loss_gen = trainer.criterion(logits, valid)
return {"model_outputs": logits}, {"loss_gen": loss_gen}

def get_optimizer(self):
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
return [discriminator_optimizer, generator_optimizer]

def get_criterion(self):
return nn.BCELoss()

def get_data_loader(
self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
): # pylint: disable=unused-argument
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
dataset.data = dataset.data[:64]
dataset.targets = dataset.targets[:64]
dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True)
return dataloader


if __name__ == "__main__":

config = GANModelConfig()
config.batch_size = 64
config.grad_clip = None

model = GANModel()
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
trainer.config.epochs = 10
trainer.fit()
Loading

0 comments on commit b68ac29

Please sign in to comment.