Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accuracy metric #72

Open
zhuzihan728 opened this issue Apr 20, 2024 · 3 comments
Open

accuracy metric #72

zhuzihan728 opened this issue Apr 20, 2024 · 3 comments

Comments

@zhuzihan728
Copy link

zhuzihan728 commented Apr 20, 2024

The accuracy in metrics.py is defined as

def accuracy(preds, labels):
    match_count = 0
    for pred, label in zip(preds, labels):
        target = label[0]
        if pred == target:
            match_count += 1

    return 100 * (match_count / len(preds))

While in run_short_form.py, acc is calculated per data instance:

if args.metric == "accuracy":
            metric_result = accuracy(pred, row["output"])

where pred is some string, and row["output"] is neither present in any short-form dataset, nor defined in your code.

@leeds1219
Copy link

changed "output" to "answers" and it kind of fixed the problem but
0it [00:00, ?it/s]

...

Processed prompts: 0%| | 0/10 [00:00<?, ?it/s]�[A

Processed prompts: 10%|█ | 1/10 [00:00<00:03, 2.59it/s]�[A

Processed prompts: 70%|███████ | 7/10 [00:00<00:00, 13.11it/s]�[A
Processed prompts: 100%|██████████| 10/10 [00:00<00:00, 15.28it/s]

19it [00:53, 2.81s/it]
Traceback (most recent call last):
File "run_short_form.py", line 378, in
main()
File "run_short_form.py", line 349, in main
metric_result = accuracy(pred, row["answers"])
File "/workspace/rag/rag/self-rag/retrieval_lm/metrics.py", line 21, in accuracy
target = label[0]
IndexError: string index out of range

i think the code isnt complete yet...

@leeds1219
Copy link

preds = []
prompts = []
golds = []
metric_results = []
scores = []
all_results = []
count = 0
for i, row in tqdm(enumerate(input_data)):
    results = {}
    prompt = PROMPT_DICT["prompt_no_input"].format_map(row)
    _, evidences = process_data_evidences(row, top_n=args.ndocs)
    pred, results, do_retrieve = generate(
        prompt, evidences, max_new_tokens=args.max_new_tokens,)
    if type(pred) is str and pred[0] == "#" or pred[0] == ":":
        pred = pred[1:]
    prompts.append(prompt)
    preds.append(pred)
    all_results.append(results)
    if do_retrieve is True:
        count += 1
    if "answers" not in row and "answer" in row:
        row["answers"] = [row["answer"]] if type(
            row["answer"]) is str else row["answer"]
    ######################################################################
    # 2024-05-22 fixed index outof range error        
    row["answers"] = [answer for answer in row["answers"] if answer != ""] 
    ######################################################################
    if args.metric == "accuracy":

#############################################################################################
# 2024-05-22 fixed key error "output" doesnt exist
# metric_result = accuracy(pred, row["output"])
metric_result = accuracy(pred, row["answers"])
##############################################################################################
elif args.metric == "match":
if "SUPPORTS" in pred:
pred = "true"
elif "REFUTES" in pred:
pred = "false"
metric_result = match(pred, row["answers"])
else:
raise NotImplementedError

modified the code and works

@zhuzihan728
Copy link
Author

preds = []
prompts = []
golds = []
metric_results = []
scores = []
all_results = []
count = 0
for i, row in tqdm(enumerate(input_data)):
    results = {}
    prompt = PROMPT_DICT["prompt_no_input"].format_map(row)
    _, evidences = process_data_evidences(row, top_n=args.ndocs)
    pred, results, do_retrieve = generate(
        prompt, evidences, max_new_tokens=args.max_new_tokens,)
    if type(pred) is str and pred[0] == "#" or pred[0] == ":":
        pred = pred[1:]
    prompts.append(prompt)
    preds.append(pred)
    all_results.append(results)
    if do_retrieve is True:
        count += 1
    if "answers" not in row and "answer" in row:
        row["answers"] = [row["answer"]] if type(
            row["answer"]) is str else row["answer"]
    ######################################################################
    # 2024-05-22 fixed index outof range error        
    row["answers"] = [answer for answer in row["answers"] if answer != ""] 
    ######################################################################
    if args.metric == "accuracy":

############################################################################################# # 2024-05-22 fixed key error "output" doesnt exist # metric_result = accuracy(pred, row["output"]) metric_result = accuracy(pred, row["answers"]) ############################################################################################## elif args.metric == "match": if "SUPPORTS" in pred: pred = "true" elif "REFUTES" in pred: pred = "false" metric_result = match(pred, row["answers"]) else: raise NotImplementedError

modified the code and works

Thx for the reply :D
I see now that the accuracy calculation is to check if the first letter of the prediction matches with the first letter of a gold answer. which only makes sense if the gold answer list is of length one cuz they are zipping a string with a list, and if it is for multiple-choice datasets?

And as you pointed out the list out of index error, the author seems mistakenly put empty strings ("") in some gold answer lists in the eval_data they provide in this link. Not sure why, but this definitely doesn't look right and only lifts up the final metric score if using the match method in metrics.py for the metric calculation.
@AkariAsai

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants