-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
68 lines (50 loc) · 1.95 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import time
import random
import torch
from torchvision import datasets, transforms
from pathlib import Path
from PIL import Image
from model import Net
def infer(model, dataset, save_dir, num_samples=5):
model.eval()
results_dir = Path(save_dir) / "results"
results_dir.mkdir(parents=True, exist_ok=True)
indices = random.sample(range(len(dataset)), num_samples)
for idx, i in enumerate(indices):
image, _ = dataset[i]
with torch.no_grad():
output = model(image.unsqueeze(0))
pred = output.argmax(dim=1, keepdim=True).item()
img = Image.fromarray(image.squeeze().numpy() * 255).convert("L")
# Ensure unique filenames
filename = f"{pred}_{idx}.png"
img.save(results_dir / filename)
print(f"Saved {num_samples} inference result images.")
# for idx in indices:
# image, _ = dataset[idx]
# with torch.no_grad():
# output = model(image.unsqueeze(0))
# pred = output.argmax(dim=1, keepdim=True).item()
# img = Image.fromarray(image.squeeze().numpy() * 255).convert("L")
# img.save(results_dir / f"{pred}_{i}.png")
# i +=1
def main():
save_dir = "/opt/mount"
# init model and load checkpoint here
model = Net() # model architecture defined in model.py
checkpoint_path = "/opt/mount/model/mnist_cnn.pt" # Path to the saved model checkpoint
model.load_state_dict(torch.load(checkpoint_path))
# create transforms and test dataset for mnist
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # Normalization values for MNIST
])
# Load the MNIST test dataset
dataset = datasets.MNIST(
"/opt/mount", train=False, download=True, transform=transform
)
infer(model, dataset, save_dir)
print("Inference completed. Results saved in the 'results' folder.")
if __name__ == "__main__":
main()