Skip to content

Commit

Permalink
Make updates to model evaluation, and add data processing class for e…
Browse files Browse the repository at this point in the history
…xtracting sentences of interest
  • Loading branch information
SecroLoL committed Nov 8, 2023
1 parent cd12480 commit 3e1b5f0
Show file tree
Hide file tree
Showing 4 changed files with 627 additions and 15 deletions.
77 changes: 62 additions & 15 deletions stanza/models/lemma_classifier/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,98 @@
import stanza
from typing import Any, List, Tuple
from models.lemma_classifier.baseline_model import BaselineModel
import utils


def load_doc_from_conll_file(path: str):
return stanza.utils.conll.CoNLL.conll2doc(path)
def update_counts(gold_tag: str, pred_tag: str, true_pos: int, false_pos: int, false_neg: int) -> Tuple[int, int, int]:
""""
Takes in a prediction along with the counts for true positive, false positive and false negative and updates the counts
of the measurements according to the prediction.
We measure positives, where we treat "be" as a positive and "have" as a negative.
"""
if gold_tag == "be" and pred_tag == "be":
true_pos += 1
elif gold_tag == "be" and pred_tag == "have":
false_neg += 1
elif gold_tag == "have" and pred_tag == "be":
false_pos += 1
return true_pos, false_pos, false_neg


def evaluate_models(eval_path: str, binary_classifier: Any, baseline_classifier: BaselineModel):
"""
Evaluates both the binary classifier and baseline classifier on a test file,
checking the predicted lemmas for each "'s" token against the gold lemma.
TODO: Measure precision, recall, and F1.
Precision = true positives / true positives + false positives
Recall = true positives / true positives + false negatives
F1 = 2 * (Precision * Recall) / (Precision + Recall)
"""
gold_doc = utils.load_doc_from_conll_file(eval_path)

gold_doc = load_doc_from_conll_file(eval_path)
for sentence in doc.sentences:
bin_tp, bin_fp, bin_fn = 0, 0, 0
bl_tp, bl_fp, bl_fn = 0, 0, 0 # baseline counts

for sentence in gold_doc.sentences:
for word in sentence.words:
if word.text == "'s" and word.upos in ["VERB", "AUX"]:
if word.text == "'s" and word.upos == "AUX": # only evaluate when the UPOS tag is AUX
gold_tag = word.lemma
# predict binary classifier
bin_predict = None # TODO
# predict baseline classifier
baseline_predict = baseline_classifier.predict(word.text) # TODO
# score
if gold_tag == bin_predict:
pass
if gold_tag == baseline_predict:
pass

# score binary classifier
bin_tp, bin_fp, bin_fn = update_counts(gold_tag, bin_predict, bin_tp, bin_fp, bin_fn)
bl_tp, bl_fp, bl_fn = update_counts(gold_tag, baseline_predict, bl_tp, bl_fp, bl_fn)

return # TODO write some kind of evaluation
# compute precision, recall, f1
bin_precision, bin_recall = bin_tp / (bin_tp + bin_fp), bin_tp / (bin_tp + bin_fn)
bin_results = {"precision": bin_precision,
"recall": bin_recall,
"f1": 2 * (bin_precision * bin_recall) / (bin_precision + bin_recall)
}

bl_precision, bl_recall = bl_tp / (bl_tp + bl_fp), bl_tp / (bl_tp + bl_fn)
bl_results = {"precision": bl_precision,
"recall": bl_recall,
"f1": 2 * (bl_precision * bl_recall) / (bl_precision + bl_recall)
}

return bin_results, bl_results


def main():
"""
Runs a test on the EN_GUM test set
"""
coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu")
doc = load_doc_from_conll_file(coNLL_path)
print(f"Attempting to find token 's in file {coNLL_path}...")
doc = utils.load_doc_from_conll_file(coNLL_path)
count = 0
be_count, have_count = 0, 0
for sentence in doc.sentences:
for word in sentence.words:
if word.text == "'s" and word.upos in ["VERB", "AUX"]:
print("Found")
if word.text == "'s" and word.upos == "AUX":
print("---------------------------")
print(word)
print("---------------------------")
if word.lemma == "have":
have_count += 1
if word.lemma == "be":
be_count += 1
count += 1

print(f"Count was {count}.")
print(f"The number of 's found was {count}.")
print(f"There were {have_count} occurrences of the lemma being 'have'.")
print(f"There were {be_count} occurrences of the lemma being 'be'.")

# bl_model = BaselineModel("'s", "be")
# bin_results, bl_results = evaluate_models(coNLL_path, None, bl_model)
# print(bin_results, bl_results)


if __name__ == "__main__":
Expand Down
94 changes: 94 additions & 0 deletions stanza/models/lemma_classifier/prepare_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import stanza
import utils
import os
from typing import List, Tuple, Any

"""
The code in this file processes a CoNLL dataset by taking its sentences and filtering out all sentences that do not contain the target token.
Furthermore, it will store tuples of the Stanza document object, the position index of the target token, and its lemma.
"""


class DataProcessor():

def __init__(self, target_word: str, target_upos: List[str]):
self.target_word = target_word
self.target_upos = target_upos

def find_all_occurrences(self, sentence) -> List[int]:
"""
Finds all occurrences of self.target_word in tokens and returns the index(es) of such occurrences.
"""
occurrences = []
for idx, token in enumerate(sentence.words):
if token.text == self.target_word and token.upos in self.target_upos:
occurrences.append(idx)
return occurrences

def process_document(self, doc, keep_condition: callable, save_name: str) -> None:
"""
Takes any sentence from `doc` that meets the condition of `keep_condition` and writes its tokens, index of target word, and lemma to `save_name`
Sentences that meet `keep_condition` and contain `self.target_word` multiple times have each instance in a different example in the output file.
Args:
doc (Stanza.doc): Document object that represents the file to be analyzed
keep_condition (callable): A function that outputs a boolean representing whether to analyze (True) or not analyze the sentence for a target word.
save_name (str): Path to the file for storing output
"""
if os.path.exists(save_name):
raise ValueError(f"Output path {save_name} already exists. Aborting...")
with open(save_name, "w+", encoding="utf-8") as output_f:
for sentence in doc.sentences:
# for each sentence, we need to determine if it should be added to the output file.
# if the sentence fulfills the keep_condition, then we will save it along with the target word's index and its corresponding lemma
if keep_condition(sentence):
tokens = [token.text for token in sentence.words]
indexes = self.find_all_occurrences(sentence)
for idx in indexes:
# for each example found, we write the tokens along with the target word index and lemma
output_f.write(f'{" ".join(tokens)} {idx} {sentence.words[idx].lemma}\n')

def read_processed_data(self, file_name: str) -> List[dict]:
"""
Reads the output file from `process_document()` and outputs a list that contains the sentences of interest. Each object within the list
contains a map with three (key, val) pairs:
"words" is a list that contains the tokens of the sentence
"index" is an integer representing which token in "words" the lemma annotation corresponds to
"lemma" is a string that is the lemma of the target word in the sentence.
"""
output = []
with open(file_name, "r", encoding="utf-8") as f:
for line in f.readlines():
obj = {}
split = line.split()
words, index, lemma = split[:-2], int(split[-2]), split[-1]

obj["words"] = words
obj["index"] = index
obj["lemma"] = lemma

output.append(obj)

return output


if __name__ == "__main__":

coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu")
doc = utils.load_doc_from_conll_file(coNLL_path)

processor = DataProcessor(target_word="'s", target_upos=["AUX"])
output_path = os.path.join(os.path.dirname(__file__), "test_output.txt")

def keep_sentence(sentence):
for word in sentence.words:
if word.text == "'s" and word.upos == "AUX":
return True
return False

processor.process_document(doc, keep_sentence, output_path)

print(processor.read_processed_data(output_path))
Loading

0 comments on commit 3e1b5f0

Please sign in to comment.