Skip to content

Commit

Permalink
adding early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
Kunal22shah committed Mar 29, 2024
1 parent d3862a6 commit 1064d27
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 18 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ dataset/Happy/processed
dataset/Neutral/processed
dataset/Suprised/processed
dataset/Focused/processed
__pycache__
__pycache__
venv
Binary file modified emotion_classifier_model_cnn_variant2.pth
Binary file not shown.
Binary file modified emotion_classifier_model_cnn_variant3.pth
Binary file not shown.
37 changes: 33 additions & 4 deletions model-train2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
validation_loader = DataLoader(validation_set, batch_size=32, shuffle=False)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Training on device: {device}")

model = CNNVariant2()
model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
Expand All @@ -44,6 +48,7 @@
running_loss = 0.0

for i, (images, labels) in enumerate(train_loader):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
Expand All @@ -59,8 +64,14 @@
correct_validation = 0
total_validation = 0

early_stopping_patience = 3
epochs_since_improvement = 0
min_loss_decrease = 0.001 # Minimum decrease in loss to qualify as an improvement
best_model_state = None

with torch.no_grad():
for images, labels in validation_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
validation_loss += loss.item()
Expand All @@ -69,19 +80,37 @@

avg_validation_loss = validation_loss / len(validation_loader)
validation_accuracy = 100 * correct_validation / len(validation_set)
print(f'Validation: Epoch {epoch + 1}/{num_epochs}, Loss: {avg_validation_loss:.6f}, Accuracy: {validation_accuracy:.2f}%')
print(
f'Validation: Epoch {epoch + 1}/{num_epochs}, Loss: {avg_validation_loss:.6f}, Accuracy: {validation_accuracy:.2f}%')

if avg_validation_loss < best_validation_loss:
# Check for improvement
if best_validation_loss - avg_validation_loss > min_loss_decrease:
best_validation_loss = avg_validation_loss
torch.save(model.state_dict(), "emotion_classifier_model_cnn_variant2.pth")
epochs_since_improvement = 0
best_model_state = model.state_dict()
else:
epochs_since_improvement += 1

# Early stopping condition check
if epochs_since_improvement >= early_stopping_patience:
print("Early stopping triggered. Stopping training...")
break

# Save the best model outside the training loop
if best_model_state is not None:
torch.save(best_model_state, "emotion_classifier_model_cnn_variant2.pth")
print("Best model saved.")
else:
print("No improvement over initial model. Best model not saved.")

# Test the model
model.eval()
test_correct = 0
test_total = 0

model.eval()
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
test_total += labels.size(0)
Expand Down
36 changes: 23 additions & 13 deletions model-train3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
validation_loader = DataLoader(validation_set, batch_size=32, shuffle=False)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

model = CNNVariant3()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Training on device: {device}")

model = CNNVariant3().to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
Expand All @@ -38,12 +41,15 @@
# Set up the training process
num_epochs = 10
best_validation_loss = float('inf')
early_stopping_patience = 3
early_stopping_counter = 0

for epoch in range(num_epochs):
model.train()
running_loss = 0.0

for i, (images, labels) in enumerate(train_loader):
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
Expand All @@ -52,7 +58,7 @@
running_loss += loss.item()

avg_training_loss = running_loss / len(train_loader)
print(f'Training: Epoch {epoch + 1}/{num_epochs}, Loss: {avg_training_loss:.6f}')
print(f'Training: Epoch {epoch+1}/{num_epochs}, Loss: {avg_training_loss:.6f}')

model.eval()
validation_loss = 0.0
Expand All @@ -61,25 +67,28 @@

with torch.no_grad():
for images, labels in validation_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
validation_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
correct_validation += (predicted == labels).sum().item()
total_validation += labels.size(0) # Correctly update total_validation
total_validation += labels.size(0)

avg_validation_loss = validation_loss / len(validation_loader)
if total_validation > 0: # Ensure we don't divide by zero
validation_accuracy = 100 * correct_validation / total_validation
print(
f'Validation: Epoch {epoch + 1}/{num_epochs}, Loss: {avg_validation_loss:.6f}, Accuracy: {validation_accuracy:.2f}%')
else:
print(
f'Validation: Epoch {epoch + 1}/{num_epochs}, Loss: {avg_validation_loss:.6f}, Accuracy: N/A - No validation data')
validation_accuracy = 100 * correct_validation / total_validation
print(f'Validation: Epoch {epoch+1}/{num_epochs}, Loss: {avg_validation_loss:.6f}, Accuracy: {validation_accuracy:.2f}%')

# Early Stopping
if avg_validation_loss < best_validation_loss:
best_validation_loss = avg_validation_loss
early_stopping_counter = 0
torch.save(model.state_dict(), "emotion_classifier_model_cnn_variant3.pth")
else:
early_stopping_counter += 1
if early_stopping_counter >= early_stopping_patience:
print("Early stopping triggered.")
break

# Test the model
test_correct = 0
Expand All @@ -88,10 +97,11 @@
model.eval()
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum().item()

test_accuracy = 100 * test_correct / test_total
print(f'Test Accuracy: {test_accuracy:.2f}%')
test_accuracy = test_correct / test_total
print(f'Test Accuracy: {test_accuracy:.2f}')

0 comments on commit 1064d27

Please sign in to comment.