From 1a10351d61fa9fce973a713818b3bdc200ef723a Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Mon, 16 Dec 2024 16:14:30 +0100 Subject: [PATCH] Enabled passing through system prompt to the models in the requests. --- .../models/endpoints/openai_model.py | 1 - src/lighteval/models/litellm_model.py | 25 ++++++++++++------- src/lighteval/tasks/lighteval_task.py | 17 ++++++++++--- src/lighteval/tasks/requests.py | 2 ++ 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index b2ca25285..8733474d0 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -145,7 +145,6 @@ def greedy_until( Args: requests (list[Request]): list of requests containing the context and ending conditions. - disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. override_bs (int, optional): Override the batch size for generation. Defaults to None. Returns: diff --git a/src/lighteval/models/litellm_model.py b/src/lighteval/models/litellm_model.py index d8bd5e86e..31040a9ba 100644 --- a/src/lighteval/models/litellm_model.py +++ b/src/lighteval/models/litellm_model.py @@ -89,12 +89,10 @@ def __init__(self, config, env_config) -> None: self.model = config.model self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility self.pairwise_tokenization = False - # TODO: Pass the system prompt from the pipeline through. - self.system_prompt = "You are a helpful assistant." litellm.drop_params = True litellm.verbose = True - def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence): + def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt): for attempt in range(self.API_MAX_RETRY): try: if self.provider == "anthropic": @@ -110,7 +108,7 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se response = litellm.completion( model=self.model, - messages=[{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}], + messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], max_completion_tokens=max_new_tokens if max_new_tokens > 0 else None, logprobs=return_logits if self.provider == "openai" else None, stop=stop_sequence, @@ -137,6 +135,7 @@ def __call_api_parallel( max_new_tokens: int | list[int], num_samples: int | list[int], stop_sequence: list[str] | None = None, + system_prompt: str | list[str] = None, ): results = [] @@ -144,10 +143,15 @@ def __call_api_parallel( max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples stop_sequencess = [stop_sequence for _ in prompts] - + system_prompts = [system_prompt for _ in prompts] if not isinstance(system_prompt, list) else system_prompt assert ( - len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(stop_sequencess) - ), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}" + len(prompts) + == len(return_logitss) + == len(max_new_tokenss) + == len(num_sampless) + == len(stop_sequencess) + == len(system_prompts) + ), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences, system_prompts should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}, {len(system_prompts)}" with ThreadPoolExecutor(self.CONCURENT_CALLS) as executor: for entry in tqdm( @@ -158,6 +162,7 @@ def __call_api_parallel( max_new_tokenss, num_sampless, stop_sequencess, + system_prompts, ), total=len(prompts), ): @@ -178,7 +183,6 @@ def greedy_until( Args: requests (list[Request]): list of requests containing the context and ending conditions. - disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. override_bs (int, optional): Override the batch size for generation. Defaults to None. Returns: @@ -202,8 +206,11 @@ def greedy_until( return_logits = dataset[0].use_logits num_samples = dataset[0].num_samples stop_sequence = requests[0].stop_sequence + system_prompt = requests[0].system_prompt - responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples, stop_sequence) + responses = self.__call_api_parallel( + contexts, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt + ) for response in responses: result: list[str] = [choice.message.content for choice in response.choices] diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index ea01f81e4..be14d7445 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -339,7 +339,7 @@ def eval_docs(self) -> list[Doc]: return self._docs def construct_requests( - self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str + self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str, system_prompt: str ) -> Dict[RequestType, List[Request]]: """ Constructs a list of requests from the task based on the given parameters. @@ -349,7 +349,7 @@ def construct_requests( ctx (str): Context, which is the few shot examples + the query. document_id_seed (str): Index of the document in the task appended with the seed used for the few shot sampling. current_task_name (str): Name of the current task. - + system_prompt (str): System prompt to use for the request. Returns: dict[RequestType, List[Request]]: List of requests. """ @@ -365,6 +365,7 @@ def construct_requests( context=context, choice=gold, metric_categories=[MetricCategory.TARGET_PERPLEXITY], + system_prompt=system_prompt, ) for i, gold in enumerate(golds) ] @@ -376,6 +377,7 @@ def construct_requests( request_index=0, context=context, metric_categories=[MetricCategory.PERPLEXITY], + system_prompt=system_prompt, ) ] if self.has_metric_category[MetricCategory.GENERATIVE_SAMPLING]: @@ -395,6 +397,7 @@ def construct_requests( do_sample=True, use_logits=False, metric_categories=[MetricCategory.GENERATIVE_SAMPLING], + system_prompt=system_prompt, ) ] if ( @@ -421,6 +424,7 @@ def construct_requests( ] if self.has_metric_category[c] ], + system_prompt=system_prompt, ) ] if ( @@ -439,6 +443,7 @@ def construct_requests( for c in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI] if self.has_metric_category[c] ], + system_prompt=system_prompt, ) for i, choice in enumerate(formatted_doc.choices) ] @@ -455,6 +460,7 @@ def construct_requests( context=formatted_doc.unconditioned_query, choice=choice, metric_categories=[MetricCategory.MULTICHOICE_PMI], + system_prompt=system_prompt, ) for i, choice in enumerate(formatted_doc.choices) ] @@ -467,6 +473,7 @@ def construct_requests( context=context, choices=formatted_doc.choices, metric_categories=[MetricCategory.MULTICHOICE_ONE_TOKEN], + system_prompt=system_prompt, ) ] if self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]: @@ -479,6 +486,7 @@ def construct_requests( stop_sequence=self.stop_sequence, generation_size=self.generation_size, metric_categories=[MetricCategory.LLM_AS_JUDGE_MULTI_TURN], + system_prompt=system_prompt, ) ] if self.has_metric_category[MetricCategory.LLM_AS_JUDGE]: @@ -493,6 +501,7 @@ def construct_requests( generation_grammar=self.generation_grammar, num_samples=1, metric_categories=[MetricCategory.LLM_AS_JUDGE], + system_prompt=system_prompt, ) ] @@ -652,7 +661,9 @@ def create_requests_from_tasks( # noqa: C901 # Constructing the requests cur_task_name = f"{task_name}|{num_fewshot}" docs[SampleUid(cur_task_name, doc_id_seed)] = doc - req_type_reqs_dict = task.construct_requests(doc, doc.ctx, doc_id_seed, cur_task_name) + req_type_reqs_dict = task.construct_requests( + doc, doc.ctx, doc_id_seed, cur_task_name, system_prompt + ) for req_type, reqs in req_type_reqs_dict.items(): requests[req_type].extend(reqs) diff --git a/src/lighteval/tasks/requests.py b/src/lighteval/tasks/requests.py index cd75ad402..e184f807f 100644 --- a/src/lighteval/tasks/requests.py +++ b/src/lighteval/tasks/requests.py @@ -52,6 +52,7 @@ class Request: request_index (int): The index of the request. context (str): The context for the request. metric_categories (list[MetricCategory]): All the metric categories which concern this request + system_prompt (str): System prompt to use for the request. """ task_name: str @@ -59,6 +60,7 @@ class Request: request_index: int context: str metric_categories: list["MetricCategory"] # noqa F821 + system_prompt: Optional[str] @dataclass