From 9d4d1b8c753a49416865846fe721beea56ac1aaf Mon Sep 17 00:00:00 2001 From: Alex Shan Date: Thu, 26 Oct 2023 21:56:29 -0700 Subject: [PATCH] Update scripts to do model eval and baseline model --- .../models/lemma_classifier/baseline_model.py | 11 ++++ .../lemma_classifier/evaluate_models.py | 57 ++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/stanza/models/lemma_classifier/baseline_model.py b/stanza/models/lemma_classifier/baseline_model.py index 75c5f12539..133b4ee0cc 100644 --- a/stanza/models/lemma_classifier/baseline_model.py +++ b/stanza/models/lemma_classifier/baseline_model.py @@ -2,7 +2,18 @@ Baseline model for the existing lemmatizer which always predicts "be" and never "have" on the "'s" token. """ +import stanza +class BaselineModel: + def __init__(self, token_to_lemmatize, prediction_lemma): + self.token_to_lemmatize = token_to_lemmatize + self.prediction_lemma = prediction_lemma + + + def predict(self, token): + if token == self.token_to_lemmatize: + return self.prediction_lemma + diff --git a/stanza/models/lemma_classifier/evaluate_models.py b/stanza/models/lemma_classifier/evaluate_models.py index 04f9b87f79..5fc4bd8dcb 100644 --- a/stanza/models/lemma_classifier/evaluate_models.py +++ b/stanza/models/lemma_classifier/evaluate_models.py @@ -1,8 +1,61 @@ # TODO: Figure out how to load in the UD files into Stanza objects to get the features from them. +import os +import sys + +parentdir = os.path.dirname(__file__) +parentdir = os.path.dirname(parentdir) +parentdir = os.path.dirname(parentdir) +sys.path.append(parentdir) import stanza +from typing import Any, List, Tuple +from models.lemma_classifier.baseline_model import BaselineModel + + +def load_doc_from_conll_file(path: str): + return stanza.utils.conll.CoNLL.conll2doc(path) + + +def evaluate_models(eval_path: str, binary_classifier: Any, baseline_classifier: BaselineModel): + """ + Evaluates both the binary classifier and baseline classifier on a test file, + checking the predicted lemmas for each "'s" token against the gold lemma. + """ + + gold_doc = load_doc_from_conll_file(eval_path) + for sentence in doc.sentences: + for word in sentence.words: + if word.text == "'s": + gold_tag = word.lemma + # predict binary classifier + bin_predict = None # TODO + # predict baseline classifier + baseline_predict = baseline_classifier.predict(word.text) # TODO + # score + if gold_tag == bin_predict: + pass + if gold_tag == baseline_predict: + pass + + return # TODO write some kind of evaluation + + +def main(): + """ + Runs a test on the EN_GUM test set + """ + coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-test.conllu") + doc = load_doc_from_conll_file(coNLL_path) + count = 0 + for sentence in doc.sentences: + for word in sentence.words: + if word.text == "'s": + print("Found") + print(word) + count += 1 -doc = stanza.utils.CoNLL.conll2doc("/u/scr/corpora/Universal_Dependencies/Universal_Dependencies_2.12/ud-treebanks-v2.12/UD_English-GUM/en_gum-ud-test.conllu") + print(f"Count was {count}.") -print(doc.sentences) +if __name__ == "__main__": + main()