Skip to content

Commit

Permalink
Enabled passing through system prompt to the models in the requests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 16, 2024
1 parent cca1446 commit 1a10351
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 13 deletions.
1 change: 0 additions & 1 deletion src/lighteval/models/endpoints/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def greedy_until(
Args:
requests (list[Request]): list of requests containing the context and ending conditions.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
Returns:
Expand Down
25 changes: 16 additions & 9 deletions src/lighteval/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ def __init__(self, config, env_config) -> None:
self.model = config.model
self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility
self.pairwise_tokenization = False
# TODO: Pass the system prompt from the pipeline through.
self.system_prompt = "You are a helpful assistant."
litellm.drop_params = True
litellm.verbose = True

def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence):
def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt):
for attempt in range(self.API_MAX_RETRY):
try:
if self.provider == "anthropic":
Expand All @@ -110,7 +108,7 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se

response = litellm.completion(
model=self.model,
messages=[{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}],
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
max_completion_tokens=max_new_tokens if max_new_tokens > 0 else None,
logprobs=return_logits if self.provider == "openai" else None,
stop=stop_sequence,
Expand All @@ -137,17 +135,23 @@ def __call_api_parallel(
max_new_tokens: int | list[int],
num_samples: int | list[int],
stop_sequence: list[str] | None = None,
system_prompt: str | list[str] = None,
):
results = []

return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits
max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens
num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples
stop_sequencess = [stop_sequence for _ in prompts]

system_prompts = [system_prompt for _ in prompts] if not isinstance(system_prompt, list) else system_prompt
assert (
len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(stop_sequencess)
), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}"
len(prompts)
== len(return_logitss)
== len(max_new_tokenss)
== len(num_sampless)
== len(stop_sequencess)
== len(system_prompts)
), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences, system_prompts should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}, {len(system_prompts)}"

with ThreadPoolExecutor(self.CONCURENT_CALLS) as executor:
for entry in tqdm(
Expand All @@ -158,6 +162,7 @@ def __call_api_parallel(
max_new_tokenss,
num_sampless,
stop_sequencess,
system_prompts,
),
total=len(prompts),
):
Expand All @@ -178,7 +183,6 @@ def greedy_until(
Args:
requests (list[Request]): list of requests containing the context and ending conditions.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
Returns:
Expand All @@ -202,8 +206,11 @@ def greedy_until(
return_logits = dataset[0].use_logits
num_samples = dataset[0].num_samples
stop_sequence = requests[0].stop_sequence
system_prompt = requests[0].system_prompt

responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples, stop_sequence)
responses = self.__call_api_parallel(
contexts, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt
)

for response in responses:
result: list[str] = [choice.message.content for choice in response.choices]
Expand Down
17 changes: 14 additions & 3 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def eval_docs(self) -> list[Doc]:
return self._docs

def construct_requests(
self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str
self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str, system_prompt: str
) -> Dict[RequestType, List[Request]]:
"""
Constructs a list of requests from the task based on the given parameters.
Expand All @@ -349,7 +349,7 @@ def construct_requests(
ctx (str): Context, which is the few shot examples + the query.
document_id_seed (str): Index of the document in the task appended with the seed used for the few shot sampling.
current_task_name (str): Name of the current task.
system_prompt (str): System prompt to use for the request.
Returns:
dict[RequestType, List[Request]]: List of requests.
"""
Expand All @@ -365,6 +365,7 @@ def construct_requests(
context=context,
choice=gold,
metric_categories=[MetricCategory.TARGET_PERPLEXITY],
system_prompt=system_prompt,
)
for i, gold in enumerate(golds)
]
Expand All @@ -376,6 +377,7 @@ def construct_requests(
request_index=0,
context=context,
metric_categories=[MetricCategory.PERPLEXITY],
system_prompt=system_prompt,
)
]
if self.has_metric_category[MetricCategory.GENERATIVE_SAMPLING]:
Expand All @@ -395,6 +397,7 @@ def construct_requests(
do_sample=True,
use_logits=False,
metric_categories=[MetricCategory.GENERATIVE_SAMPLING],
system_prompt=system_prompt,
)
]
if (
Expand All @@ -421,6 +424,7 @@ def construct_requests(
]
if self.has_metric_category[c]
],
system_prompt=system_prompt,
)
]
if (
Expand All @@ -439,6 +443,7 @@ def construct_requests(
for c in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI]
if self.has_metric_category[c]
],
system_prompt=system_prompt,
)
for i, choice in enumerate(formatted_doc.choices)
]
Expand All @@ -455,6 +460,7 @@ def construct_requests(
context=formatted_doc.unconditioned_query,
choice=choice,
metric_categories=[MetricCategory.MULTICHOICE_PMI],
system_prompt=system_prompt,
)
for i, choice in enumerate(formatted_doc.choices)
]
Expand All @@ -467,6 +473,7 @@ def construct_requests(
context=context,
choices=formatted_doc.choices,
metric_categories=[MetricCategory.MULTICHOICE_ONE_TOKEN],
system_prompt=system_prompt,
)
]
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]:
Expand All @@ -479,6 +486,7 @@ def construct_requests(
stop_sequence=self.stop_sequence,
generation_size=self.generation_size,
metric_categories=[MetricCategory.LLM_AS_JUDGE_MULTI_TURN],
system_prompt=system_prompt,
)
]
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE]:
Expand All @@ -493,6 +501,7 @@ def construct_requests(
generation_grammar=self.generation_grammar,
num_samples=1,
metric_categories=[MetricCategory.LLM_AS_JUDGE],
system_prompt=system_prompt,
)
]

Expand Down Expand Up @@ -652,7 +661,9 @@ def create_requests_from_tasks( # noqa: C901
# Constructing the requests
cur_task_name = f"{task_name}|{num_fewshot}"
docs[SampleUid(cur_task_name, doc_id_seed)] = doc
req_type_reqs_dict = task.construct_requests(doc, doc.ctx, doc_id_seed, cur_task_name)
req_type_reqs_dict = task.construct_requests(
doc, doc.ctx, doc_id_seed, cur_task_name, system_prompt
)
for req_type, reqs in req_type_reqs_dict.items():
requests[req_type].extend(reqs)

Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/tasks/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ class Request:
request_index (int): The index of the request.
context (str): The context for the request.
metric_categories (list[MetricCategory]): All the metric categories which concern this request
system_prompt (str): System prompt to use for the request.
"""

task_name: str
sample_index: int
request_index: int
context: str
metric_categories: list["MetricCategory"] # noqa F821
system_prompt: Optional[str]


@dataclass
Expand Down

0 comments on commit 1a10351

Please sign in to comment.