Skip to content

Commit

Permalink
use system prompt from the request and use litellm encode functino as…
Browse files Browse the repository at this point in the history
… tokenizer
  • Loading branch information
NathanHB committed Dec 18, 2024
1 parent be77b15 commit d045d92
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 24 deletions.
34 changes: 12 additions & 22 deletions src/lighteval/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from litellm.utils import ModelResponse
from tqdm import tqdm
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset
from lighteval.models.abstract_model import LightevalModel
Expand All @@ -52,6 +51,7 @@

if is_litellm_available():
import litellm
from litellm import encode
from litellm.caching.caching import Cache

logging.getLogger("LiteLLM").setLevel(logging.WARNING)
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(self, config, env_config) -> None:
self.TEMPERATURE = 0.7
self.TOP_P = 0.95
self.model = config.model
self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility
self._tokenizer = encode
self.pairwise_tokenization = False
litellm.drop_params = True
litellm.verbose = True
Expand All @@ -113,17 +113,14 @@ def _prepare_max_new_tokens(self, max_new_tokens):
max_new_tokens = min(max_new_tokens * 10, 32000)
return max_new_tokens

def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt):
def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence):
"""Make API call with retries."""
response = ModelResponse()
for attempt in range(self.API_MAX_RETRY):
try:
stop_sequence = self._prepare_stop_sequence(stop_sequence)
max_new_tokens = self._prepare_max_new_tokens(max_new_tokens)

# Remove system prompt from the main prompt if it's duplicated
prompt = prompt.replace(system_prompt, "")

response = litellm.completion(
model=self.model,
messages=prompt,
Expand Down Expand Up @@ -162,23 +159,16 @@ def __call_api_parallel(
max_new_tokens: int | list[int],
num_samples: int | list[int],
stop_sequence: list[str] | None = None,
system_prompt: str | list[str] = None,
):
results = []

return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits
max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens
num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples
stop_sequencess = [stop_sequence for _ in prompts]
system_prompts = [system_prompt for _ in prompts] if not isinstance(system_prompt, list) else system_prompt
assert (
len(prompts)
== len(return_logitss)
== len(max_new_tokenss)
== len(num_sampless)
== len(stop_sequencess)
== len(system_prompts)
), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences, system_prompts should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}, {len(system_prompts)}"
len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(stop_sequencess)
), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences, system_prompts should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}"

with ThreadPoolExecutor(self.CONCURENT_CALLS) as executor:
for entry in tqdm(
Expand All @@ -189,7 +179,6 @@ def __call_api_parallel(
max_new_tokenss,
num_sampless,
stop_sequencess,
system_prompts,
),
total=len(prompts),
):
Expand Down Expand Up @@ -233,11 +222,8 @@ def greedy_until(
return_logits = dataset[0].use_logits
num_samples = dataset[0].num_samples
stop_sequence = requests[0].stop_sequence
system_prompt = requests[0].system_prompt

responses = self.__call_api_parallel(
contexts, return_logits, max_new_tokens, num_samples, stop_sequence, system_prompt
)
responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples, stop_sequence)

for response in responses:
result: list[str] = [choice.message.content for choice in response.choices]
Expand All @@ -257,8 +243,12 @@ def greedy_until(
def tokenizer(self):
return self._tokenizer

def tok_encode(self, text: str):
return text
def tok_encode(self, text: str | list[str]):
if isinstance(text, list):
toks = [encode(model=self.model, text=t["content"]) for t in text]
toks = [tok for tok in toks if tok]
return toks
return encode(model=self.model, text=text)

@property
def add_special_tokens(self) -> bool:
Expand Down
6 changes: 4 additions & 2 deletions src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def _single_turn_context(
if not use_chat_template:
toks = self.model.tok_encode(output)
else:
toks = "".join([msg["content"] for msg in output])
toks = [self.model.tok_encode(msg["content"]) for msg in output]
toks = [t for ts in toks for t in ts]

# If we need to truncate few-shots to fit in the context
if truncate_few_shots and self.model.max_length is not None and self.model.tokenizer is not None:
Expand All @@ -230,7 +231,8 @@ def _single_turn_context(
if not use_chat_template:
toks = self.model.tok_encode(output)
else:
toks = "".join([msg["content"] for msg in output])
toks = [self.model.tok_encode(msg["content"]) for msg in output]
toks = [t for ts in toks for t in ts]

if isinstance(self.model, LiteLLMClient):
return output, num_effective_fewshots
Expand Down

0 comments on commit d045d92

Please sign in to comment.