-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathutils_dataset.py
35 lines (29 loc) · 1.35 KB
/
utils_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import h5py, torch, sys, os, sqlite3
import torch.utils.data.dataset
class HDF5Dataset(torch.utils.data.dataset.Dataset):
def __init__(self, filename, collection_name='name'):
self.f = h5py.File(filename,'r')
self.dset = self.f[collection_name]
def __getitem__(self, index):
return self.dset[index]
def __len__(self):
return len(self.dset)
class SQLDataset(torch.utils.data.dataset.Dataset):
def __init__(self, filename, table_name='articles', cut=None):
self.table_name = table_name
self.conn = sqlite3.connect(filename, detect_types=sqlite3.PARSE_DECLTYPES)
self.conn.row_factory = sqlite3.Row
self.cut = cut
self.curr = self.conn.cursor()
def __getitem__(self, index):
if self.cut is not None:
res = self.curr.execute("SELECT * FROM "+self.table_name+" WHERE cut_id=? and cut=?", (index, self.cut))
else:
res = self.curr.execute("SELECT * FROM "+self.table_name+" WHERE id= ?", (index,))
return [dict(r) for r in res][0]
def __len__(self):
if self.cut is not None:
N = self.curr.execute("SELECT COUNT(*) as count FROM "+self.table_name+" WHERE cut = ?", (self.cut,)).fetchone()[0]
else:
N = self.curr.execute("SELECT COUNT(*) as count FROM "+self.table_name).fetchone()[0]
return N