diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 19eb2a0e9..0b291f599 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -314,7 +314,6 @@ def tgi( """ Evaluate models using TGI as backend. """ - import yaml from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.models.endpoints.tgi_model import TGIModelConfig @@ -332,14 +331,8 @@ def tgi( # TODO (nathan): better handling of model_args parallelism_manager = ParallelismManager.TGI - with open(model_config_path, "r") as f: - config = yaml.safe_load(f)["model"] - model_config = TGIModelConfig( - inference_server_address=config["instance"]["inference_server_address"], - inference_server_auth=config["instance"]["inference_server_auth"], - model_id=config["instance"]["model_id"], - ) + model_config = TGIModelConfig.from_path(model_config_path) pipeline_params = PipelineParameters( launcher_type=parallelism_manager, diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 0bd6cbbc3..1344e2485 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -111,6 +111,14 @@ def __post_init__(self): @classmethod def from_path(cls, path: str) -> "InferenceEndpointModelConfig": + """Load configuration for inference endpoint model from YAML file path. + + Args: + path (`str`): Path of the model configuration YAML file. + + Returns: + [`InferenceEndpointModelConfig`]: Configuration for inference endpoint model. + """ import yaml with open(path, "r") as f: diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index d95609a50..3f20e4a57 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -51,6 +51,22 @@ class TGIModelConfig: inference_server_auth: str model_id: str + @classmethod + def from_path(cls, path: str) -> "TGIModelConfig": + """Load configuration for TGI endpoint model from YAML file path. + + Args: + path (`str`): Path of the model configuration YAML file. + + Returns: + [`TGIModelConfig`]: Configuration for TGI endpoint model. + """ + import yaml + + with open(path, "r") as f: + config = yaml.safe_load(f)["model"] + return cls(**config["instance"]) + # inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite # the client functions, since they use a different client. diff --git a/tests/models/test_endpoint_model.py b/tests/models/endpoints/test_endpoint_model.py similarity index 100% rename from tests/models/test_endpoint_model.py rename to tests/models/endpoints/test_endpoint_model.py diff --git a/tests/models/endpoints/test_tgi_model.py b/tests/models/endpoints/test_tgi_model.py new file mode 100644 index 000000000..305034278 --- /dev/null +++ b/tests/models/endpoints/test_tgi_model.py @@ -0,0 +1,42 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from dataclasses import asdict + +import pytest + +from lighteval.models.endpoints.tgi_model import TGIModelConfig + + +class TestTGIModelConfig: + @pytest.mark.parametrize( + "config_path, expected_config", + [ + ( + "examples/model_configs/tgi_model.yaml", + {"inference_server_address": "", "inference_server_auth": None, "model_id": None}, + ), + ], + ) + def test_from_path(self, config_path, expected_config): + config = TGIModelConfig.from_path(config_path) + assert asdict(config) == expected_config