From 5909d4aac1db15fa17998cfbb3f5f3de5ad7d31e Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Wed, 11 Dec 2024 16:28:26 +0100 Subject: [PATCH 01/21] Added first version of custom model. --- src/lighteval/__main__.py | 2 + src/lighteval/main_custom.py | 144 +++++++++++++++++++++++++++ src/lighteval/models/model_config.py | 6 ++ src/lighteval/models/model_loader.py | 33 ++++++ src/lighteval/pipeline.py | 1 + 5 files changed, 186 insertions(+) create mode 100644 src/lighteval/main_custom.py diff --git a/src/lighteval/__main__.py b/src/lighteval/__main__.py index 4484f7812..25b0ec880 100644 --- a/src/lighteval/__main__.py +++ b/src/lighteval/__main__.py @@ -27,6 +27,7 @@ import lighteval.main_accelerate import lighteval.main_baseline +import lighteval.main_custom import lighteval.main_endpoint import lighteval.main_nanotron import lighteval.main_tasks @@ -63,6 +64,7 @@ app.command(rich_help_panel="Evaluation Utils")(lighteval.main_baseline.baseline) app.command(rich_help_panel="Evaluation Backends")(lighteval.main_nanotron.nanotron) app.command(rich_help_panel="Evaluation Backends")(lighteval.main_vllm.vllm) +app.command(rich_help_panel="Evaluation Backends")(lighteval.main_custom.custom) app.add_typer( lighteval.main_endpoint.app, name="endpoint", diff --git a/src/lighteval/main_custom.py b/src/lighteval/main_custom.py new file mode 100644 index 000000000..4ea0291dc --- /dev/null +++ b/src/lighteval/main_custom.py @@ -0,0 +1,144 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import os +from typing import Optional + +import typer +from typer import Argument, Option +from typing_extensions import Annotated + + +app = typer.Typer() + + +TOKEN = os.getenv("HF_TOKEN") +CACHE_DIR: str = os.getenv("HF_HOME", "/scratch") + +HELP_PANNEL_NAME_1 = "Common Paramaters" +HELP_PANNEL_NAME_2 = "Logging Parameters" +HELP_PANNEL_NAME_3 = "Debug Paramaters" +HELP_PANNEL_NAME_4 = "Modeling Paramaters" + + +@app.command(rich_help_panel="Evaluation Backends") +def custom( + # === general === + model_name: Annotated[str, Argument(help="The model name to evaluate")], + model_definition_file_path: Annotated[str, Argument(help="The model definition file path to evaluate")], + tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")], + # === Common parameters === + use_chat_template: Annotated[ + bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANNEL_NAME_4) + ] = False, + system_prompt: Annotated[ + Optional[str], Option(help="Use system prompt for evaluation.", rich_help_panel=HELP_PANNEL_NAME_4) + ] = None, + dataset_loading_processes: Annotated[ + int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANNEL_NAME_1) + ] = 1, + custom_tasks: Annotated[ + Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANNEL_NAME_1) + ] = None, + cache_dir: Annotated[ + str, Option(help="Cache directory for datasets and models.", rich_help_panel=HELP_PANNEL_NAME_1) + ] = CACHE_DIR, + num_fewshot_seeds: Annotated[ + int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANNEL_NAME_1) + ] = 1, + # === saving === + output_dir: Annotated[ + str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = "results", + push_to_hub: Annotated[ + bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = False, + push_to_tensorboard: Annotated[ + bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = False, + public_run: Annotated[ + bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = False, + results_org: Annotated[ + Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = None, + save_details: Annotated[ + bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = False, + # === debug === + max_samples: Annotated[ + Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANNEL_NAME_3) + ] = None, + override_batch_size: Annotated[ + int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANNEL_NAME_3) + ] = None, + job_id: Annotated[ + int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANNEL_NAME_3) + ] = 0, +): + """ + Evaluate custom models (can be anything). + """ + from lighteval.logging.evaluation_tracker import EvaluationTracker + from lighteval.models.model_config import CustomModelConfig + from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters + + env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) + evaluation_tracker = EvaluationTracker( + output_dir=output_dir, + save_details=save_details, + push_to_hub=push_to_hub, + push_to_tensorboard=push_to_tensorboard, + public=public_run, + hub_results_org=results_org, + ) + + parallelism_manager = ParallelismManager.CUSTOM + model_config = CustomModelConfig(model=model_name, model_definition_file_path=model_definition_file_path) + + pipeline_params = PipelineParameters( + launcher_type=parallelism_manager, + env_config=env_config, + job_id=job_id, + dataset_loading_processes=dataset_loading_processes, + custom_tasks_directory=custom_tasks, + override_batch_size=override_batch_size, + num_fewshot_seeds=num_fewshot_seeds, + max_samples=max_samples, + use_chat_template=use_chat_template, + system_prompt=system_prompt, + ) + pipeline = Pipeline( + tasks=tasks, + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model_config=model_config, + ) + + pipeline.evaluate() + + pipeline.show_results() + + results = pipeline.get_results() + + pipeline.save_and_push_results() + + return results diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 1eda1e029..e3f703777 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -244,6 +244,12 @@ class VLLMModelConfig: temperature: float = 0.6 # will be used for multi sampling tasks, for tasks requiring no sampling, this will be ignored and set to 0. +@dataclass +class CustomModelConfig: + model: str + model_definition_file_path: str + + @dataclass class OpenAIModelConfig: model: str diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 1a409746c..dfc244730 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -23,6 +23,7 @@ import logging from typing import Union +from lighteval.models.abstract_model import LightevalModel from lighteval.models.adapter_model import AdapterModel from lighteval.models.base_model import BaseModel from lighteval.models.delta_model import DeltaModel @@ -31,6 +32,7 @@ from lighteval.models.model_config import ( AdapterModelConfig, BaseModelConfig, + CustomModelConfig, DeltaModelConfig, DummyModelConfig, InferenceEndpointModelConfig, @@ -64,6 +66,7 @@ def load_model( # noqa: C901 InferenceEndpointModelConfig, DummyModelConfig, VLLMModelConfig, + CustomModelConfig, OpenAIModelConfig, ], env_config: EnvConfig, @@ -99,6 +102,9 @@ def load_model( # noqa: C901 if isinstance(config, VLLMModelConfig): return load_model_with_accelerate_or_default(config=config, env_config=env_config) + if isinstance(config, CustomModelConfig): + return load_custom_model(config=config, env_config=env_config) + if isinstance(config, OpenAIModelConfig): return load_openai_model(config=config, env_config=env_config) @@ -114,6 +120,33 @@ def load_model_with_tgi(config: TGIModelConfig): return model +def load_custom_model(config: CustomModelConfig, env_config: EnvConfig): + import importlib.util + + # Load the Python file + spec = importlib.util.spec_from_file_location("custom_model_module", config.model_definition_file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load file: {config.model_definition_file_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find the first class that inherits from LightevalModel + model_class = None + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, LightevalModel) and attr != LightevalModel: + model_class = attr + break + + if model_class is None: + raise ValueError(f"No class inheriting from LightevalModel found in {config.model_definition_file_path}") + + model = model_class(config, env_config) + + return model + + def load_openai_model(config: OpenAIModelConfig, env_config: EnvConfig): if not is_openai_available(): raise ImportError() diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index facecd8ec..88b262829 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -76,6 +76,7 @@ class ParallelismManager(Enum): TGI = auto() OPENAI = auto() VLLM = auto() + CUSTOM = auto() NONE = auto() From 60960421a1a2deb80e474db70ead856e7ecf5bb9 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 15:44:46 +0100 Subject: [PATCH 02/21] Moved custom model config. --- src/lighteval/main_custom.py | 9 ++----- src/lighteval/models/custom/custom_model.py | 29 +++++++++++++++++++++ src/lighteval/models/model_loader.py | 2 +- 3 files changed, 32 insertions(+), 8 deletions(-) create mode 100644 src/lighteval/models/custom/custom_model.py diff --git a/src/lighteval/main_custom.py b/src/lighteval/main_custom.py index 80fc80dc5..84c953826 100644 --- a/src/lighteval/main_custom.py +++ b/src/lighteval/main_custom.py @@ -20,13 +20,14 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import os -from dataclasses import dataclass from typing import Optional import typer from typer import Argument, Option from typing_extensions import Annotated +from lighteval.models.custom.custom_model import CustomModelConfig + app = typer.Typer() @@ -40,12 +41,6 @@ HELP_PANNEL_NAME_4 = "Modeling Paramaters" -@dataclass -class CustomModelConfig: - model: str - model_definition_file_path: str - - @app.command(rich_help_panel="Evaluation Backends") def custom( # === general === diff --git a/src/lighteval/models/custom/custom_model.py b/src/lighteval/models/custom/custom_model.py new file mode 100644 index 000000000..f0d08dd96 --- /dev/null +++ b/src/lighteval/models/custom/custom_model.py @@ -0,0 +1,29 @@ +# MIT License +# +# Copyright (c) 2024 The HuggingFace Team +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from dataclasses import dataclass + + +@dataclass +class CustomModelConfig: + model: str + model_definition_file_path: str diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index fcb460ad7..c73e13a46 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -23,8 +23,8 @@ import logging from typing import Union -from lighteval.main_custom import CustomModelConfig from lighteval.models.abstract_model import LightevalModel +from lighteval.models.custom.custom_model import CustomModelConfig from lighteval.models.dummy.dummy_model import DummyModel, DummyModelConfig from lighteval.models.endpoints.endpoint_model import ( InferenceEndpointModel, From a7e1fe58592549a146ca49368403e7e5a889da5d Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 15:46:52 +0100 Subject: [PATCH 03/21] Added warning. --- src/lighteval/models/model_loader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index c73e13a46..b0d480846 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -114,6 +114,8 @@ def load_model_with_tgi(config: TGIModelConfig): def load_custom_model(config: CustomModelConfig, env_config: EnvConfig): + logger.warning(f"Executing custom model code loaded from {config.model_definition_file_path}.") + import importlib.util # Load the Python file From 24b8bd35e31334050d2704274092e1348bafbfa5 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 15:49:30 +0100 Subject: [PATCH 04/21] Added custom model example for google translate. --- .../custom_models/google_translate_model.py | 151 ++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 examples/custom_models/google_translate_model.py diff --git a/examples/custom_models/google_translate_model.py b/examples/custom_models/google_translate_model.py new file mode 100644 index 000000000..4d79cf2d2 --- /dev/null +++ b/examples/custom_models/google_translate_model.py @@ -0,0 +1,151 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +from typing import Optional + +from tqdm import tqdm +from transformers import AutoTokenizer + +from lighteval.data import GenerativeTaskDataset +from lighteval.models.abstract_model import LightevalModel, ModelInfo +from lighteval.models.model_output import ( + GenerativeResponse, + LoglikelihoodResponse, + LoglikelihoodSingleTokenResponse, +) +from lighteval.tasks.requests import ( + GreedyUntilRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, + LoglikelihoodSingleTokenRequest, +) + + +logger = logging.getLogger(__name__) + + +class GoogleTranslateClient(LightevalModel): + def __init__(self, config, env_config) -> None: + self.model = config.model + self.model_definition_file_path = config.model_definition_file_path + + self.model_info = ModelInfo( + model_name=config.model, + model_sha="", + model_dtype=None, + model_size="", + ) + + self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility + + import httpcore + + # Needed to fix some googletrans bug + # https://stackoverflow.com/questions/72796594/attributeerror-module-httpcore-has-no-attribute-synchttptransport#comment136664963_77334618 + setattr(httpcore, "SyncHTTPTransport", "AsyncHTTPProxy") + from googletrans import Translator + + self.translator = Translator() + + def greedy_until( + self, + requests: list[GreedyUntilRequest], + override_bs: Optional[int] = None, + ) -> list[GenerativeResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerativeResponse]: list of generated responses. + """ + for request in requests: + request.tokenized_context = self.tok_encode(request.context) + + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) + results = [] + + for _ in tqdm( + dataset.splits_start_end_iterator(), + total=dataset.num_dataset_splits, + desc="Splits", + position=0, + disable=False, # self.disable_tqdm, + ): + for r in tqdm(dataset, desc="Batch", position=1, disable=False): + context = r.context.replace("French phrase: ", "") + # TODO: Get src and dest from request + translation = self.translator.translate(context, src="fr", dest="de") + + result = translation.text + cur_response = GenerativeResponse( + result=result, + logits=None, + generated_tokens=[], + input_tokens=[], + ) + results.append(cur_response) + + return dataset.get_original_order(results) + + @property + def tokenizer(self): + return self._tokenizer + + def tok_encode(self, text: str): + return self.tokenizer.encode(text) + + @property + def add_special_tokens(self) -> bool: + return False + + @property + def max_length(self) -> int: + """Return the maximum sequence length of the model.""" + return 4096 + + def loglikelihood( + self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError + + def loglikelihood_rolling( + self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """This function is used to compute the log likelihood of the context for perplexity metrics.""" + raise NotImplementedError + + def loglikelihood_single_token( + self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodSingleTokenResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError From c177a8e46ee39eb889816a41c64a37da366cc9c3 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 15:53:40 +0100 Subject: [PATCH 05/21] Added documentation for custom model config. --- src/lighteval/models/custom/custom_model.py | 49 +++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/lighteval/models/custom/custom_model.py b/src/lighteval/models/custom/custom_model.py index f0d08dd96..1b48a807b 100644 --- a/src/lighteval/models/custom/custom_model.py +++ b/src/lighteval/models/custom/custom_model.py @@ -25,5 +25,54 @@ @dataclass class CustomModelConfig: + """ + Configuration class for loading custom model implementations in Lighteval. + + This config allows users to define and load their own model implementations by specifying + a Python file containing a custom model class that inherits from LightevalModel. + + The custom model file should contain exactly one class that inherits from LightevalModel. + This class will be automatically detected and instantiated when loading the model. + + Args: + model (str): + An identifier for the model. This can be used to track which model was evaluated + in the results and logs. + + model_definition_file_path (str): + Path to a Python file containing the custom model implementation. This file must + define exactly one class that inherits from LightevalModel. The class should + implement all required methods from the LightevalModel interface. + + Example usage: + ```python + # Define config + config = CustomModelConfig( + model="my-custom-model", + model_definition_file_path="path/to/my_model.py" + ) + + # Example custom model file (my_model.py): + from lighteval.models.abstract_model import LightevalModel + + class MyCustomModel(LightevalModel): + def __init__(self, config, env_config): + super().__init__(config, env_config) + # Custom initialization... + + def greedy_until(self, *args, **kwargs): + # Custom generation logic... + pass + ``` + + An example of a custom model can be found in `examples/custom_models/google_translate_model.py`. + + Notes: + - The custom model class must inherit from LightevalModel and implement all required methods + - Only one class inheriting from LightevalModel should be defined in the file + - The model file is dynamically loaded at runtime, so ensure all dependencies are available + - Exercise caution when loading custom model files as they can execute arbitrary code + """ + model: str model_definition_file_path: str From d712cdb41acede58e3f624d4fc72d03d3cfc46e7 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 16:03:13 +0100 Subject: [PATCH 06/21] Added docs. --- docs/source/_toctree.yml | 2 + docs/source/evaluating-a-custom-model.mdx | 129 ++++++++++++++++++++++ docs/source/package_reference/models.mdx | 4 + 3 files changed, 135 insertions(+) create mode 100644 docs/source/evaluating-a-custom-model.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 9ad55466a..beedaff60 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -15,6 +15,8 @@ title: Add a custom task - local: adding-a-new-metric title: Add a custom metric + - local: evaluating-a-custom-model + title: Evaluate a custom model - local: use-vllm-as-backend title: Use VLLM as backend - local: evaluate-the-model-on-a-server-or-container diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx new file mode 100644 index 000000000..751de0bef --- /dev/null +++ b/docs/source/evaluating-a-custom-model.mdx @@ -0,0 +1,129 @@ +# Evaluating a Custom Model + +Lighteval allows you to evaluate custom model implementations by creating a custom model class that inherits from `LightevalModel`. This is useful when you want to evaluate models that aren't directly supported by the standard backends (transformers, vllm, etc). + +## Creating a Custom Model + +1. Create a Python file containing your custom model implementation. The model must inherit from `LightevalModel` and implement all required methods. + +Here's a basic example: + +```python +from lighteval.models.abstract_model import LightevalModel + +class MyCustomModel(LightevalModel): + def __init__(self, config, env_config): + super().__init__(config, env_config) + # Initialize your model here... + + def greedy_until(self, requests, max_tokens=None, stop_sequences=None): + # Implement generation logic + pass + + def loglikelihood(self, requests, log=True): + # Implement loglikelihood computation + pass + + def loglikelihood_rolling(self, requests): + # Implement rolling loglikelihood computation + pass + + def loglikelihood_single_token(self, requests): + # Implement single token loglikelihood computation + pass +``` + +2. The custom model file should contain exactly one class that inherits from `LightevalModel`. This class will be automatically detected and instantiated when loading the model. + +> [!TIP] +> You can find a complete example of a custom model implementation in `examples/custom_models/google_translate_model.py`. + +## Running the Evaluation + +You can evaluate your custom model using either the command line interface or the Python API. + +### Using the Command Line + +```bash +lighteval custom \ + "google-translate" \ + "examples/custom_models/google_translate_model.py" \ + "leaderboard|wmt20:fr-de|0|0" \ + --output-dir results \ + --max-samples 10 +``` + +The command takes three required arguments: +- The model name (used for tracking in results/logs) +- The path to your model implementation file +- The tasks to evaluate on (same format as other backends) + +### Using the Python API + +```python +from lighteval.logging.evaluation_tracker import EvaluationTracker +from lighteval.models.custom.custom_model import CustomModelConfig +from lighteval.pipeline import Pipeline, PipelineParameters, EnvConfig + +# Set up evaluation tracking +evaluation_tracker = EvaluationTracker( + output_dir="results", + save_details=True +) + +# Configure the pipeline +pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.CUSTOM, + env_config=EnvConfig(cache_dir="tmp/") +) + +# Configure your custom model +model_config = CustomModelConfig( + model="my-custom-model", + model_definition_file_path="path/to/my_model.py" +) + +# Create and run the pipeline +pipeline = Pipeline( + tasks="leaderboard|truthfulqa:mc|0|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model_config=model_config +) + +pipeline.evaluate() +pipeline.save_and_push_results() +``` + +## Required Methods + +Your custom model must implement these core methods: + +- `greedy_until`: For generating text until a stop sequence or max tokens is reached +- `loglikelihood`: For computing log probabilities of specific continuations +- `loglikelihood_rolling`: For computing rolling log probabilities of sequences +- `loglikelihood_single_token`: For computing log probabilities of single tokens + +See the `LightevalModel` base class documentation for detailed method signatures and requirements. + +## Best Practices + +1. **Error Handling**: Implement robust error handling in your model methods to gracefully handle edge cases. + +2. **Batching**: Consider implementing efficient batching in your model methods to improve performance. + +3. **Resource Management**: Properly manage any resources (e.g., API connections, model weights) in your model's `__init__` and `__del__` methods. + +4. **Documentation**: Add clear docstrings to your model class and methods explaining any specific requirements or limitations. + +## Example Use Cases + +Custom models are particularly useful for: + +- Evaluating models accessed through custom APIs +- Wrapping models with specialized preprocessing/postprocessing +- Testing novel model architectures +- Evaluating ensemble models +- Integrating with external services or tools + +For a complete example of a custom model that wraps the Google Translate API, see `examples/custom_models/google_translate_model.py`. diff --git a/docs/source/package_reference/models.mdx b/docs/source/package_reference/models.mdx index 096ce7be3..5d15fff9a 100644 --- a/docs/source/package_reference/models.mdx +++ b/docs/source/package_reference/models.mdx @@ -28,6 +28,10 @@ [[autodoc]] models.endpoints.tgi_model.TGIModelConfig [[autodoc]] models.endpoints.tgi_model.ModelClient +### Custom Model +[[autodoc]] models.custom.custom_model.CustomModelConfig +[[autodoc]] models.custom.custom_model.CustomModel + ### Open AI Models [[autodoc]] models.endpoints.openai_model.OpenAIClient From b41949c2dfe9b291d6a3348ff22ab20c7bd9ae2b Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 16:05:02 +0100 Subject: [PATCH 07/21] Fixed path error. --- docs/source/evaluating-a-custom-model.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx index 751de0bef..168a56e69 100644 --- a/docs/source/evaluating-a-custom-model.mdx +++ b/docs/source/evaluating-a-custom-model.mdx @@ -48,7 +48,7 @@ You can evaluate your custom model using either the command line interface or th lighteval custom \ "google-translate" \ "examples/custom_models/google_translate_model.py" \ - "leaderboard|wmt20:fr-de|0|0" \ + "lighteval|wmt20:fr-de|0|0" \ --output-dir results \ --max-samples 10 ``` From aaaadb0aa9063556289f189ac0e5ee79c9106bc5 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 16:34:26 +0100 Subject: [PATCH 08/21] Fixed doc error. --- docs/source/evaluating-a-custom-model.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx index 168a56e69..d4ff33cfa 100644 --- a/docs/source/evaluating-a-custom-model.mdx +++ b/docs/source/evaluating-a-custom-model.mdx @@ -45,7 +45,7 @@ You can evaluate your custom model using either the command line interface or th ### Using the Command Line ```bash -lighteval custom \ +python -m lighteval custom \ "google-translate" \ "examples/custom_models/google_translate_model.py" \ "lighteval|wmt20:fr-de|0|0" \ From c85065f59091176869f310419ace020123fd8473 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 16:39:06 +0100 Subject: [PATCH 09/21] Added requirements file for google translate. --- .../google-translate-requirements-freeze.txt | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 examples/custom_models/google-translate-requirements-freeze.txt diff --git a/examples/custom_models/google-translate-requirements-freeze.txt b/examples/custom_models/google-translate-requirements-freeze.txt new file mode 100644 index 000000000..bf1cd7200 --- /dev/null +++ b/examples/custom_models/google-translate-requirements-freeze.txt @@ -0,0 +1,113 @@ +absl-py==2.1.0 +accelerate==1.2.0 +aenum==3.1.15 +aiohappyeyeballs==2.4.4 +aiohttp==3.11.10 +aiosignal==1.3.1 +annotated-types==0.7.0 +anyio==4.7.0 +attrs==24.2.0 +blis==0.7.11 +catalogue==2.0.10 +certifi==2024.8.30 +chardet==3.0.4 +charset-normalizer==3.4.0 +click==8.1.7 +cloudpathlib==0.16.0 +colorama==0.4.6 +colorlog==6.9.0 +confection==0.1.5 +cymem==2.0.10 +DataProperty==1.0.1 +datasets==3.2.0 +dill==0.3.8 +filelock==3.16.1 +frozenlist==1.5.0 +fsspec==2024.9.0 +gitdb==4.0.11 +GitPython==3.1.43 +googletrans==4.0.0rc1 +h11==0.14.0 +h2==3.2.0 +hpack==3.0.0 +hstspreload==2024.12.1 +httpcore==1.0.7 +httpx==0.28.1 +huggingface-hub==0.26.5 +hyperframe==5.2.0 +idna==2.10 +Jinja2==3.1.4 +joblib==1.4.2 +langcodes==3.5.0 +language_data==1.3.0 +lighteval==0.6.0.dev0 +lxml==5.3.0 +marisa-trie==1.2.1 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +mbstrdecoder==1.1.3 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.1.0 +multiprocess==0.70.16 +murmurhash==1.0.11 +networkx==3.4.2 +nltk==3.9.1 +numpy==1.26.4 +packaging==24.2 +pandas==2.2.3 +pathvalidate==3.2.1 +portalocker==3.0.0 +preshed==3.0.9 +propcache==0.2.1 +protobuf==3.20.3 +psutil==6.1.0 +pyarrow==18.1.0 +pycountry==24.6.1 +pydantic==2.10.3 +pydantic_core==2.27.1 +Pygments==2.18.0 +pytablewriter==1.2.0 +python-dateutil==2.9.0.post0 +pytz==2024.2 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +rfc3986==1.5.0 +rich==13.9.4 +rouge_score==0.1.2 +sacrebleu==2.4.3 +safetensors==0.4.5 +scikit-learn==1.6.0 +scipy==1.14.1 +sentencepiece==0.2.0 +setuptools==75.1.0 +six==1.17.0 +smart-open==6.4.0 +smmap==5.0.1 +sniffio==1.3.1 +spacy==3.7.2 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +srsly==2.5.0 +sympy==1.13.3 +tabledata==1.3.3 +tabulate==0.9.0 +tcolorpy==0.1.6 +termcolor==2.3.0 +thinc==8.2.5 +threadpoolctl==3.5.0 +tokenizers==0.21.0 +torch==2.4.1 +tqdm==4.67.1 +transformers==4.47.0 +typepy==1.3.2 +typer==0.9.4 +typing_extensions==4.12.2 +tzdata==2024.2 +urllib3==2.2.3 +wasabi==1.1.3 +weasel==0.3.4 +wheel==0.44.0 +xxhash==3.5.0 +yarl==1.18.3 From f1103da9f425dfdd80835823a1ff5d763425929d Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 16:50:37 +0100 Subject: [PATCH 10/21] Moved model loading function to reduce merge conflicts with litellm inference. --- src/lighteval/models/model_loader.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index b0d480846..3d19178e4 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -113,6 +113,15 @@ def load_model_with_tgi(config: TGIModelConfig): return model +def load_openai_model(config: OpenAIModelConfig, env_config: EnvConfig): + if not is_openai_available(): + raise ImportError() + + model = OpenAIClient(config, env_config) + + return model + + def load_custom_model(config: CustomModelConfig, env_config: EnvConfig): logger.warning(f"Executing custom model code loaded from {config.model_definition_file_path}.") @@ -142,15 +151,6 @@ def load_custom_model(config: CustomModelConfig, env_config: EnvConfig): return model -def load_openai_model(config: OpenAIModelConfig, env_config: EnvConfig): - if not is_openai_available(): - raise ImportError() - - model = OpenAIClient(config, env_config) - - return model - - def load_model_with_inference_endpoints(config: InferenceEndpointModelConfig, env_config: EnvConfig): logger.info("Spin up model using inference endpoint.") model = InferenceEndpointModel(config=config, env_config=env_config) From 71f871e73cad38b2e7153e423ae96710252bf6d5 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 17:19:12 +0100 Subject: [PATCH 11/21] Added diskcache and get source and target language from the task name. --- .../custom_models/google_translate_model.py | 36 ++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/examples/custom_models/google_translate_model.py b/examples/custom_models/google_translate_model.py index 4d79cf2d2..158ee8418 100644 --- a/examples/custom_models/google_translate_model.py +++ b/examples/custom_models/google_translate_model.py @@ -20,9 +20,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import hashlib import logging +import os from typing import Optional +import diskcache from tqdm import tqdm from transformers import AutoTokenizer @@ -67,6 +70,29 @@ def __init__(self, config, env_config) -> None: self.translator = Translator() + # Initialize disk cache + cache_dir = os.path.join(os.getcwd(), ".translation_cache") + self.cache = diskcache.Cache(cache_dir) + + def _get_cache_key(self, context: str, src_lang: str, tgt_lang: str) -> str: + """Generate a unique cache key for the translation request.""" + key_string = f"{context}|{src_lang}|{tgt_lang}" + return hashlib.md5(key_string.encode()).hexdigest() + + def _translate_with_cache(self, context: str, src_lang: str, tgt_lang: str) -> str: + """Translate text using cache if available, otherwise call Google Translate.""" + cache_key = self._get_cache_key(context, src_lang, tgt_lang) + + # Try to get from cache + if cache_key in self.cache: + return self.cache[cache_key] + + # If not in cache, translate and store + translation = self.translator.translate(context, src=src_lang, dest=tgt_lang) + result = translation.text + self.cache[cache_key] = result + return result + def greedy_until( self, requests: list[GreedyUntilRequest], @@ -74,10 +100,10 @@ def greedy_until( ) -> list[GenerativeResponse]: """ Generates responses using a greedy decoding strategy until certain ending conditions are met. + Results are cached to disk to avoid repeated translations. Args: requests (list[Request]): list of requests containing the context and ending conditions. - disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. override_bs (int, optional): Override the batch size for generation. Defaults to None. Returns: @@ -97,11 +123,13 @@ def greedy_until( disable=False, # self.disable_tqdm, ): for r in tqdm(dataset, desc="Batch", position=1, disable=False): + # Extract source and target languages from task name + # Format is like "community|sdst-text_level:de-fr|0" + src_lang, tgt_lang = r.task_name.split("|")[1].split(":")[-1].split("-") + context = r.context.replace("French phrase: ", "") - # TODO: Get src and dest from request - translation = self.translator.translate(context, src="fr", dest="de") + result = self._translate_with_cache(context, src_lang, tgt_lang) - result = translation.text cur_response = GenerativeResponse( result=result, logits=None, From d1af518ff39e0f730e23addf1c3c33724a5ba803 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 12 Dec 2024 18:37:16 +0100 Subject: [PATCH 12/21] Fixed problem with removing languages in the context. --- examples/custom_models/google_translate_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/custom_models/google_translate_model.py b/examples/custom_models/google_translate_model.py index 158ee8418..1d67b04e1 100644 --- a/examples/custom_models/google_translate_model.py +++ b/examples/custom_models/google_translate_model.py @@ -127,7 +127,7 @@ def greedy_until( # Format is like "community|sdst-text_level:de-fr|0" src_lang, tgt_lang = r.task_name.split("|")[1].split(":")[-1].split("-") - context = r.context.replace("French phrase: ", "") + context = r.context.replace(f"{src_lang.upper()}: ", "").replace(f"\n{tgt_lang.upper()}: ", "") result = self._translate_with_cache(context, src_lang, tgt_lang) cur_response = GenerativeResponse( From 25111589eab4d693e2df6fe91bb356864fc31f60 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Fri, 13 Dec 2024 10:00:49 +0100 Subject: [PATCH 13/21] Added retry logic. --- .../custom_models/google_translate_model.py | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/examples/custom_models/google_translate_model.py b/examples/custom_models/google_translate_model.py index 1d67b04e1..d22d97bae 100644 --- a/examples/custom_models/google_translate_model.py +++ b/examples/custom_models/google_translate_model.py @@ -23,9 +23,13 @@ import hashlib import logging import os +import time from typing import Optional import diskcache +import httpcore +import tenacity +from googletrans import Translator from tqdm import tqdm from transformers import AutoTokenizer @@ -61,12 +65,9 @@ def __init__(self, config, env_config) -> None: self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility - import httpcore - # Needed to fix some googletrans bug # https://stackoverflow.com/questions/72796594/attributeerror-module-httpcore-has-no-attribute-synchttptransport#comment136664963_77334618 setattr(httpcore, "SyncHTTPTransport", "AsyncHTTPProxy") - from googletrans import Translator self.translator = Translator() @@ -74,24 +75,39 @@ def __init__(self, config, env_config) -> None: cache_dir = os.path.join(os.getcwd(), ".translation_cache") self.cache = diskcache.Cache(cache_dir) + self.max_retries = 3 + self.retry_delay = 1 + def _get_cache_key(self, context: str, src_lang: str, tgt_lang: str) -> str: """Generate a unique cache key for the translation request.""" key_string = f"{context}|{src_lang}|{tgt_lang}" return hashlib.md5(key_string.encode()).hexdigest() + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(multiplier=1, min=4, max=10), + retry=tenacity.retry_if_exception_type((Exception)), + before_sleep=lambda retry_state: time.sleep(1), + ) def _translate_with_cache(self, context: str, src_lang: str, tgt_lang: str) -> str: - """Translate text using cache if available, otherwise call Google Translate.""" + """Translate text using cache if available, otherwise call Google Translate with retry logic.""" cache_key = self._get_cache_key(context, src_lang, tgt_lang) # Try to get from cache if cache_key in self.cache: return self.cache[cache_key] - # If not in cache, translate and store - translation = self.translator.translate(context, src=src_lang, dest=tgt_lang) - result = translation.text - self.cache[cache_key] = result - return result + try: + # If not in cache, translate and store + translation = self.translator.translate(context, src=src_lang, dest=tgt_lang) + result = translation.text + self.cache[cache_key] = result + return result + except Exception as e: + logger.warning(f"Translation error: {str(e)}. Retrying...") + # Re-initialize translator on error + self.translator = Translator() + raise # Let tenacity handle the retry def greedy_until( self, From 743a2846ea3f460e471b9aba462942b0d05556af Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Mon, 16 Dec 2024 16:30:24 +0100 Subject: [PATCH 14/21] Update google-translate requirements. --- .../google-translate-requirements-freeze.txt | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/custom_models/google-translate-requirements-freeze.txt b/examples/custom_models/google-translate-requirements-freeze.txt index bf1cd7200..1e6389f8d 100644 --- a/examples/custom_models/google-translate-requirements-freeze.txt +++ b/examples/custom_models/google-translate-requirements-freeze.txt @@ -21,18 +21,19 @@ cymem==2.0.10 DataProperty==1.0.1 datasets==3.2.0 dill==0.3.8 +diskcache==5.6.3 filelock==3.16.1 frozenlist==1.5.0 fsspec==2024.9.0 gitdb==4.0.11 GitPython==3.1.43 googletrans==4.0.0rc1 -h11==0.14.0 +h11==0.9.0 h2==3.2.0 hpack==3.0.0 hstspreload==2024.12.1 -httpcore==1.0.7 -httpx==0.28.1 +httpcore==0.9.1 +httpx==0.13.3 huggingface-hub==0.26.5 hyperframe==5.2.0 idna==2.10 @@ -40,7 +41,6 @@ Jinja2==3.1.4 joblib==1.4.2 langcodes==3.5.0 language_data==1.3.0 -lighteval==0.6.0.dev0 lxml==5.3.0 marisa-trie==1.2.1 markdown-it-py==3.0.0 @@ -94,6 +94,7 @@ sympy==1.13.3 tabledata==1.3.3 tabulate==0.9.0 tcolorpy==0.1.6 +tenacity==9.0.0 termcolor==2.3.0 thinc==8.2.5 threadpoolctl==3.5.0 From 1a37f715b94929abc9c3f04b28b4a57d927b0ab2 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Tue, 17 Dec 2024 10:28:20 +0100 Subject: [PATCH 15/21] Added another example for a custom model. --- examples/custom_models/seamless4mt_model.py | 191 ++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 examples/custom_models/seamless4mt_model.py diff --git a/examples/custom_models/seamless4mt_model.py b/examples/custom_models/seamless4mt_model.py new file mode 100644 index 000000000..0dc858d27 --- /dev/null +++ b/examples/custom_models/seamless4mt_model.py @@ -0,0 +1,191 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +from typing import Optional + +import pycountry +from tqdm import tqdm +from transformers import AutoProcessor, SeamlessM4Tv2ForTextToText + +from lighteval.data import GenerativeTaskDataset +from lighteval.models.abstract_model import LightevalModel, ModelInfo, TokenSequence +from lighteval.models.model_output import ( + GenerativeResponse, + LoglikelihoodResponse, + LoglikelihoodSingleTokenResponse, +) +from lighteval.tasks.requests import ( + GreedyUntilRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, + LoglikelihoodSingleTokenRequest, +) + + +logger = logging.getLogger(__name__) + + +class Seamless4MTClient(LightevalModel): + def __init__(self, config, env_config) -> None: + self.model = config.model + self.model_definition_file_path = config.model_definition_file_path + + self.model_info = ModelInfo( + model_name=config.model, + model_sha="", + model_dtype=None, + model_size="", + ) + self._tokenizer = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large") + self._model = SeamlessM4Tv2ForTextToText.from_pretrained("facebook/seamless-m4t-v2-large") + + def _convert_to_iso3(self, lang_code: str) -> str: + """Convert 2-letter ISO code to 3-letter ISO code.""" + try: + return pycountry.languages.get(alpha_2=lang_code.lower()).alpha_3 + except AttributeError: + # If conversion fails, return the original code + return lang_code + + def greedy_until( + self, + requests: list[GreedyUntilRequest], + override_bs: Optional[int] = None, + ) -> list[GenerativeResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + Results are cached to disk to avoid repeated translations. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerativeResponse]: list of generated responses. + """ + + def get_langs(task_name: str) -> tuple[str, str]: + src, tgt = task_name.split("|")[1].split(":")[-1].split("-") + return self._convert_to_iso3(src), self._convert_to_iso3(tgt) + + # Prepare all inputs first + prepared_requests = [] + for request in requests: + src_lang, tgt_lang = get_langs(request.task_name) + request.context = request.context.replace(f"{src_lang.upper()}: ", "").replace( + f"\n{tgt_lang.upper()}: ", "" + ) + request.tokenized_context = self._tokenizer( + text=request.context, src_lang=src_lang, return_tensors="pt", padding=True + ) + prepared_requests.append(request) + + # Create dataset after preparation + dataset = GenerativeTaskDataset(requests=prepared_requests, num_dataset_splits=self.DATASET_SPLITS) + results = [] + batch_size = override_bs or 32 + + for split_start, split_end in tqdm( + dataset.splits_start_end_iterator(), + total=dataset.num_dataset_splits, + desc="Splits", + position=0, + disable=False, + ): + # Get all requests for this split directly from sorted_data + current_requests = dataset.sorted_data[split_start:split_end] + + # Process in batches + for batch_idx in range(0, len(current_requests), batch_size): + batch = current_requests[batch_idx : batch_idx + batch_size] + + # Batch tokenize all inputs together instead of concatenating pre-tokenized inputs + batch_texts = [r.context for r in batch] + src_lang = get_langs(batch[0].task_name)[0] # All source languages should be the same in a batch + + # Unpack the tokenizer output into input_ids and attention_mask + input_ids, attention_mask = self._tokenizer( + text=batch_texts, src_lang=src_lang, return_tensors="pt", padding=True + ).values() + + tgt_langs = [get_langs(r.task_name)[1] for r in batch] + assert set(tgt_langs) == {tgt_langs[0]}, "All target languages must be the same" + + # Use unpacked values directly + output_ids = self._model.generate( + input_ids=input_ids, attention_mask=attention_mask, tgt_lang=tgt_langs[0] + ) + translations = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + # Create responses for the batch + for input_tokens, output_tokens, translation in zip(input_ids, output_ids, translations): + results.append( + GenerativeResponse( + input_tokens=input_tokens, + generated_tokens=output_tokens, + result=translation, + logits=None, + ) + ) + + return dataset.get_original_order(results) + + @property + def tokenizer(self): + return self._tokenizer + + def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence: + return self._tokenizer( + text=str_to_encode, return_tensors="pt", padding=True, add_special_tokens=add_special_tokens or False + ) + + @property + def add_special_tokens(self) -> bool: + return False + + @property + def max_length(self) -> int: + """Return the maximum sequence length of the model.""" + return 4096 + + def loglikelihood( + self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError + + def loglikelihood_rolling( + self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """This function is used to compute the log likelihood of the context for perplexity metrics.""" + raise NotImplementedError + + def loglikelihood_single_token( + self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodSingleTokenResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError From 2f27645b0ae79ce0719e2d396091b7a1c5c6a175 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Tue, 17 Dec 2024 11:36:17 +0100 Subject: [PATCH 16/21] Made local mt model example more general to support madlad400 as well. --- ...seamless4mt_model.py => local_mt_model.py} | 95 +++++++++++++++---- 1 file changed, 75 insertions(+), 20 deletions(-) rename examples/custom_models/{seamless4mt_model.py => local_mt_model.py} (65%) diff --git a/examples/custom_models/seamless4mt_model.py b/examples/custom_models/local_mt_model.py similarity index 65% rename from examples/custom_models/seamless4mt_model.py rename to examples/custom_models/local_mt_model.py index 0dc858d27..0d816a7a6 100644 --- a/examples/custom_models/seamless4mt_model.py +++ b/examples/custom_models/local_mt_model.py @@ -25,7 +25,12 @@ import pycountry from tqdm import tqdm -from transformers import AutoProcessor, SeamlessM4Tv2ForTextToText +from transformers import ( + AutoModelForSeq2SeqLM, + AutoProcessor, + AutoTokenizer, + SeamlessM4Tv2ForTextToText, +) from lighteval.data import GenerativeTaskDataset from lighteval.models.abstract_model import LightevalModel, ModelInfo, TokenSequence @@ -45,10 +50,42 @@ logger = logging.getLogger(__name__) -class Seamless4MTClient(LightevalModel): +class LocalMTClient(LightevalModel): + """ + A custom model implementation for local machine translation models, specifically supporting: + - SeamlessM4T v2 models from Meta + - MADLAD-400 models from Google + + This class provides a unified interface for both model families while handling their different + tokenization and generation approaches transparently. + + Args: + config (CustomModelConfig): Configuration containing: + - model (str): Model identifier/path (e.g. "facebook/seamless-m4t-v2-large" or "google/madlad400-7b-mt") + - model_definition_file_path (str): Path to this model definition file + env_config: Environment configuration (unused) + + The model automatically detects whether to load SeamlessM4T or MADLAD based on the model identifier + and initializes the appropriate tokenizer and model. + + Translation tasks should specify the source and target languages in the format: + "{task_name}|{...}:{src}-{tgt}" + where src and tgt are ISO language codes (2 or 3 letter codes supported). + + Example: + ```lighteval custom facebook/seamless-m4t-v2-large examples/custom_models/local_mt_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10 + ``` + + Note: + - SeamlessM4T models use the AutoProcessor for tokenization + - MADLAD models use the standard AutoTokenizer + - Language codes are automatically converted to 3-letter ISO codes for SeamlessM4T + """ + def __init__(self, config, env_config) -> None: self.model = config.model self.model_definition_file_path = config.model_definition_file_path + self.batch_size = 32 self.model_info = ModelInfo( model_name=config.model, @@ -56,8 +93,18 @@ def __init__(self, config, env_config) -> None: model_dtype=None, model_size="", ) - self._tokenizer = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large") - self._model = SeamlessM4Tv2ForTextToText.from_pretrained("facebook/seamless-m4t-v2-large") + + # Update model initialization to handle both models + if "seamless-m4t" in config.model: + self._tokenizer = AutoProcessor.from_pretrained(config.model) + self._model = SeamlessM4Tv2ForTextToText.from_pretrained(config.model) + self.model_type = "seamless-4mt" + elif "madlad400" in config.model: + self._tokenizer = AutoTokenizer.from_pretrained(config.model) + self._model = AutoModelForSeq2SeqLM.from_pretrained(config.model) + self.model_type = "madlad400" + else: + raise ValueError(f"Unsupported model: {config.model}") def _convert_to_iso3(self, lang_code: str) -> str: """Convert 2-letter ISO code to 3-letter ISO code.""" @@ -86,24 +133,27 @@ def greedy_until( def get_langs(task_name: str) -> tuple[str, str]: src, tgt = task_name.split("|")[1].split(":")[-1].split("-") - return self._convert_to_iso3(src), self._convert_to_iso3(tgt) + if self.model_type == "seamless-4mt": + return self._convert_to_iso3(src), self._convert_to_iso3(tgt) + return src, tgt - # Prepare all inputs first + # Prepare all inputs first for creating the GenerativeTaskDataset prepared_requests = [] for request in requests: src_lang, tgt_lang = get_langs(request.task_name) request.context = request.context.replace(f"{src_lang.upper()}: ", "").replace( f"\n{tgt_lang.upper()}: ", "" ) - request.tokenized_context = self._tokenizer( - text=request.context, src_lang=src_lang, return_tensors="pt", padding=True - ) + if self.model_type == "madlad400": + request.context = f"<2{tgt_lang}> {request.context}" + + request.tokenized_context = self.tok_encode(request.context) prepared_requests.append(request) # Create dataset after preparation dataset = GenerativeTaskDataset(requests=prepared_requests, num_dataset_splits=self.DATASET_SPLITS) results = [] - batch_size = override_bs or 32 + batch_size = override_bs or self.batch_size for split_start, split_end in tqdm( dataset.splits_start_end_iterator(), @@ -123,18 +173,25 @@ def get_langs(task_name: str) -> tuple[str, str]: batch_texts = [r.context for r in batch] src_lang = get_langs(batch[0].task_name)[0] # All source languages should be the same in a batch - # Unpack the tokenizer output into input_ids and attention_mask - input_ids, attention_mask = self._tokenizer( - text=batch_texts, src_lang=src_lang, return_tensors="pt", padding=True - ).values() + # This is the tokenization step that really counts, as it actually gets used + tokenizer_kwargs = {"text": batch_texts, "return_tensors": "pt", "padding": True} + if self.model_type == "seamless-4mt": + tokenizer_kwargs["src_lang"] = src_lang + + input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).values() tgt_langs = [get_langs(r.task_name)[1] for r in batch] assert set(tgt_langs) == {tgt_langs[0]}, "All target languages must be the same" # Use unpacked values directly - output_ids = self._model.generate( - input_ids=input_ids, attention_mask=attention_mask, tgt_lang=tgt_langs[0] - ) + generate_kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + if self.model_type == "seamless-4mt": + generate_kwargs["tgt_lang"] = tgt_langs[0] + + output_ids = self._model.generate(**generate_kwargs) translations = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True) # Create responses for the batch @@ -155,9 +212,7 @@ def tokenizer(self): return self._tokenizer def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence: - return self._tokenizer( - text=str_to_encode, return_tensors="pt", padding=True, add_special_tokens=add_special_tokens or False - ) + return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False) @property def add_special_tokens(self) -> bool: From b7106e4810092580eaf4421cd0e3ca5c2d00cdb5 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Wed, 18 Dec 2024 16:21:18 +0100 Subject: [PATCH 17/21] Make sure generation can happen on the GPU. --- examples/custom_models/local_mt_model.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/custom_models/local_mt_model.py b/examples/custom_models/local_mt_model.py index 0d816a7a6..b0b1328a3 100644 --- a/examples/custom_models/local_mt_model.py +++ b/examples/custom_models/local_mt_model.py @@ -24,6 +24,7 @@ from typing import Optional import pycountry +import torch from tqdm import tqdm from transformers import ( AutoModelForSeq2SeqLM, @@ -86,6 +87,7 @@ def __init__(self, config, env_config) -> None: self.model = config.model self.model_definition_file_path = config.model_definition_file_path self.batch_size = 32 + self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model_info = ModelInfo( model_name=config.model, @@ -106,6 +108,9 @@ def __init__(self, config, env_config) -> None: else: raise ValueError(f"Unsupported model: {config.model}") + self._model.to(self.device) + self._model.eval() + def _convert_to_iso3(self, lang_code: str) -> str: """Convert 2-letter ISO code to 3-letter ISO code.""" try: @@ -166,7 +171,9 @@ def get_langs(task_name: str) -> tuple[str, str]: current_requests = dataset.sorted_data[split_start:split_end] # Process in batches - for batch_idx in range(0, len(current_requests), batch_size): + for batch_idx in tqdm( + range(0, len(current_requests), batch_size), desc="Batches", position=1, disable=False + ): batch = current_requests[batch_idx : batch_idx + batch_size] # Batch tokenize all inputs together instead of concatenating pre-tokenized inputs @@ -178,15 +185,19 @@ def get_langs(task_name: str) -> tuple[str, str]: if self.model_type == "seamless-4mt": tokenizer_kwargs["src_lang"] = src_lang - input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).values() + input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).to(self.device).values() tgt_langs = [get_langs(r.task_name)[1] for r in batch] assert set(tgt_langs) == {tgt_langs[0]}, "All target languages must be the same" + generation_sizes = [r.generation_size for r in batch] + assert set(generation_sizes) == {generation_sizes[0]}, "All generation sizes must be the same" + # Use unpacked values directly generate_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, + "max_new_tokens": generation_sizes[0], } if self.model_type == "seamless-4mt": generate_kwargs["tgt_lang"] = tgt_langs[0] @@ -212,7 +223,7 @@ def tokenizer(self): return self._tokenizer def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence: - return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False) + return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False).to(self.device) @property def add_special_tokens(self) -> bool: From a7d176c01a320d86a11fac46883c7c585d5c7d61 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 19 Dec 2024 14:28:21 +0100 Subject: [PATCH 18/21] Fixed issue with src and tgt lang for seamless model. --- examples/custom_models/local_mt_model.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/custom_models/local_mt_model.py b/examples/custom_models/local_mt_model.py index b0b1328a3..93860e416 100644 --- a/examples/custom_models/local_mt_model.py +++ b/examples/custom_models/local_mt_model.py @@ -74,7 +74,7 @@ class LocalMTClient(LightevalModel): where src and tgt are ISO language codes (2 or 3 letter codes supported). Example: - ```lighteval custom facebook/seamless-m4t-v2-large examples/custom_models/local_mt_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10 + ```lighteval custom facebook/seamless-m4t-v2-large examples/custom_models/local_mt_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10 --save-details ``` Note: @@ -101,6 +101,10 @@ def __init__(self, config, env_config) -> None: self._tokenizer = AutoProcessor.from_pretrained(config.model) self._model = SeamlessM4Tv2ForTextToText.from_pretrained(config.model) self.model_type = "seamless-4mt" + self.batch_size = 1 + logger.info( + "Using batch size of 1 for seamless-4mt model because it the target language needs to be set for the entire batch." + ) elif "madlad400" in config.model: self._tokenizer = AutoTokenizer.from_pretrained(config.model) self._model = AutoModelForSeq2SeqLM.from_pretrained(config.model) @@ -176,20 +180,17 @@ def get_langs(task_name: str) -> tuple[str, str]: ): batch = current_requests[batch_idx : batch_idx + batch_size] - # Batch tokenize all inputs together instead of concatenating pre-tokenized inputs + # Batch tokenize all inputs together instead of concatenating pre-tokenized inputs because of the padding batch_texts = [r.context for r in batch] - src_lang = get_langs(batch[0].task_name)[0] # All source languages should be the same in a batch # This is the tokenization step that really counts, as it actually gets used tokenizer_kwargs = {"text": batch_texts, "return_tensors": "pt", "padding": True} if self.model_type == "seamless-4mt": + src_lang = get_langs(batch[0].task_name)[0] tokenizer_kwargs["src_lang"] = src_lang input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).to(self.device).values() - tgt_langs = [get_langs(r.task_name)[1] for r in batch] - assert set(tgt_langs) == {tgt_langs[0]}, "All target languages must be the same" - generation_sizes = [r.generation_size for r in batch] assert set(generation_sizes) == {generation_sizes[0]}, "All generation sizes must be the same" @@ -200,7 +201,8 @@ def get_langs(task_name: str) -> tuple[str, str]: "max_new_tokens": generation_sizes[0], } if self.model_type == "seamless-4mt": - generate_kwargs["tgt_lang"] = tgt_langs[0] + tgt_lang = get_langs(batch[0].task_name)[1] + generate_kwargs["tgt_lang"] = tgt_lang output_ids = self._model.generate(**generate_kwargs) translations = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True) From f1ba65c5020cbd38d2da7506645933aab80eebf7 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Thu, 19 Dec 2024 17:09:43 +0100 Subject: [PATCH 19/21] Added cleanup to free the GPU memory again. --- examples/custom_models/local_mt_model.py | 26 ++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/examples/custom_models/local_mt_model.py b/examples/custom_models/local_mt_model.py index 93860e416..8e3f49184 100644 --- a/examples/custom_models/local_mt_model.py +++ b/examples/custom_models/local_mt_model.py @@ -220,6 +220,32 @@ def get_langs(task_name: str) -> tuple[str, str]: return dataset.get_original_order(results) + def cleanup(self): + import gc + + logger.info("Cleaning up GPU memory for local MT client.") + + # Show GPU memory before cleanup + if torch.cuda.is_available(): + logger.info(f"GPU memory before cleanup: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") + + # Delete model and move to CPU + if hasattr(self, "_model"): + self._model.cpu() + del self._model + self._model = None + + if hasattr(self, "_tokenizer"): + del self._tokenizer + self._tokenizer = None + + torch.cuda.empty_cache() + gc.collect() + + # Show GPU memory after cleanup + if torch.cuda.is_available(): + logger.info(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") + @property def tokenizer(self): return self._tokenizer From ace6e599330c660e84fd32a796a2a152d249a47b Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Sun, 22 Dec 2024 11:40:22 +0100 Subject: [PATCH 20/21] Fix dependency issues by switching to deep-translator. --- .../google-translate-requirements-freeze.txt | 114 ------------------ .../custom_models/google_translate_model.py | 21 ++-- 2 files changed, 8 insertions(+), 127 deletions(-) delete mode 100644 examples/custom_models/google-translate-requirements-freeze.txt diff --git a/examples/custom_models/google-translate-requirements-freeze.txt b/examples/custom_models/google-translate-requirements-freeze.txt deleted file mode 100644 index 1e6389f8d..000000000 --- a/examples/custom_models/google-translate-requirements-freeze.txt +++ /dev/null @@ -1,114 +0,0 @@ -absl-py==2.1.0 -accelerate==1.2.0 -aenum==3.1.15 -aiohappyeyeballs==2.4.4 -aiohttp==3.11.10 -aiosignal==1.3.1 -annotated-types==0.7.0 -anyio==4.7.0 -attrs==24.2.0 -blis==0.7.11 -catalogue==2.0.10 -certifi==2024.8.30 -chardet==3.0.4 -charset-normalizer==3.4.0 -click==8.1.7 -cloudpathlib==0.16.0 -colorama==0.4.6 -colorlog==6.9.0 -confection==0.1.5 -cymem==2.0.10 -DataProperty==1.0.1 -datasets==3.2.0 -dill==0.3.8 -diskcache==5.6.3 -filelock==3.16.1 -frozenlist==1.5.0 -fsspec==2024.9.0 -gitdb==4.0.11 -GitPython==3.1.43 -googletrans==4.0.0rc1 -h11==0.9.0 -h2==3.2.0 -hpack==3.0.0 -hstspreload==2024.12.1 -httpcore==0.9.1 -httpx==0.13.3 -huggingface-hub==0.26.5 -hyperframe==5.2.0 -idna==2.10 -Jinja2==3.1.4 -joblib==1.4.2 -langcodes==3.5.0 -language_data==1.3.0 -lxml==5.3.0 -marisa-trie==1.2.1 -markdown-it-py==3.0.0 -MarkupSafe==3.0.2 -mbstrdecoder==1.1.3 -mdurl==0.1.2 -mpmath==1.3.0 -multidict==6.1.0 -multiprocess==0.70.16 -murmurhash==1.0.11 -networkx==3.4.2 -nltk==3.9.1 -numpy==1.26.4 -packaging==24.2 -pandas==2.2.3 -pathvalidate==3.2.1 -portalocker==3.0.0 -preshed==3.0.9 -propcache==0.2.1 -protobuf==3.20.3 -psutil==6.1.0 -pyarrow==18.1.0 -pycountry==24.6.1 -pydantic==2.10.3 -pydantic_core==2.27.1 -Pygments==2.18.0 -pytablewriter==1.2.0 -python-dateutil==2.9.0.post0 -pytz==2024.2 -PyYAML==6.0.2 -regex==2024.11.6 -requests==2.32.3 -rfc3986==1.5.0 -rich==13.9.4 -rouge_score==0.1.2 -sacrebleu==2.4.3 -safetensors==0.4.5 -scikit-learn==1.6.0 -scipy==1.14.1 -sentencepiece==0.2.0 -setuptools==75.1.0 -six==1.17.0 -smart-open==6.4.0 -smmap==5.0.1 -sniffio==1.3.1 -spacy==3.7.2 -spacy-legacy==3.0.12 -spacy-loggers==1.0.5 -srsly==2.5.0 -sympy==1.13.3 -tabledata==1.3.3 -tabulate==0.9.0 -tcolorpy==0.1.6 -tenacity==9.0.0 -termcolor==2.3.0 -thinc==8.2.5 -threadpoolctl==3.5.0 -tokenizers==0.21.0 -torch==2.4.1 -tqdm==4.67.1 -transformers==4.47.0 -typepy==1.3.2 -typer==0.9.4 -typing_extensions==4.12.2 -tzdata==2024.2 -urllib3==2.2.3 -wasabi==1.1.3 -weasel==0.3.4 -wheel==0.44.0 -xxhash==3.5.0 -yarl==1.18.3 diff --git a/examples/custom_models/google_translate_model.py b/examples/custom_models/google_translate_model.py index d22d97bae..b6ef145f5 100644 --- a/examples/custom_models/google_translate_model.py +++ b/examples/custom_models/google_translate_model.py @@ -27,9 +27,8 @@ from typing import Optional import diskcache -import httpcore import tenacity -from googletrans import Translator +from deep_translator import GoogleTranslator from tqdm import tqdm from transformers import AutoTokenizer @@ -65,11 +64,8 @@ def __init__(self, config, env_config) -> None: self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility - # Needed to fix some googletrans bug - # https://stackoverflow.com/questions/72796594/attributeerror-module-httpcore-has-no-attribute-synchttptransport#comment136664963_77334618 - setattr(httpcore, "SyncHTTPTransport", "AsyncHTTPProxy") - - self.translator = Translator() + # Deep-translator also supports other translators + self.translator = GoogleTranslator() # Initialize disk cache cache_dir = os.path.join(os.getcwd(), ".translation_cache") @@ -98,15 +94,14 @@ def _translate_with_cache(self, context: str, src_lang: str, tgt_lang: str) -> s return self.cache[cache_key] try: - # If not in cache, translate and store - translation = self.translator.translate(context, src=src_lang, dest=tgt_lang) - result = translation.text + # Updated translation call for deep-translator + self.translator.source = src_lang + self.translator.target = tgt_lang + result = self.translator.translate(context) self.cache[cache_key] = result return result except Exception as e: logger.warning(f"Translation error: {str(e)}. Retrying...") - # Re-initialize translator on error - self.translator = Translator() raise # Let tenacity handle the retry def greedy_until( @@ -161,7 +156,7 @@ def tokenizer(self): return self._tokenizer def tok_encode(self, text: str): - return self.tokenizer.encode(text) + return text @property def add_special_tokens(self) -> bool: From cfd7254d2f27b927dd548c6d046fbd433f41c80b Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Sun, 22 Dec 2024 13:09:56 +0100 Subject: [PATCH 21/21] Made inference code more robust against empty responses. --- examples/custom_models/google_translate_model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/custom_models/google_translate_model.py b/examples/custom_models/google_translate_model.py index b6ef145f5..b60276cc6 100644 --- a/examples/custom_models/google_translate_model.py +++ b/examples/custom_models/google_translate_model.py @@ -76,6 +76,7 @@ def __init__(self, config, env_config) -> None: def _get_cache_key(self, context: str, src_lang: str, tgt_lang: str) -> str: """Generate a unique cache key for the translation request.""" + # IMPORTANT: In case we want to support other translators, we can add the translator name to the key key_string = f"{context}|{src_lang}|{tgt_lang}" return hashlib.md5(key_string.encode()).hexdigest() @@ -91,13 +92,20 @@ def _translate_with_cache(self, context: str, src_lang: str, tgt_lang: str) -> s # Try to get from cache if cache_key in self.cache: - return self.cache[cache_key] + result = self.cache[cache_key] + if result is not None and result != "": + return result + logger.warning("Translation in cache is empty. Removing from cache and retrying...") + del self.cache[cache_key] try: # Updated translation call for deep-translator self.translator.source = src_lang self.translator.target = tgt_lang result = self.translator.translate(context) + if result is None or result == "": + result = "" + self.cache[cache_key] = result return result except Exception as e: @@ -140,6 +148,8 @@ def greedy_until( context = r.context.replace(f"{src_lang.upper()}: ", "").replace(f"\n{tgt_lang.upper()}: ", "") result = self._translate_with_cache(context, src_lang, tgt_lang) + if result is None: + result = "" # Set to empty string to prevent errors in metric computation cur_response = GenerativeResponse( result=result,