Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix attribute and parameter names in loggers #476

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import os
import time
from dataclasses import asdict, dataclass, field
from typing import Optional, Union
from typing import Union

import git
import numpy as np
Expand Down Expand Up @@ -251,11 +251,7 @@ class CompiledDetailOverAllTasks:
non_truncated (int): Total number of samples which did not need prompt truncation to fit the model context size across all tasks
padded (int): Number of samples which needed padding during the batching step across all tasks.
non_padded (int): Number of samples which did not need padding during the batching step across all tasks.
effective_few_shots (float): Average effective few shots across all samples across all tasks.
effective few shot is the number of few shots actually used to fit the prompt in the model context
length while allowing model generation of the expected size.
num_truncated_few_shots (int): Number of samples which required truncated prompts to fit the model size across all tasks.

"""

hashes: dict = field(default_factory=dict)
Expand Down Expand Up @@ -289,16 +285,16 @@ class CompiledHash:
Hashes the aggregated hash values for all the sample ([`Doc`]) of one task ([`LightevalTask`])

Attributes:
example (str): Aggregated hash of all the [`Doc.query`] hashes for all samples of the current task.
full_prompt (str): Aggregated hash of all the [`Doc.ctx`] hashes for all samples of the current task.
examples (str): Aggregated hash of all the [`Doc.query`] hashes for all samples of the current task.
full_prompts (str): Aggregated hash of all the [`Doc.ctx`] hashes for all samples of the current task.
input_tokens (str): Aggregated hash of the aggregated [`Doc.input_tokens`] hashes over all samples of the current task.
cont_tokens (str): Aggregated hash of the aggregated [`Doc.generated_tokens`] hashes over all samples of the current task.
"""

hash_examples: str = ""
hash_full_prompts: str = ""
hash_input_tokens: str = ""
hash_cont_tokens: str = ""
examples: str = ""
full_prompts: str = ""
input_tokens: str = ""
cont_tokens: str = ""

hashes: dict[str, list[Hash]] = field(default_factory=lambda: collections.defaultdict(list))
compiled_hashes: dict[str, CompiledHash] = field(
Expand All @@ -319,7 +315,6 @@ def log(
doc: Doc,
outputs: list[ModelResponse],
metrics: dict,
llm_as_prompt_judgement: Optional[tuple[str, str]] = None,
) -> None:
"""Stores the relevant information for one sample of one task to the total list of samples stored in the DetailsLogger.

Expand All @@ -329,8 +324,6 @@ def log(
doc (Doc): Current sample that we want to store.
outputs (list[ModelResponse]): Model outputs for the current sample
metrics (_type_): Model scores for said sample on the current task's metrics.
llm_as_prompt_judgement (tuple[str, str]): Tuple containing the
prompt passed to the judge and the judgement for the current sample when using llm-as-judge metric.
"""
detail = self.Detail()
detail.example = doc.query
Expand Down Expand Up @@ -415,16 +408,16 @@ def aggregate(self):

for task_name in self.hashes:
compiled_hash = self.CompiledHash()
compiled_hash.hash_examples = xxhash.xxh64(
compiled_hash.examples = xxhash.xxh64(
"".join(sorted(q.example for q in self.hashes[task_name]))
).hexdigest() # hash of all the hash - sorted for reproducibility
compiled_hash.hash_full_prompts = xxhash.xxh64(
compiled_hash.full_prompts = xxhash.xxh64(
"".join(sorted(q.full_prompt for q in self.hashes[task_name]))
).hexdigest() # hash of all the hash - sorted for reproducibility
compiled_hash.hash_input_tokens = xxhash.xxh64(
compiled_hash.input_tokens = xxhash.xxh64(
"".join(sorted(q.input_tokens for q in self.hashes[task_name]))
).hexdigest() # hash of all the hash - sorted for reproducibility
compiled_hash.hash_cont_tokens = xxhash.xxh64(
compiled_hash.cont_tokens = xxhash.xxh64(
"".join(sorted(q.cont_tokens for q in self.hashes[task_name]))
).hexdigest() # hash of all the hash - sorted for reproducibility
self.compiled_hashes[task_name] = compiled_hash
Expand Down Expand Up @@ -469,7 +462,7 @@ class MetricsLogger:
"""Logs the actual scores for each metric of each task.

Attributes:
metrics_value (dict[str, dict[str, list[float]]]): Maps each task to its dictionary of metrics to scores for all the example of the task.
metrics_values (dict[str, dict[str, list[float]]]): Maps each task to its dictionary of metrics to scores for all the example of the task.
Example: {"winogrande|winogrande_xl": {"accuracy": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]}}
metric_aggregated (dict[str, dict[str, float]]): Maps each task to its dictionary of metrics to aggregated scores over all the example of the task.
Example: {"winogrande|winogrande_xl": {"accuracy": 0.5}}
Expand All @@ -486,14 +479,12 @@ def log(self, task_name: str, metrics: dict) -> None:
for metric_name, metric_value in metrics.items():
self.metrics_values[task_name][metric_name].append(metric_value)

def aggregate(self, task_dict: dict[str, LightevalTask], bootstrap_iters: int = 1000): # noqa: C901
def aggregate(self, task_dict: dict[str, LightevalTask]): # noqa: C901
"""
Aggregate the metrics for each task and then for all tasks.

Args:
task_dict (dict[str, LightevalTask]): used to determine what aggregation function to use for each metric
bootstrap_iters (int, optional): Number of runs used to run the statistical bootstrap. Defaults to 1000.

"""

for task_name, metrics in self.metrics_values.items():
Expand Down Expand Up @@ -572,8 +563,7 @@ class VersionsLogger:
Tasks can have a version number/date, which indicates what is the precise metric definition and dataset version used for an evaluation.

Attributes:
version (dict[str, int]): Maps the task names with the task versions.

versions (dict[str, int]): Maps the task names with the task versions.
"""

# the versions dict will be a dict of task_name: task_version
Expand All @@ -589,8 +579,7 @@ class TaskConfigLogger:
"""Logs the different parameters of the current [`LightevalTask`] of interest.

Attributes:
tasks_config (dict[str, LightevalTaskConfig]): Maps each task to its associated [`LightevalTaskConfig`]

tasks_configs (dict[str, LightevalTaskConfig]): Maps each task to its associated [`LightevalTaskConfig`]
"""

tasks_configs: dict[str, LightevalTaskConfig] = field(default_factory=dict)
Expand Down
Loading