diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 9ad55466a..beedaff60 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -15,6 +15,8 @@ title: Add a custom task - local: adding-a-new-metric title: Add a custom metric + - local: evaluating-a-custom-model + title: Evaluate a custom model - local: use-vllm-as-backend title: Use VLLM as backend - local: evaluate-the-model-on-a-server-or-container diff --git a/docs/source/evaluating-a-custom-model.mdx b/docs/source/evaluating-a-custom-model.mdx new file mode 100644 index 000000000..d4ff33cfa --- /dev/null +++ b/docs/source/evaluating-a-custom-model.mdx @@ -0,0 +1,129 @@ +# Evaluating a Custom Model + +Lighteval allows you to evaluate custom model implementations by creating a custom model class that inherits from `LightevalModel`. This is useful when you want to evaluate models that aren't directly supported by the standard backends (transformers, vllm, etc). + +## Creating a Custom Model + +1. Create a Python file containing your custom model implementation. The model must inherit from `LightevalModel` and implement all required methods. + +Here's a basic example: + +```python +from lighteval.models.abstract_model import LightevalModel + +class MyCustomModel(LightevalModel): + def __init__(self, config, env_config): + super().__init__(config, env_config) + # Initialize your model here... + + def greedy_until(self, requests, max_tokens=None, stop_sequences=None): + # Implement generation logic + pass + + def loglikelihood(self, requests, log=True): + # Implement loglikelihood computation + pass + + def loglikelihood_rolling(self, requests): + # Implement rolling loglikelihood computation + pass + + def loglikelihood_single_token(self, requests): + # Implement single token loglikelihood computation + pass +``` + +2. The custom model file should contain exactly one class that inherits from `LightevalModel`. This class will be automatically detected and instantiated when loading the model. + +> [!TIP] +> You can find a complete example of a custom model implementation in `examples/custom_models/google_translate_model.py`. + +## Running the Evaluation + +You can evaluate your custom model using either the command line interface or the Python API. + +### Using the Command Line + +```bash +python -m lighteval custom \ + "google-translate" \ + "examples/custom_models/google_translate_model.py" \ + "lighteval|wmt20:fr-de|0|0" \ + --output-dir results \ + --max-samples 10 +``` + +The command takes three required arguments: +- The model name (used for tracking in results/logs) +- The path to your model implementation file +- The tasks to evaluate on (same format as other backends) + +### Using the Python API + +```python +from lighteval.logging.evaluation_tracker import EvaluationTracker +from lighteval.models.custom.custom_model import CustomModelConfig +from lighteval.pipeline import Pipeline, PipelineParameters, EnvConfig + +# Set up evaluation tracking +evaluation_tracker = EvaluationTracker( + output_dir="results", + save_details=True +) + +# Configure the pipeline +pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.CUSTOM, + env_config=EnvConfig(cache_dir="tmp/") +) + +# Configure your custom model +model_config = CustomModelConfig( + model="my-custom-model", + model_definition_file_path="path/to/my_model.py" +) + +# Create and run the pipeline +pipeline = Pipeline( + tasks="leaderboard|truthfulqa:mc|0|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model_config=model_config +) + +pipeline.evaluate() +pipeline.save_and_push_results() +``` + +## Required Methods + +Your custom model must implement these core methods: + +- `greedy_until`: For generating text until a stop sequence or max tokens is reached +- `loglikelihood`: For computing log probabilities of specific continuations +- `loglikelihood_rolling`: For computing rolling log probabilities of sequences +- `loglikelihood_single_token`: For computing log probabilities of single tokens + +See the `LightevalModel` base class documentation for detailed method signatures and requirements. + +## Best Practices + +1. **Error Handling**: Implement robust error handling in your model methods to gracefully handle edge cases. + +2. **Batching**: Consider implementing efficient batching in your model methods to improve performance. + +3. **Resource Management**: Properly manage any resources (e.g., API connections, model weights) in your model's `__init__` and `__del__` methods. + +4. **Documentation**: Add clear docstrings to your model class and methods explaining any specific requirements or limitations. + +## Example Use Cases + +Custom models are particularly useful for: + +- Evaluating models accessed through custom APIs +- Wrapping models with specialized preprocessing/postprocessing +- Testing novel model architectures +- Evaluating ensemble models +- Integrating with external services or tools + +For a complete example of a custom model that wraps the Google Translate API, see `examples/custom_models/google_translate_model.py`. diff --git a/docs/source/package_reference/models.mdx b/docs/source/package_reference/models.mdx index 9feed4652..3e07a81ce 100644 --- a/docs/source/package_reference/models.mdx +++ b/docs/source/package_reference/models.mdx @@ -28,6 +28,10 @@ [[autodoc]] models.endpoints.tgi_model.TGIModelConfig [[autodoc]] models.endpoints.tgi_model.ModelClient +### Custom Model +[[autodoc]] models.custom.custom_model.CustomModelConfig +[[autodoc]] models.custom.custom_model.CustomModel + ### Open AI Models [[autodoc]] models.endpoints.openai_model.OpenAIClient diff --git a/examples/custom_models/google_translate_model.py b/examples/custom_models/google_translate_model.py new file mode 100644 index 000000000..b60276cc6 --- /dev/null +++ b/examples/custom_models/google_translate_model.py @@ -0,0 +1,200 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import hashlib +import logging +import os +import time +from typing import Optional + +import diskcache +import tenacity +from deep_translator import GoogleTranslator +from tqdm import tqdm +from transformers import AutoTokenizer + +from lighteval.data import GenerativeTaskDataset +from lighteval.models.abstract_model import LightevalModel, ModelInfo +from lighteval.models.model_output import ( + GenerativeResponse, + LoglikelihoodResponse, + LoglikelihoodSingleTokenResponse, +) +from lighteval.tasks.requests import ( + GreedyUntilRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, + LoglikelihoodSingleTokenRequest, +) + + +logger = logging.getLogger(__name__) + + +class GoogleTranslateClient(LightevalModel): + def __init__(self, config, env_config) -> None: + self.model = config.model + self.model_definition_file_path = config.model_definition_file_path + + self.model_info = ModelInfo( + model_name=config.model, + model_sha="", + model_dtype=None, + model_size="", + ) + + self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility + + # Deep-translator also supports other translators + self.translator = GoogleTranslator() + + # Initialize disk cache + cache_dir = os.path.join(os.getcwd(), ".translation_cache") + self.cache = diskcache.Cache(cache_dir) + + self.max_retries = 3 + self.retry_delay = 1 + + def _get_cache_key(self, context: str, src_lang: str, tgt_lang: str) -> str: + """Generate a unique cache key for the translation request.""" + # IMPORTANT: In case we want to support other translators, we can add the translator name to the key + key_string = f"{context}|{src_lang}|{tgt_lang}" + return hashlib.md5(key_string.encode()).hexdigest() + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(multiplier=1, min=4, max=10), + retry=tenacity.retry_if_exception_type((Exception)), + before_sleep=lambda retry_state: time.sleep(1), + ) + def _translate_with_cache(self, context: str, src_lang: str, tgt_lang: str) -> str: + """Translate text using cache if available, otherwise call Google Translate with retry logic.""" + cache_key = self._get_cache_key(context, src_lang, tgt_lang) + + # Try to get from cache + if cache_key in self.cache: + result = self.cache[cache_key] + if result is not None and result != "": + return result + logger.warning("Translation in cache is empty. Removing from cache and retrying...") + del self.cache[cache_key] + + try: + # Updated translation call for deep-translator + self.translator.source = src_lang + self.translator.target = tgt_lang + result = self.translator.translate(context) + if result is None or result == "": + result = "" + + self.cache[cache_key] = result + return result + except Exception as e: + logger.warning(f"Translation error: {str(e)}. Retrying...") + raise # Let tenacity handle the retry + + def greedy_until( + self, + requests: list[GreedyUntilRequest], + override_bs: Optional[int] = None, + ) -> list[GenerativeResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + Results are cached to disk to avoid repeated translations. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerativeResponse]: list of generated responses. + """ + for request in requests: + request.tokenized_context = self.tok_encode(request.context) + + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) + results = [] + + for _ in tqdm( + dataset.splits_start_end_iterator(), + total=dataset.num_dataset_splits, + desc="Splits", + position=0, + disable=False, # self.disable_tqdm, + ): + for r in tqdm(dataset, desc="Batch", position=1, disable=False): + # Extract source and target languages from task name + # Format is like "community|sdst-text_level:de-fr|0" + src_lang, tgt_lang = r.task_name.split("|")[1].split(":")[-1].split("-") + + context = r.context.replace(f"{src_lang.upper()}: ", "").replace(f"\n{tgt_lang.upper()}: ", "") + result = self._translate_with_cache(context, src_lang, tgt_lang) + if result is None: + result = "" # Set to empty string to prevent errors in metric computation + + cur_response = GenerativeResponse( + result=result, + logits=None, + generated_tokens=[], + input_tokens=[], + ) + results.append(cur_response) + + return dataset.get_original_order(results) + + @property + def tokenizer(self): + return self._tokenizer + + def tok_encode(self, text: str): + return text + + @property + def add_special_tokens(self) -> bool: + return False + + @property + def max_length(self) -> int: + """Return the maximum sequence length of the model.""" + return 4096 + + def loglikelihood( + self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError + + def loglikelihood_rolling( + self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """This function is used to compute the log likelihood of the context for perplexity metrics.""" + raise NotImplementedError + + def loglikelihood_single_token( + self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodSingleTokenResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError diff --git a/examples/custom_models/local_mt_model.py b/examples/custom_models/local_mt_model.py new file mode 100644 index 000000000..8e3f49184 --- /dev/null +++ b/examples/custom_models/local_mt_model.py @@ -0,0 +1,285 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +from typing import Optional + +import pycountry +import torch +from tqdm import tqdm +from transformers import ( + AutoModelForSeq2SeqLM, + AutoProcessor, + AutoTokenizer, + SeamlessM4Tv2ForTextToText, +) + +from lighteval.data import GenerativeTaskDataset +from lighteval.models.abstract_model import LightevalModel, ModelInfo, TokenSequence +from lighteval.models.model_output import ( + GenerativeResponse, + LoglikelihoodResponse, + LoglikelihoodSingleTokenResponse, +) +from lighteval.tasks.requests import ( + GreedyUntilRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, + LoglikelihoodSingleTokenRequest, +) + + +logger = logging.getLogger(__name__) + + +class LocalMTClient(LightevalModel): + """ + A custom model implementation for local machine translation models, specifically supporting: + - SeamlessM4T v2 models from Meta + - MADLAD-400 models from Google + + This class provides a unified interface for both model families while handling their different + tokenization and generation approaches transparently. + + Args: + config (CustomModelConfig): Configuration containing: + - model (str): Model identifier/path (e.g. "facebook/seamless-m4t-v2-large" or "google/madlad400-7b-mt") + - model_definition_file_path (str): Path to this model definition file + env_config: Environment configuration (unused) + + The model automatically detects whether to load SeamlessM4T or MADLAD based on the model identifier + and initializes the appropriate tokenizer and model. + + Translation tasks should specify the source and target languages in the format: + "{task_name}|{...}:{src}-{tgt}" + where src and tgt are ISO language codes (2 or 3 letter codes supported). + + Example: + ```lighteval custom facebook/seamless-m4t-v2-large examples/custom_models/local_mt_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10 --save-details + ``` + + Note: + - SeamlessM4T models use the AutoProcessor for tokenization + - MADLAD models use the standard AutoTokenizer + - Language codes are automatically converted to 3-letter ISO codes for SeamlessM4T + """ + + def __init__(self, config, env_config) -> None: + self.model = config.model + self.model_definition_file_path = config.model_definition_file_path + self.batch_size = 32 + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + self.model_info = ModelInfo( + model_name=config.model, + model_sha="", + model_dtype=None, + model_size="", + ) + + # Update model initialization to handle both models + if "seamless-m4t" in config.model: + self._tokenizer = AutoProcessor.from_pretrained(config.model) + self._model = SeamlessM4Tv2ForTextToText.from_pretrained(config.model) + self.model_type = "seamless-4mt" + self.batch_size = 1 + logger.info( + "Using batch size of 1 for seamless-4mt model because it the target language needs to be set for the entire batch." + ) + elif "madlad400" in config.model: + self._tokenizer = AutoTokenizer.from_pretrained(config.model) + self._model = AutoModelForSeq2SeqLM.from_pretrained(config.model) + self.model_type = "madlad400" + else: + raise ValueError(f"Unsupported model: {config.model}") + + self._model.to(self.device) + self._model.eval() + + def _convert_to_iso3(self, lang_code: str) -> str: + """Convert 2-letter ISO code to 3-letter ISO code.""" + try: + return pycountry.languages.get(alpha_2=lang_code.lower()).alpha_3 + except AttributeError: + # If conversion fails, return the original code + return lang_code + + def greedy_until( + self, + requests: list[GreedyUntilRequest], + override_bs: Optional[int] = None, + ) -> list[GenerativeResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + Results are cached to disk to avoid repeated translations. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerativeResponse]: list of generated responses. + """ + + def get_langs(task_name: str) -> tuple[str, str]: + src, tgt = task_name.split("|")[1].split(":")[-1].split("-") + if self.model_type == "seamless-4mt": + return self._convert_to_iso3(src), self._convert_to_iso3(tgt) + return src, tgt + + # Prepare all inputs first for creating the GenerativeTaskDataset + prepared_requests = [] + for request in requests: + src_lang, tgt_lang = get_langs(request.task_name) + request.context = request.context.replace(f"{src_lang.upper()}: ", "").replace( + f"\n{tgt_lang.upper()}: ", "" + ) + if self.model_type == "madlad400": + request.context = f"<2{tgt_lang}> {request.context}" + + request.tokenized_context = self.tok_encode(request.context) + prepared_requests.append(request) + + # Create dataset after preparation + dataset = GenerativeTaskDataset(requests=prepared_requests, num_dataset_splits=self.DATASET_SPLITS) + results = [] + batch_size = override_bs or self.batch_size + + for split_start, split_end in tqdm( + dataset.splits_start_end_iterator(), + total=dataset.num_dataset_splits, + desc="Splits", + position=0, + disable=False, + ): + # Get all requests for this split directly from sorted_data + current_requests = dataset.sorted_data[split_start:split_end] + + # Process in batches + for batch_idx in tqdm( + range(0, len(current_requests), batch_size), desc="Batches", position=1, disable=False + ): + batch = current_requests[batch_idx : batch_idx + batch_size] + + # Batch tokenize all inputs together instead of concatenating pre-tokenized inputs because of the padding + batch_texts = [r.context for r in batch] + + # This is the tokenization step that really counts, as it actually gets used + tokenizer_kwargs = {"text": batch_texts, "return_tensors": "pt", "padding": True} + if self.model_type == "seamless-4mt": + src_lang = get_langs(batch[0].task_name)[0] + tokenizer_kwargs["src_lang"] = src_lang + + input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).to(self.device).values() + + generation_sizes = [r.generation_size for r in batch] + assert set(generation_sizes) == {generation_sizes[0]}, "All generation sizes must be the same" + + # Use unpacked values directly + generate_kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "max_new_tokens": generation_sizes[0], + } + if self.model_type == "seamless-4mt": + tgt_lang = get_langs(batch[0].task_name)[1] + generate_kwargs["tgt_lang"] = tgt_lang + + output_ids = self._model.generate(**generate_kwargs) + translations = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + # Create responses for the batch + for input_tokens, output_tokens, translation in zip(input_ids, output_ids, translations): + results.append( + GenerativeResponse( + input_tokens=input_tokens, + generated_tokens=output_tokens, + result=translation, + logits=None, + ) + ) + + return dataset.get_original_order(results) + + def cleanup(self): + import gc + + logger.info("Cleaning up GPU memory for local MT client.") + + # Show GPU memory before cleanup + if torch.cuda.is_available(): + logger.info(f"GPU memory before cleanup: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") + + # Delete model and move to CPU + if hasattr(self, "_model"): + self._model.cpu() + del self._model + self._model = None + + if hasattr(self, "_tokenizer"): + del self._tokenizer + self._tokenizer = None + + torch.cuda.empty_cache() + gc.collect() + + # Show GPU memory after cleanup + if torch.cuda.is_available(): + logger.info(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") + + @property + def tokenizer(self): + return self._tokenizer + + def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence: + return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False).to(self.device) + + @property + def add_special_tokens(self) -> bool: + return False + + @property + def max_length(self) -> int: + """Return the maximum sequence length of the model.""" + return 4096 + + def loglikelihood( + self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError + + def loglikelihood_rolling( + self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """This function is used to compute the log likelihood of the context for perplexity metrics.""" + raise NotImplementedError + + def loglikelihood_single_token( + self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodSingleTokenResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError diff --git a/src/lighteval/__main__.py b/src/lighteval/__main__.py index e4053813e..e0d12a6c8 100644 --- a/src/lighteval/__main__.py +++ b/src/lighteval/__main__.py @@ -27,6 +27,7 @@ import lighteval.main_accelerate import lighteval.main_baseline +import lighteval.main_custom import lighteval.main_endpoint import lighteval.main_nanotron import lighteval.main_tasks @@ -64,6 +65,7 @@ app.command(rich_help_panel="Evaluation Utils")(lighteval.main_baseline.baseline) app.command(rich_help_panel="Evaluation Backends")(lighteval.main_nanotron.nanotron) app.command(rich_help_panel="Evaluation Backends")(lighteval.main_vllm.vllm) +app.command(rich_help_panel="Evaluation Backends")(lighteval.main_custom.custom) app.add_typer( lighteval.main_endpoint.app, name="endpoint", diff --git a/src/lighteval/main_custom.py b/src/lighteval/main_custom.py new file mode 100644 index 000000000..84c953826 --- /dev/null +++ b/src/lighteval/main_custom.py @@ -0,0 +1,145 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import os +from typing import Optional + +import typer +from typer import Argument, Option +from typing_extensions import Annotated + +from lighteval.models.custom.custom_model import CustomModelConfig + + +app = typer.Typer() + + +TOKEN = os.getenv("HF_TOKEN") +CACHE_DIR: str = os.getenv("HF_HOME", "/scratch") + +HELP_PANNEL_NAME_1 = "Common Paramaters" +HELP_PANNEL_NAME_2 = "Logging Parameters" +HELP_PANNEL_NAME_3 = "Debug Paramaters" +HELP_PANNEL_NAME_4 = "Modeling Paramaters" + + +@app.command(rich_help_panel="Evaluation Backends") +def custom( + # === general === + model_name: Annotated[str, Argument(help="The model name to evaluate")], + model_definition_file_path: Annotated[str, Argument(help="The model definition file path to evaluate")], + tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")], + # === Common parameters === + use_chat_template: Annotated[ + bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANNEL_NAME_4) + ] = False, + system_prompt: Annotated[ + Optional[str], Option(help="Use system prompt for evaluation.", rich_help_panel=HELP_PANNEL_NAME_4) + ] = None, + dataset_loading_processes: Annotated[ + int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANNEL_NAME_1) + ] = 1, + custom_tasks: Annotated[ + Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANNEL_NAME_1) + ] = None, + cache_dir: Annotated[ + str, Option(help="Cache directory for datasets and models.", rich_help_panel=HELP_PANNEL_NAME_1) + ] = CACHE_DIR, + num_fewshot_seeds: Annotated[ + int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANNEL_NAME_1) + ] = 1, + # === saving === + output_dir: Annotated[ + str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = "results", + push_to_hub: Annotated[ + bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = False, + push_to_tensorboard: Annotated[ + bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = False, + public_run: Annotated[ + bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = False, + results_org: Annotated[ + Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = None, + save_details: Annotated[ + bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANNEL_NAME_2) + ] = False, + # === debug === + max_samples: Annotated[ + Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANNEL_NAME_3) + ] = None, + override_batch_size: Annotated[ + int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANNEL_NAME_3) + ] = None, + job_id: Annotated[ + int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANNEL_NAME_3) + ] = 0, +): + """ + Evaluate custom models (can be anything). + """ + from lighteval.logging.evaluation_tracker import EvaluationTracker + from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters + + env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) + evaluation_tracker = EvaluationTracker( + output_dir=output_dir, + save_details=save_details, + push_to_hub=push_to_hub, + push_to_tensorboard=push_to_tensorboard, + public=public_run, + hub_results_org=results_org, + ) + + parallelism_manager = ParallelismManager.CUSTOM + model_config = CustomModelConfig(model=model_name, model_definition_file_path=model_definition_file_path) + + pipeline_params = PipelineParameters( + launcher_type=parallelism_manager, + env_config=env_config, + job_id=job_id, + dataset_loading_processes=dataset_loading_processes, + custom_tasks_directory=custom_tasks, + override_batch_size=override_batch_size, + num_fewshot_seeds=num_fewshot_seeds, + max_samples=max_samples, + use_chat_template=use_chat_template, + system_prompt=system_prompt, + ) + pipeline = Pipeline( + tasks=tasks, + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model_config=model_config, + ) + + pipeline.evaluate() + + pipeline.show_results() + + results = pipeline.get_results() + + pipeline.save_and_push_results() + + return results diff --git a/src/lighteval/models/custom/custom_model.py b/src/lighteval/models/custom/custom_model.py new file mode 100644 index 000000000..1b48a807b --- /dev/null +++ b/src/lighteval/models/custom/custom_model.py @@ -0,0 +1,78 @@ +# MIT License +# +# Copyright (c) 2024 The HuggingFace Team +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from dataclasses import dataclass + + +@dataclass +class CustomModelConfig: + """ + Configuration class for loading custom model implementations in Lighteval. + + This config allows users to define and load their own model implementations by specifying + a Python file containing a custom model class that inherits from LightevalModel. + + The custom model file should contain exactly one class that inherits from LightevalModel. + This class will be automatically detected and instantiated when loading the model. + + Args: + model (str): + An identifier for the model. This can be used to track which model was evaluated + in the results and logs. + + model_definition_file_path (str): + Path to a Python file containing the custom model implementation. This file must + define exactly one class that inherits from LightevalModel. The class should + implement all required methods from the LightevalModel interface. + + Example usage: + ```python + # Define config + config = CustomModelConfig( + model="my-custom-model", + model_definition_file_path="path/to/my_model.py" + ) + + # Example custom model file (my_model.py): + from lighteval.models.abstract_model import LightevalModel + + class MyCustomModel(LightevalModel): + def __init__(self, config, env_config): + super().__init__(config, env_config) + # Custom initialization... + + def greedy_until(self, *args, **kwargs): + # Custom generation logic... + pass + ``` + + An example of a custom model can be found in `examples/custom_models/google_translate_model.py`. + + Notes: + - The custom model class must inherit from LightevalModel and implement all required methods + - Only one class inheriting from LightevalModel should be defined in the file + - The model file is dynamically loaded at runtime, so ensure all dependencies are available + - Exercise caution when loading custom model files as they can execute arbitrary code + """ + + model: str + model_definition_file_path: str diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 68835fda7..46a3b82c1 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -23,6 +23,8 @@ import logging from typing import Union +from lighteval.models.abstract_model import LightevalModel +from lighteval.models.custom.custom_model import CustomModelConfig from lighteval.models.dummy.dummy_model import DummyModel, DummyModelConfig from lighteval.models.endpoints.endpoint_model import ( InferenceEndpointModel, @@ -60,6 +62,7 @@ def load_model( # noqa: C901 InferenceEndpointModelConfig, DummyModelConfig, VLLMModelConfig, + CustomModelConfig, OpenAIModelConfig, LiteLLMModelConfig, ], @@ -96,6 +99,9 @@ def load_model( # noqa: C901 if isinstance(config, VLLMModelConfig): return load_model_with_accelerate_or_default(config=config, env_config=env_config) + if isinstance(config, CustomModelConfig): + return load_custom_model(config=config, env_config=env_config) + if isinstance(config, OpenAIModelConfig): return load_openai_model(config=config, env_config=env_config) @@ -131,6 +137,35 @@ def load_openai_model(config: OpenAIModelConfig, env_config: EnvConfig): return model +def load_custom_model(config: CustomModelConfig, env_config: EnvConfig): + logger.warning(f"Executing custom model code loaded from {config.model_definition_file_path}.") + + import importlib.util + + # Load the Python file + spec = importlib.util.spec_from_file_location("custom_model_module", config.model_definition_file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load file: {config.model_definition_file_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find the first class that inherits from LightevalModel + model_class = None + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, LightevalModel) and attr != LightevalModel: + model_class = attr + break + + if model_class is None: + raise ValueError(f"No class inheriting from LightevalModel found in {config.model_definition_file_path}") + + model = model_class(config, env_config) + + return model + + def load_model_with_inference_endpoints( config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig ): diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index 6a40d2801..13ee5e13f 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -76,6 +76,7 @@ class ParallelismManager(Enum): TGI = auto() OPENAI = auto() VLLM = auto() + CUSTOM = auto() NONE = auto()