Skip to content

Commit

Permalink
[FIX] Fixes vllm backend (#317)
Browse files Browse the repository at this point in the history
* fix vllm
* fix long loglikelihood context in vllm backend
* removes the need for pytest hook function
* fix model max length
  • Loading branch information
NathanHB authored Sep 24, 2024
1 parent 7295c78 commit 4b06b94
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 36 deletions.
5 changes: 2 additions & 3 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def add_special_tokens(self):
def max_length(self) -> int:
return self._max_length

def init_model_parallel(self, model_parallel: bool = None) -> Tuple[bool, Optional[dict], Optional[str]]:
def init_model_parallel(self, model_parallel: bool | None = None) -> Tuple[bool, Optional[dict], Optional[str]]:
"""Compute all the parameters related to model_parallel"""
if not is_accelerate_available():
return False, None, None
Expand All @@ -147,7 +147,7 @@ def init_model_parallel(self, model_parallel: bool = None) -> Tuple[bool, Option
f"the number of local processes is {self.num_local_processes} "
f"and the number of GPUs is {len(max_memory_all_gpus)}"
)
if model_parallel:
if model_parallel is True:
max_memory_all_gpus = get_max_memory() # A dict of the max memory for all the gpus
if "cpu" in max_memory_all_gpus:
del max_memory_all_gpus["cpu"]
Expand Down Expand Up @@ -569,7 +569,6 @@ def greedy_until(
if max_new_tokens is None: # If generation size is not set, we go all the way
max_new_tokens = self.max_length - context_size
else:
print(self.max_length, context_size, max_new_tokens)
max_new_tokens = min(self.max_length - context_size, max_new_tokens)
if max_new_tokens < 1:
max_new_tokens = 1
Expand Down
18 changes: 11 additions & 7 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class BaseModelConfig:
def __post_init__(self):
# Making sure this parameter is a boolean
self.multichoice_continuations_start_space = boolstring_to_bool(self.multichoice_continuations_start_space)
self.model_parallel = boolstring_to_bool(self.model_parallel)
self.compile = boolstring_to_bool(self.compile)

if self.quantization_config is not None and not is_bnb_available():
raise ImportError(NO_BNB_ERROR_MSG)
Expand Down Expand Up @@ -209,19 +211,21 @@ def init_configs(self, env_config: EnvConfig):
@dataclass
class VLLMModelConfig:
pretrained: str
gpu_memory_utilisation: float = 0.8
batch_size: int = -1
revision: str = "main"
gpu_memory_utilisation: float = 0.9 # lower this if you are running out of memory
revision: str = "main" # revision of the model
dtype: str | None = None
tensor_parallel_size: int = 1
data_parallel_size: int = 1
max_model_length: int = 1024
tensor_parallel_size: int = 1 # how many GPUs to use for tensor parallelism
pipeline_parallel_size: int = 1 # how many GPUs to use for pipeline parallelism
data_parallel_size: int = 1 # how many GPUs to use for data parallelism
max_model_length: int | None = None # maximum length of the model, ussually infered automatically. reduce this if you encouter OOM issues, 4096 is usually enough
swap_space: int = 4 # CPU swap space size (GiB) per GPU.
seed: int = 1234
trust_remote_code: bool = False
use_chat_template: bool = False
add_special_tokens: bool = True
multichoice_continuations_start_space: bool = True
multichoice_continuations_start_space: bool = (
True # whether to add a space at the start of each continuation in multichoice generation
)
subfolder: Optional[str] = None


Expand Down
61 changes: 35 additions & 26 deletions src/lighteval/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ def __init__(
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation."""
self._config = config
self._batch_size = config.batch_size
self._max_length = self._init_max_length(config.max_model_length)
self.use_chat_template = config.use_chat_template
self.data_parallel_size = int(config.data_parallel_size)

self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False
self._tokenizer = self._create_auto_tokenizer(config, env_config)

if config.max_model_length is not None:
self._max_length = int(config.max_model_length)
else:
self._max_length = self.tokenizer.model_max_length or self.tokenizer.max_position_embeddings

# If model_parallel is not set we compare the number of processes with the number of GPUs
self.model = self._create_auto_model(config, env_config)

Expand Down Expand Up @@ -120,12 +123,13 @@ def _create_auto_model(self, config: VLLMModelConfig, env_config: EnvConfig) ->
"""
self.model_args = {
"model": config.pretrained,
"gpu_memory_utilization": float(0.8),
"gpu_memory_utilization": float(config.gpu_memory_utilisation),
"revision": config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""),
"dtype": config.dtype,
"trust_remote_code": config.trust_remote_code,
"tensor_parallel_size": int(1),
"max_model_len": int(self._max_length) if self._max_length else None,
"tensor_parallel_size": int(config.tensor_parallel_size),
"pipeline_parallel_size": int(config.pipeline_parallel_size),
"max_model_len": self._max_length,
"swap_space": 4,
"seed": 1234,
}
Expand Down Expand Up @@ -227,30 +231,33 @@ def greedy_until(
# of losing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context_size = len(tokenized["input_ids"][0])
if context_size > self.max_length:
hlog_warn(
f"The context size of your batch ({context_size}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({dataset[0].task_name})
+ ". This is likely to lead to some errors." # noqa C401
)
# There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
if max_new_tokens is None: # If generation size is not set, we go all the way
max_new_tokens = self.max_length - context_size
else:
max_new_tokens = min(self.max_length - context_size, max_new_tokens)
inputs = tokenized["input_ids"]
context_size = len(inputs[0])

# left truncate the inputs to the maximum length
if max_new_tokens is not None:
if context_size + max_new_tokens > self.max_length:
hlog_warn(
f"{context_size + max_new_tokens=} which is greather than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens."
)
context_size = self.max_length - max_new_tokens
inputs = [input[-context_size:] for input in inputs]
else:
if context_size > self.max_length:
hlog_warn(
f"{context_size=} which is greather than {self.max_length=}. Truncating context to {self.max_length} tokens."
)
context_size = self.max_length
inputs = [input[-context_size:] for input in inputs]

vllm_outputs = self._generate(
inputs=tokenized["input_ids"],
inputs=inputs,
max_new_tokens=max_new_tokens,
stop_tokens=stop_tokens,
returns_logits=returns_logits,
num_samples=num_samples,
)

print(f"{len(vllm_outputs)} vllm_outputs")
for vllm_output in vllm_outputs:
output_token_ids = [outputs.token_ids for outputs in vllm_output.outputs]
logprobs = [output.logprobs for output in vllm_output.outputs] or []
Expand Down Expand Up @@ -345,19 +352,21 @@ def _loglikelihood_tokens(

for _ in tqdm(dataset.splits_start_end_iterator()):
# the last token is an eos token, so we don't need to add it
inputs = [
dataset[i].tokenized_context + dataset[i].tokenized_continuation[:-1] for i in range(len(dataset))
]
inputs = [dataset[i].tokenized_context + dataset[i].tokenized_continuation for i in range(len(dataset))]
# Left truncate the inputs to the maximum length
inputs = [input[-self.max_length :] for input in inputs]
outputs = self._generate(inputs, generate=False)

for output, input in zip(outputs, dataset):
continuation_logprobs = []
for token, logprobs in zip(input.tokenized_continuation[-2::-1], output.prompt_logprobs[::-1]):
for token, logprobs in zip(input.tokenized_continuation[::-1], output.prompt_logprobs[::-1]):
continuation_logprobs.append(logprobs[token])
bool_score = all(logprob.rank == 1 for logprob in continuation_logprobs)
continuation_logprobs = [logprob.logprob for logprob in continuation_logprobs]
answer = LoglikelihoodResponse(
result=(sum(continuation_logprobs), bool_score if return_bool_score else None)
input_tokens=input.tokenized_context + input.tokenized_continuation,
generated_tokens=input.tokenized_continuation,
result=(sum(continuation_logprobs), bool_score if return_bool_score else None),
)
res.append(answer)

Expand Down

0 comments on commit 4b06b94

Please sign in to comment.