-
Notifications
You must be signed in to change notification settings - Fork 64
/
utils.py
68 lines (44 loc) · 1.41 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#! /usr/bin/env python
# -*- coding:utf-8 -*-
import os
VOCAB_SIZE = 6000
SEP_TOKEN = 0
PAD_TOKEN = 5999
DATA_RAW_DIR = 'data/raw'
DATA_PROCESSED_DIR = 'data/processed'
DATA_SAMPLES_DIR = 'data/samples'
MODEL_DIR = 'model'
LOG_DIR = 'log'
if not os.path.exists(DATA_PROCESSED_DIR):
os.mkdir(DATA_PROCESSED_DIR)
if not os.path.exists(MODEL_DIR):
os.mkdir(MODEL_DIR)
def embed_w2v(embedding, data_set):
embedded = [map(lambda x: embedding[x], sample) for sample in data_set]
return embedded
def apply_one_hot(data_set):
applied = [map(lambda x: to_categorical(x, num_classes=VOCAB_SIZE)[0], sample) for sample in data_set]
return applied
def apply_sparse(data_set):
applied = [map(lambda x: [x], sample) for sample in data_set]
return applied
def pad_to(lst, length, value):
for i in range(len(lst), length):
lst.append(value)
return lst
def uprint(x):
print repr(x).decode('unicode-escape'),
def uprintln(x):
print repr(x).decode('unicode-escape')
def is_CN_char(ch):
return ch >= u'\u4e00' and ch <= u'\u9fa5'
def split_sentences(line):
sentences = []
i = 0
for j in range(len(line)+1):
if j == len(line) or line[j] in [u',', u'。', u'!', u'?', u'、']:
if i < j:
sentence = u''.join(filter(is_CN_char, line[i:j]))
sentences.append(sentence)
i = j+1
return sentences