Skip to content

Commit

Permalink
Clean up BERTScore code.
Browse files Browse the repository at this point in the history
  • Loading branch information
anicolson committed Aug 28, 2024
1 parent 2150a75 commit 820607a
Showing 1 changed file with 1 addition and 42 deletions.
43 changes: 1 addition & 42 deletions tools/metrics/bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from bert_score import BERTScorer
from torchmetrics import Metric

# from torchmetrics.text import BERTScore
from transformers import AutoModel, AutoTokenizer


class BERTScoreRoBERTaLarge(Metric):
"""
Expand Down Expand Up @@ -94,13 +91,8 @@ def compute(self, epoch):
lang='en',
device=self.device,
rescale_with_baseline=True,
# baseline_path=os.path.join(self.ckpt_dir, 'bert_score', 'rescale_baseline', 'en', 'roberta-large.tsv'),
)

# RoBERTa tokenizer:
tokenizer = AutoTokenizer.from_pretrained(os.path.join(self.ckpt_dir, 'roberta-large'))
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

y_hat = [j['prediction'] for j in self.reports]
y = [j['label'] for j in self.reports]
study_ids = [j['study_id'] for j in self.reports]
Expand All @@ -112,41 +104,8 @@ def compute(self, epoch):
assert len(j) == 1
y = [j[0] for j in y]

# To fix the issue caused by DistilBERT adding too many special tokens:
y_hat_trimmed = tokenizer.batch_decode(
[i[:511] for i in tokenizer(y_hat).input_ids], skip_special_tokens=True,
)
y_trimmed = tokenizer.batch_decode(
[i[:511] for i in tokenizer(y).input_ids], skip_special_tokens=True,
)

# precision = [0.0] * len(y_trimmed)
# recall = [0.0] * len(y_trimmed)
# f1 = [0.0] * len(y_trimmed)
#
# # Drop pairs and track indices:
# y_hat_checked, y_checked, indices = [], [], []
# for i, (j, k) in enumerate(zip(y_hat_trimmed, y_trimmed)):
# if j and k:
# y_hat_checked.append(j)
# y_checked.append(k)
# indices.append(i)
# elif not y_hat_trimmed and y_trimmed:
# precision[i] = -1.0
# recall[i] = -1.0
# f1[i] = -1.0
#
# with torch.no_grad():
# bert_scores, hash_code = bert_scorer.score(y_hat_checked, y_checked, batch_size=self.mbatch_size, return_hash=True)
# print(hash_code)
#
# for i, x, y, z in zip(indices, bert_scores[0].tolist(), bert_scores[1].tolist(), bert_scores[2].tolist()):
# precision[i] = x
# recall[i] = y
# f1[i] = z

with torch.no_grad():
bert_scores, hash_code = bert_scorer.score(y_hat_trimmed, y_trimmed, batch_size=self.mbatch_size, return_hash=True)
bert_scores, hash_code = bert_scorer.score(y_hat, y, batch_size=self.mbatch_size, return_hash=True)
print(hash_code)

precision = bert_scores[0].tolist()
Expand Down

0 comments on commit 820607a

Please sign in to comment.