Skip to content

Commit

Permalink
Update to VLSP_2023, including making TP and WHADV legal
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Nov 13, 2023
1 parent 2f20353 commit 1a44c92
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
6 changes: 4 additions & 2 deletions stanza/utils/datasets/constituency/prepare_con_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,11 @@ def process_vlsp22(paths, dataset_name, *args):
if dataset_name == 'vi_vlsp22':
default_subdir = 'VLSP_2022'
default_make_test_split = False
updated_tagset = False
elif dataset_name == 'vi_vlsp23':
default_subdir = os.path.join('VLSP_2023', 'TrainingDataset')
default_subdir = os.path.join('VLSP_2023', 'Trainingdataset')
default_make_test_split = True
updated_tagset = True

parser = argparse.ArgumentParser()
parser.add_argument('--subdir', default=default_subdir, type=str, help='Where to find the data - allows for using previous versions, if needed')
Expand All @@ -276,7 +278,7 @@ def process_vlsp22(paths, dataset_name, *args):
print("Loading training files from {}".format(vlsp_dir))
print("Procesing training files:\n {}".format("\n ".join(vlsp_train_files)))
with tempfile.TemporaryDirectory() as train_output_path:
vtb_convert.convert_files(vlsp_train_files, train_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets)
vtb_convert.convert_files(vlsp_train_files, train_output_path, verbose=True, fix_errors=True, convert_brackets=args.convert_brackets, updated_tagset=updated_tagset)
# This produces a 0 length test set, just as a placeholder until the actual test set is released
if args.n_splits:
test_size = 0.1 if args.test_split else 0.0
Expand Down
16 changes: 10 additions & 6 deletions stanza/utils/datasets/constituency/vtb_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,19 @@ def is_valid_line(line):

# not clear if TP is supposed to be NP or PP - needs a native speaker to decode
WEIRD_LABELS = sorted(set(["WP", "YP", "SNP", "STC", "UPC", "(TP", "Xp", "XP", "WHVP", "WHPR", "NO", "WHADV", "(SC (", "(VOC (", "(Adv (", "(SP (", "ADV-MDP", "(SPL", "(ADV (", "(V-MWE ("] + list(REMAPPING.keys())))
# the 2023 dataset has TP and WHADV as actual labels
WEIRD_LABELS_2023 = sorted(set(["WP", "YP", "SNP", "STC", "UPC", "Xp", "XP", "WHVP", "WHPR", "NO", "(SC (", "(VOC (", "(Adv (", "(SP (", "ADV-MDP", "(SPL", "(ADV (", "(V-MWE ("] + list(REMAPPING.keys())))


def convert_file(orig_file, new_file, fix_errors=True, convert_brackets=False):
def convert_file(orig_file, new_file, fix_errors=True, convert_brackets=False, updated_tagset=False):
"""
:param orig_file: original directory storing original trees
:param new_file: new directory storing formatted constituency trees
This function writes new trees to the corresponding files in new_file
"""
if updated_tagset:
weird_labels = WEIRD_LABELS_2023
else:
weird_labels = WEIRD_LABELS
errors = defaultdict(list)
with open(orig_file, 'r', encoding='utf-8') as reader, open(new_file, 'w', encoding='utf-8') as writer:
content = reader.readlines()
Expand Down Expand Up @@ -166,7 +171,7 @@ def convert_file(orig_file, new_file, fix_errors=True, convert_brackets=False):
# TODO: this block eliminates 3 trees from VLSP-22
# maybe those trees can be salvaged?
bad_label = False
for weird_label in WEIRD_LABELS:
for weird_label in weird_labels:
if tree.find(weird_label) >= 0:
bad_label = True
errors[weird_label].append("Weird label {} from {} line {}: {}".format(weird_label, orig_file, line_idx, tree))
Expand All @@ -190,14 +195,14 @@ def convert_file(orig_file, new_file, fix_errors=True, convert_brackets=False):

return errors

def convert_files(file_list, new_dir, verbose=False, fix_errors=True, convert_brackets=False):
def convert_files(file_list, new_dir, verbose=False, fix_errors=True, convert_brackets=False, updated_tagset=False):
errors = defaultdict(list)
for filename in file_list:
base_name, _ = os.path.splitext(os.path.split(filename)[-1])
new_path = os.path.join(new_dir, base_name)
new_file_path = f'{new_path}.mrg'
# Convert the tree and write to new_file_path
new_errors = convert_file(filename, new_file_path, fix_errors, convert_brackets)
new_errors = convert_file(filename, new_file_path, fix_errors, convert_brackets, updated_tagset)
for e in new_errors:
errors[e].extend(new_errors[e])

Expand Down Expand Up @@ -239,7 +244,6 @@ def main():
'new_dir',
help='The location of new directory storing the new formatted trees'
)

args = parser.parse_args()

org_dir = args.org_dir
Expand Down

0 comments on commit 1a44c92

Please sign in to comment.