-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add plotting script to analyze training logs
- Loading branch information
Akram
authored and
Akram
committed
Aug 7, 2024
1 parent
e66ae2f
commit 93d2b39
Showing
2 changed files
with
100 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" Plotting script with Matplotlib """ | ||
|
||
# Imports | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
|
||
# Define the path to the CSV file | ||
LOG_FILE = "logs/training_logs.csv" | ||
|
||
# Read the CSV file | ||
data = pd.read_csv(LOG_FILE) | ||
|
||
# Define length | ||
data["Mini Batch"] = range(1, len(data) + 1) | ||
|
||
# Plotting Actor Loss | ||
plt.figure(figsize=(10, 5)) | ||
plt.plot(data["Mini Batch"], data["Actor Loss"], label="Actor Loss") | ||
plt.xlabel("Mini Batch") | ||
plt.ylabel("Loss") | ||
plt.title("Actor Loss Over Mini Batches") | ||
plt.legend() | ||
plt.grid(True) | ||
plt.savefig("logs/actor_loss_plot.png") | ||
plt.show() | ||
|
||
# Plotting Critic Loss | ||
plt.figure(figsize=(10, 5)) | ||
plt.plot(data["Mini Batch"], data["Critic Loss"], label="Critic Loss") | ||
plt.xlabel("Mini Batch") | ||
plt.ylabel("Loss") | ||
plt.title("Critic Loss Over Mini Batches") | ||
plt.legend() | ||
plt.grid(True) | ||
plt.savefig("logs/critic_loss_plot.png") | ||
plt.show() | ||
|
||
# Plotting Total Rewards | ||
plt.figure(figsize=(10, 5)) | ||
plt.plot(data["Mini Batch"], data["Reward"], label="Reward") | ||
plt.xlabel("Mini Batch") | ||
plt.ylabel("Total Reward") | ||
plt.title("Total Rewards Over Mini Batches") | ||
plt.legend() | ||
plt.grid(True) | ||
plt.savefig("logs/total_rewards_plot.png") | ||
plt.show() |