Skip to content

Commit

Permalink
Added amp scale to vision classification template (#342)
Browse files Browse the repository at this point in the history
* Added amp scale to vision classification template

* removed tralining spaces
  • Loading branch information
vfdev-5 authored Nov 27, 2023
1 parent c9dd782 commit 1dd2ba1
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/templates/template-vision-classification/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ignite.distributed as idist
import torch
from ignite.engine import DeterministicEngine, Engine, Events
from torch.cuda.amp import autocast
from torch.cuda.amp import autocast, GradScaler
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler, Sampler
Expand All @@ -17,6 +17,8 @@ def setup_trainer(
device: Union[str, torch.device],
train_sampler: Sampler,
) -> Union[Engine, DeterministicEngine]:
scaler = GradScaler(enabled=config.use_amp)

def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
model.train()

Expand All @@ -27,9 +29,10 @@ def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
outputs = model(samples)
loss = loss_fn(outputs, targets)

loss.backward()
optimizer.step()
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

train_loss = loss.item()
engine.state.metrics = {
Expand Down

0 comments on commit 1dd2ba1

Please sign in to comment.