-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
30 lines (26 loc) · 820 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
@author : Hyunwoong
@when : 2019-10-29
@homepage : https://github.com/gusdnd852
"""
from conf import *
from util.data_loader import DataLoader
from util.tokenizer import Tokenizer
tokenizer = Tokenizer()
loader = DataLoader(
ext=(".en", ".de"),
tokenize_en=tokenizer.tokenize_en,
tokenize_de=tokenizer.tokenize_de,
init_token="<sos>",
eos_token="<eos>",
)
train_, valid, test = loader.make_dataset()
loader.build_vocab(train_data=train_, min_freq=2)
train_iter, valid_iter, test_iter = loader.make_iter(
train_, valid, test, batch_size=batch_size, device=device
)
src_pad_idx = loader.source.vocab.stoi["<pad>"]
trg_pad_idx = loader.target.vocab.stoi["<pad>"]
trg_sos_idx = loader.target.vocab.stoi["<sos>"]
enc_voc_size = len(loader.source.vocab)
dec_voc_size = len(loader.target.vocab)