-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathData_Retriever_Inference_Real_Time.py
59 lines (39 loc) · 1.11 KB
/
Data_Retriever_Inference_Real_Time.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import pdb
import os
import cv2
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, Dataset
from albumentations import (Normalize, Compose, Resize)
from albumentations.pytorch import ToTensor
import torch.utils.data as data
# In[2]:
class TestDataset(Dataset):
'''Dataset for test prediction'''
def __init__(self, root, fname, mean, std):
self.root = root
self.fname = fname
self.num_samples = 1
self.transform = Compose(
[
Normalize(mean=mean, std=std, p=1),
Resize(224,224),
ToTensor()
]
)
def __getitem__(self, idx):
fnames = self.fname
path = os.path.join(self.root, fnames)
image = cv2.imread(path)
images = self.transform(image=image)["image"]
return fnames, images
def __len__(self):
return self.num_samples
# In[ ]:
# In[ ]: