Skip to content

Commit

Permalink
Merge pull request #153 from thearyadev/150-train-script-interactive
Browse files Browse the repository at this point in the history
150 train script interactive
  • Loading branch information
thearyadev authored Feb 15, 2024
2 parents 0149068 + e3d46ea commit 5c25406
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
45 changes: 31 additions & 14 deletions classifier/train.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
from shutil import copy
import sys


sys.path.append(".")
sys.path.append(".") # need access to heroes module


from model import NNModel, transformer
import torch.nn as nn
import torch
import torchvision # type: ignore
from pathlib import Path
from typing import Final
from heroes import Heroes
import uuid
from torch.utils.data import DataLoader
import shutil
CLASS_SIZE: Final[int] = 250
NUM_EPOCH: Final[int] = 5
from rich.prompt import Prompt
import datetime


def get_key_by_value(inputDict: dict[int, str], value: str) -> int:
# if this returns none (not found) the caller will throw an exception.
return next((key for key, val in inputDict.items() if val == value), None) # type: ignore


def make_fs_dataset(assets_directory: Path) -> Path:
def make_fs_dataset(assets_directory: Path, duplicates: int) -> Path:
hero_dir = assets_directory / "heroes"
dataset_dir = assets_directory / "dataset"

Expand All @@ -41,14 +40,18 @@ def make_fs_dataset(assets_directory: Path) -> Path:
for hero_image in hero_dir.iterdir():
target_dir = dataset_dir / str(get_key_by_value(labels, hero_image.name.replace(".png", "")))
target_dir.mkdir()
for _ in range(CLASS_SIZE):
for _ in range(duplicates):
shutil.copy(str(hero_image), str(target_dir /(uuid.uuid4().hex + ".png")))

return dataset_dir


def main():
dataset_path = make_fs_dataset(Path("./assets"))

def main() -> int:
model_name = f"{Prompt.ask('Model Name')}-{datetime.datetime.now().strftime('%d-%m-%Y')}"
epochs = Prompt.ask("Epochs", default="10")
class_size = Prompt.ask("Class Size", default="250")
dataset_path = make_fs_dataset(Path("./assets"), duplicates=int(class_size))

dataset = torchvision.datasets.ImageFolder(root=str(dataset_path), transform=transformer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
Expand All @@ -59,8 +62,10 @@ def main():
model = NNModel(num_classes=len(dataset.classes)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(NUM_EPOCH):
info_file: list[str] = list()
info_file.append(f"Model Name: {model_name}")
info_file.append(f"Class Size: {class_size}")
for epoch in range(int(epochs)):
model.train()

for images, labels in dataloader:
Expand All @@ -82,17 +87,29 @@ def main():
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels.to(device)).sum().item()
val_accuracy = correct / len(test_dataset)

info_file.append(
f"Epoch {epoch+1}: Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}"
)

print(f"Epoch {epoch+1}: Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}")

model_dir = Path("./models/tm")
model_dir = Path(f"./models/{model_name}")
model_dir.mkdir()

torch.save(model.state_dict(), str(model_dir / "model.pth"))
with open(model_dir / "classes", "w+") as f:
f.write("\n".join(dataset.classes))

shutil.copy(Path("./classifier/model.py"), model_dir / "model.py")
shutil.copy(Path("./classifier/model.py"), model_dir / "frozen_model.py")

with open(model_dir / "detail", "w+") as file:
file.writelines(info_file)

with open(model_dir / "__init__.py", "w+") as file:
file.write("from .frozen_model import transformer, NNModel as FrozenNeuralNetworkModel")

return 0

if __name__ == "__main__":
main()
raise SystemExit(main())
4 changes: 2 additions & 2 deletions heroes/hero_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def __init__(self):

def predict_hero_name(self, image: Image, model_directory: Path) -> str:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NNModel = importlib.import_module(f"models.{model_directory.name}.model").NNModel
transformer = importlib.import_module(f"models.{model_directory.name}.model").transformer
NNModel = importlib.import_module(f"models.{model_directory.name}").FrozenNeuralNetworkModel
transformer = importlib.import_module(f"models.{model_directory.name}").transformer

st_dict = torch.load(model_directory / "model.pth")
model = NNModel(num_classes=40)
Expand Down
4 changes: 2 additions & 2 deletions leaderboards/leaderboard_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def crop_split_column(pil_image: ImageType) -> list[ImageType]:
break
return list(reversed(result)) # Hero placement has been reversed in Season 9 (blizzard moment.)

def parse_leaderboard_to_leaderboard_entries(leaderboard_image: ImageType, region: Region, role: Role) -> list[LeaderboardEntry]:
def parse_leaderboard_to_leaderboard_entries(leaderboard_image: ImageType, region: Region, role: Role, model_name) -> list[LeaderboardEntry]:
hero_comp = Heroes()
hero_section = crop_to_hero_section(leaderboard_image)
row_entries = crop_split_row(hero_section)
Expand All @@ -109,7 +109,7 @@ def parse_leaderboard_to_leaderboard_entries(leaderboard_image: ImageType, regio
for row in split_column_entries: # each record (10)

results.append(LeaderboardEntry(
heroes=[hero_comp.predict_hero_name(hero_image, Path("./models/tm")) for hero_image in row],
heroes=[hero_comp.predict_hero_name(hero_image, Path(f"./models/{model_name}")) for hero_image in row],
role=role,
region=region

Expand Down

0 comments on commit 5c25406

Please sign in to comment.