-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
102 lines (71 loc) · 3.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import yaml
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from models import FlowerNet
with open('configs/train.yaml', 'r') as configs:
configs = yaml.load(configs, Loader=yaml.FullLoader)
augment_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
transforms.RandomRotation(10),
transforms.RandomErasing(scale=(0.05, 0.2), ratio=(0.5, 2.0)),
])
normal_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
train_transform = normal_transform
valid_transform = normal_transform
if configs['use-augment']:
train_transform = augment_transform
train_dataset = datasets.ImageFolder('datasets/train', transform=train_transform)
valid_dataset = datasets.ImageFolder('datasets/valid', transform=valid_transform)
train_dataset_size = len(train_dataset)
valid_dataset_size = len(valid_dataset)
train_dataloader = data.DataLoader(train_dataset, batch_size=configs['batch-size'], shuffle=True, num_workers=configs['num-workers'])
valid_dataloader = data.DataLoader(valid_dataset, batch_size=configs['batch-size'], shuffle=True, num_workers=configs['num-workers'])
train_dataloader_size = len(train_dataloader)
valid_dataloader_size = len(valid_dataloader)
device = torch.device(configs['device'])
model = FlowerNet(num_classes=configs['num-classes'], pretrained=configs['load-pretrained'])
model = model.to(device)
if configs['load-checkpoint']:
model.load_state_dict(torch.load(configs['load-path'], map_location=device, weights_only=True))
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=configs['learning-rate'], weight_decay=configs['weight-decay'])
max_accuracy = 0.0
print(f'\n---------- Training Start At: {str(device).upper()} ----------\n')
for epoch in range(configs['num-epochs']):
model.train()
training_loss = 0.0
for index, (inputs, labels) in enumerate(train_dataloader, start=1):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
loss = criterion(model(inputs), labels)
loss.backward()
optimizer.step()
training_loss += loss.item()
print(f'\rBatch Loss: {loss:.3f} [{index}/{train_dataloader_size}]', end='')
model.eval()
training_loss /= train_dataloader_size
with torch.no_grad():
valid_accuracy = 0.0
for inputs, labels in valid_dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
valid_accuracy += (torch.argmax(outputs, dim=1) == labels).sum().item()
valid_accuracy /= valid_dataset_size
if valid_accuracy > max_accuracy:
max_accuracy = valid_accuracy
torch.save(model.state_dict(), configs['best-path'])
torch.save(model.state_dict(), configs['last-path'])
print(f'\tEpoch: {epoch:<6} Loss: {training_loss:<10.5f} Accuracy: {valid_accuracy:.3f}')
print('\n---------- Training Finish ----------\n')
print(f'Max Accuracy: {max_accuracy:.3f}')