Skip to content

Commit

Permalink
Update scripts to do model eval and baseline model
Browse files Browse the repository at this point in the history
  • Loading branch information
SecroLoL committed Oct 27, 2023
1 parent 5bd3492 commit 9d4d1b8
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
11 changes: 11 additions & 0 deletions stanza/models/lemma_classifier/baseline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@
Baseline model for the existing lemmatizer which always predicts "be" and never "have" on the "'s" token.
"""

import stanza

class BaselineModel:

def __init__(self, token_to_lemmatize, prediction_lemma):
self.token_to_lemmatize = token_to_lemmatize
self.prediction_lemma = prediction_lemma


def predict(self, token):
if token == self.token_to_lemmatize:
return self.prediction_lemma



57 changes: 55 additions & 2 deletions stanza/models/lemma_classifier/evaluate_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,61 @@
# TODO: Figure out how to load in the UD files into Stanza objects to get the features from them.
import os
import sys

parentdir = os.path.dirname(__file__)
parentdir = os.path.dirname(parentdir)
parentdir = os.path.dirname(parentdir)
sys.path.append(parentdir)

import stanza
from typing import Any, List, Tuple
from models.lemma_classifier.baseline_model import BaselineModel


def load_doc_from_conll_file(path: str):
return stanza.utils.conll.CoNLL.conll2doc(path)


def evaluate_models(eval_path: str, binary_classifier: Any, baseline_classifier: BaselineModel):
"""
Evaluates both the binary classifier and baseline classifier on a test file,
checking the predicted lemmas for each "'s" token against the gold lemma.
"""

gold_doc = load_doc_from_conll_file(eval_path)
for sentence in doc.sentences:
for word in sentence.words:
if word.text == "'s":
gold_tag = word.lemma
# predict binary classifier
bin_predict = None # TODO
# predict baseline classifier
baseline_predict = baseline_classifier.predict(word.text) # TODO
# score
if gold_tag == bin_predict:
pass
if gold_tag == baseline_predict:
pass

return # TODO write some kind of evaluation


def main():
"""
Runs a test on the EN_GUM test set
"""
coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-test.conllu")
doc = load_doc_from_conll_file(coNLL_path)
count = 0
for sentence in doc.sentences:
for word in sentence.words:
if word.text == "'s":
print("Found")
print(word)
count += 1

doc = stanza.utils.CoNLL.conll2doc("/u/scr/corpora/Universal_Dependencies/Universal_Dependencies_2.12/ud-treebanks-v2.12/UD_English-GUM/en_gum-ud-test.conllu")
print(f"Count was {count}.")

print(doc.sentences)

if __name__ == "__main__":
main()

0 comments on commit 9d4d1b8

Please sign in to comment.