-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Custom optimize for handling complex trainings (#89)
* Implement `isimplemened` * Replace `hasattr` with `isimplemented` * Add `torch.set_grad_emabled` toggles * Implement custom optimize * Update model * Add GAN training tests * Make style * Make lint * Do not expect `grad_norm` returned * Add training examples * Bump up to v0.0.21
- Loading branch information
Showing
8 changed files
with
974 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
""" | ||
This example shows training of a simple Conv model with MNIST dataset using Auto Training mode of 👟. | ||
""" | ||
|
||
import os | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
from torchvision.datasets import MNIST | ||
|
||
from trainer import TrainerConfig, TrainerModel, Trainer, TrainerArgs | ||
|
||
|
||
@dataclass | ||
class MnistModelConfig(TrainerConfig): | ||
optimizer: str = "Adam" | ||
lr: float = 0.001 | ||
epochs: int = 1 | ||
print_step: int = 1 | ||
save_step: int = 5 | ||
plot_step: int = 5 | ||
dashboard_logger: str = "tensorboard" | ||
|
||
|
||
class MnistModel(TrainerModel): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
# mnist images are (1, 28, 28) (channels, height, width) | ||
self.layer_1 = nn.Linear(28 * 28, 128) | ||
self.layer_2 = nn.Linear(128, 256) | ||
self.layer_3 = nn.Linear(256, 10) | ||
|
||
def forward(self, x): | ||
batch_size, _, _, _ = x.size() | ||
|
||
# (b, 1, 28, 28) -> (b, 1*28*28) | ||
x = x.view(batch_size, -1) | ||
x = self.layer_1(x) | ||
x = F.relu(x) | ||
x = self.layer_2(x) | ||
x = F.relu(x) | ||
x = self.layer_3(x) | ||
|
||
x = F.log_softmax(x, dim=1) | ||
return x | ||
|
||
def train_step(self, batch, criterion): | ||
x, y = batch | ||
logits = self(x) | ||
loss = criterion(logits, y) | ||
return {"model_outputs": logits}, {"loss": loss} | ||
|
||
def eval_step(self, batch, criterion): | ||
x, y = batch | ||
logits = self(x) | ||
loss = criterion(logits, y) | ||
return {"model_outputs": logits}, {"loss": loss} | ||
|
||
@staticmethod | ||
def get_criterion(): | ||
return torch.nn.NLLLoss() | ||
|
||
def get_data_loader( | ||
self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 | ||
): # pylint: disable=unused-argument | ||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) | ||
dataset.data = dataset.data[:256] | ||
dataset.targets = dataset.targets[:256] | ||
dataloader = DataLoader(dataset, batch_size=config.batch_size) | ||
return dataloader | ||
|
||
|
||
def main(): | ||
"""Run `MNIST` model training from scratch or from previous checkpoint.""" | ||
# init args and config | ||
train_args = TrainerArgs() | ||
config = MnistModelConfig() | ||
|
||
# init the model from config | ||
model = MnistModel() | ||
|
||
# init the trainer and 🚀 | ||
trainer = Trainer( | ||
train_args, | ||
config, | ||
config.output_path, | ||
model=model, | ||
train_samples=model.get_data_loader(config, None, False, None, None, None), | ||
eval_samples=model.get_data_loader(config, None, True, None, None, None), | ||
parse_command_line_args=True, | ||
) | ||
trainer.fit() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
""" | ||
This example shows training of a simple GAN model with MNIST dataset using Gradient Accumulation and Advanced | ||
Optimization where you call optimizer steps manually. | ||
""" | ||
|
||
import os | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
import torch | ||
from torch import nn | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
from torchvision.datasets import MNIST | ||
|
||
from trainer import Trainer, TrainerConfig, TrainerModel | ||
from trainer.trainer import TrainerArgs | ||
|
||
is_cuda = torch.cuda.is_available() | ||
|
||
|
||
# pylint: skip-file | ||
|
||
|
||
class Generator(nn.Module): | ||
def __init__(self, latent_dim, img_shape): | ||
super().__init__() | ||
self.img_shape = img_shape | ||
|
||
def block(in_feat, out_feat, normalize=True): | ||
layers = [nn.Linear(in_feat, out_feat)] | ||
if normalize: | ||
layers.append(nn.BatchNorm1d(out_feat, 0.8)) | ||
layers.append(nn.LeakyReLU(0.2, inplace=True)) | ||
return layers | ||
|
||
self.model = nn.Sequential( | ||
*block(latent_dim, 128, normalize=False), | ||
*block(128, 256), | ||
*block(256, 512), | ||
*block(512, 1024), | ||
nn.Linear(1024, int(np.prod(img_shape))), | ||
nn.Tanh(), | ||
) | ||
|
||
def forward(self, z): | ||
img = self.model(z) | ||
img = img.view(img.size(0), *self.img_shape) | ||
return img | ||
|
||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, img_shape): | ||
super().__init__() | ||
|
||
self.model = nn.Sequential( | ||
nn.Linear(int(np.prod(img_shape)), 512), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(512, 256), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(256, 1), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward(self, img): | ||
img_flat = img.view(img.size(0), -1) | ||
validity = self.model(img_flat) | ||
|
||
return validity | ||
|
||
|
||
@dataclass | ||
class GANModelConfig(TrainerConfig): | ||
epochs: int = 1 | ||
print_step: int = 2 | ||
training_seed: int = 666 | ||
|
||
|
||
class GANModel(TrainerModel): | ||
def __init__(self): | ||
super().__init__() | ||
data_shape = (1, 28, 28) | ||
self.generator = Generator(latent_dim=100, img_shape=data_shape) | ||
self.discriminator = Discriminator(img_shape=data_shape) | ||
|
||
def forward(self, x): | ||
... | ||
|
||
def optimize(self, batch, trainer): | ||
imgs, _ = batch | ||
|
||
# sample noise | ||
z = torch.randn(imgs.shape[0], 100) | ||
z = z.type_as(imgs) | ||
|
||
# train discriminator | ||
imgs_gen = self.generator(z) | ||
logits = self.discriminator(imgs_gen.detach()) | ||
fake = torch.zeros(imgs.size(0), 1) | ||
fake = fake.type_as(imgs) | ||
loss_fake = trainer.criterion(logits, fake) | ||
|
||
valid = torch.ones(imgs.size(0), 1) | ||
valid = valid.type_as(imgs) | ||
logits = self.discriminator(imgs) | ||
loss_real = trainer.criterion(logits, valid) | ||
loss_disc = (loss_real + loss_fake) / 2 | ||
|
||
# step dicriminator | ||
_, _ = self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0]) | ||
|
||
if trainer.total_steps_done % trainer.grad_accum_steps == 0: | ||
trainer.optimizer[0].step() | ||
trainer.optimizer[0].zero_grad() | ||
|
||
# train generator | ||
imgs_gen = self.generator(z) | ||
|
||
valid = torch.ones(imgs.size(0), 1) | ||
valid = valid.type_as(imgs) | ||
|
||
logits = self.discriminator(imgs_gen) | ||
loss_gen = trainer.criterion(logits, valid) | ||
|
||
# step generator | ||
_, _ = self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1]) | ||
if trainer.total_steps_done % trainer.grad_accum_steps == 0: | ||
trainer.optimizer[1].step() | ||
trainer.optimizer[1].zero_grad() | ||
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc} | ||
|
||
@torch.no_grad() | ||
def eval_step(self, batch, criterion): | ||
imgs, _ = batch | ||
|
||
# sample noise | ||
z = torch.randn(imgs.shape[0], 100) | ||
z = z.type_as(imgs) | ||
|
||
imgs_gen = self.generator(z) | ||
valid = torch.ones(imgs.size(0), 1) | ||
valid = valid.type_as(imgs) | ||
|
||
logits = self.discriminator(imgs_gen) | ||
loss_gen = trainer.criterion(logits, valid) | ||
return {"model_outputs": logits}, {"loss_gen": loss_gen} | ||
|
||
def get_optimizer(self): | ||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) | ||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) | ||
return [discriminator_optimizer, generator_optimizer] | ||
|
||
def get_criterion(self): | ||
return nn.BCELoss() | ||
|
||
def get_data_loader( | ||
self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 | ||
): # pylint: disable=unused-argument | ||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) | ||
dataset.data = dataset.data[:64] | ||
dataset.targets = dataset.targets[:64] | ||
dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True) | ||
return dataloader | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
config = GANModelConfig() | ||
config.batch_size = 64 | ||
config.grad_clip = None | ||
|
||
model = GANModel() | ||
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None) | ||
trainer.config.epochs = 10 | ||
trainer.fit() |
Oops, something went wrong.