-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path3_rule_filter.py
117 lines (104 loc) · 3.96 KB
/
3_rule_filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import argparse
import ftfy
import regex
from langdetect import detect
from tqdm import tqdm
import opencc
def load_jsonl(path):
with open(path, 'r', encoding='UTF-8') as f:
return [json.loads(l) for l in f]
class RuleFilter:
def __init__(self):
self.OPENCC_CONVERTER = opencc.OpenCC('t2s.json')
self.punctuation_unicode = {
',': ',',
'。': '.',
'、': ',',
'„': '"',
'”': '"',
'“': '"',
'«': '"',
'»': '"',
'1': '"',
'」': '"',
'「': '"',
'《': '"',
'》': '"',
'´': "'",
'∶': ':',
':': ':',
'?': '?',
'!': '!',
'(': '(',
')': ')',
';': ';',
'–': '-',
'—': ' - ',
'.': '. ',
'~': '~',
'’': "'",
'…': '...',
'━': '-',
'〈': '<',
'〉': '>',
'【': '[',
'】': ']',
'%': '%',
'►': '-',
}
self.various_whitespaces = {
' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ',
' ', ' ', ' ', ' ', '', '', '', '', '', ''
}
def handle(self, text):
# unicode
text = ftfy.fix_text(text, normalization="NFC")
# language filter
if detect(text) != args.language:
return None
# Standardization of Punctuation
text = ''.join([
self.punctuation_unicode.get(c, c) for c in text
])
# Standardization of Whitespace
text = ''.join([
char if char not in self.various_whitespaces else ' ' for char in text
])
# Replace all matched consecutive punctuation with a single punctuation
pattern = r'(\p{P})\1+'
text = regex.sub(pattern, r'\1', text)
text = text.strip()
# Filter out texts with too high a punctuation ratio and too short a text length
punctuation_count = len(regex.findall(r'\p{P}', text))
total_chars = len(text)
punctuation_ratio = punctuation_count / total_chars
if punctuation_ratio > args.punctuation_ratio_threshold or len(text) < args.text_length_threshold:
return None
# Convert Traditional Chinese Characters to Simplified Chinese
return self.OPENCC_CONVERTER.convert(text)
def filter(self, input_file_path, output_file_path):
with open(input_file_path, 'r', encoding='utf-8') as input_file, \
open(output_file_path, 'w', encoding='utf-8') as output_file:
for line in input_file:
try:
data = json.loads(line)
text = data.get(args.text_column, '')
result = self.handle(text)
if result:
data[args.text_column] = result
output_file.write(json.dumps(data, ensure_ascii=False) + '\n')
except json.JSONDecodeError:
continue # Ignore lines with parsing errors
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# The default input and output are jsonl files
parser.add_argument('--input_path', type=str)
parser.add_argument('--output_path', type=str)
parser.add_argument('--text_column', type=str)
parser.add_argument('--language', type=str)
parser.add_argument('--punctuation_ratio_threshold', type=float, default=0.5)
parser.add_argument('--text_length_threshold', type=int, default=128)
args = parser.parse_args()
filter = RuleFilter()
filter.filter(args.input_path, args.output_path)