diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index aedc769fb..e66a2d70a 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -39,7 +39,7 @@ GPTQConfig, PretrainedConfig, ) -from transformers.generation.utils import GenerateOutput +from transformers.generation.utils import GenerateOutput, GenerationConfig from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset @@ -126,6 +126,8 @@ class TransformersModelConfig: model at a quantized precision. Needed for 4-bit and 8-bit precision. trust_remote_code (bool): Whether to trust remote code during model loading. + generation_parameters (GenerationParameters): Range of parameters which will affect the generation. + generation_config (GenerationConfig): GenerationConfig object (only passed during manual creation) Methods: __post_init__(): Performs post-initialization checks on the configuration. @@ -154,6 +156,7 @@ class TransformersModelConfig: use_chat_template: bool = False compile: bool = False generation_parameters: GenerationParameters = None + generation_config: GenerationConfig = None def __post_init__(self): # Making sure this parameter is a boolean @@ -180,7 +183,12 @@ def __post_init__(self): if not isinstance(self.device, str): raise ValueError("Current device must be passed as string.") - if not self.generation_parameters: + if self.generation_config and self.generation_parameters: + raise ValueError( + "Can't use both generation_config and generation_parameters argument. Pass the generation parameters to your generation config object" + ) + + if not self.generation_parameters and not self.generation_config: self.generation_parameters = GenerationParameters() def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedConfig: @@ -275,8 +283,11 @@ def __init__( self.model_sha = config.get_model_sha() self.precision = _get_dtype(config.dtype, config=self._config) - self.generation_parameters = config.generation_parameters - self.generation_config_dict = self.generation_parameters.to_transformers_dict() + if config.generation_config is None: + self.generation_parameters = config.generation_parameters + self.generation_config_dict = self.generation_parameters.to_transformers_dict() + else: + self.generation_config_dict = config.generation_config.to_dict() if is_accelerate_available(): model_size, _ = calculate_maximum_sizes(self.model)