Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding inference endpoints models #12

Merged
merged 18 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ keywords = ["evaluation", "nlp", "llm"]
dependencies = [
# Base dependencies
"transformers>=4.36.0",
"huggingface_hub==0.19.4",
"huggingface_hub==0.20.3",
"torch>=2.0",
"GitPython==3.1.31", # for logging
"datasets>=2.14.0",
Expand Down
35 changes: 20 additions & 15 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from torch.utils.data.distributed import DistributedSampler, T_co

from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.tasks.requests import Request
from lighteval.tasks.requests import (
GreedyUntilRequest,
GreedyUntilWithLogitsRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
Request,
)


class DynamicBatchDataset(Dataset):
Expand All @@ -28,6 +35,9 @@ def __init__(
requests (List): A list of requests.
dataset_splits (int): The number of dataset splits.
"""
# We make sure the requests contain the tokenized versions of their values
if any(r.tokenized_context is None for r in requests):
raise ValueError("You passed a request for which tokenization had not happened yet.")

# sort the requests using the collate function and save the original order
enumerated_requests = list(enumerate(requests))
Expand Down Expand Up @@ -124,12 +134,12 @@ def __len__(self) -> int:
"""
return self.split_end - self.split_start

def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request) -> int:
raise NotImplementedError()


class LoglikelihoodDataset(DynamicBatchDataset):
def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request: LoglikelihoodRequest | LoglikelihoodRollingRequest) -> int:
"""
Collates the input data for batching.

Expand All @@ -149,13 +159,12 @@ def _sorting_criteria(self, x) -> int:
Returns:
tuple: A tuple containing the sorted input data.
"""

toks = x[1] + x[2]
toks = request.tokenized_context + request.tokenized_continuation
NathanHB marked this conversation as resolved.
Show resolved Hide resolved
return -len(toks)


class LoglikelihoodSingleTokenDataset(DynamicBatchDataset):
def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int:
"""
Collates the input data for batching.

Expand All @@ -167,19 +176,14 @@ def _sorting_criteria(self, x) -> int:
is useful to simplify the batching logic and more importantly to make
automatic adaptive batches much much easier to implement
- any OOMs will happen right away rather than near the end

Args:
x (tuple): A tuple containing the input data.

Returns:
tuple: A tuple containing the collated data.
"""
toks = x[1] # We take only the prompt, no need for the continuation (since it's a list of single tokens)
# We take only the prompt, no need for the continuation (since it's a list of single tokens)
toks = request.tokenized_context
return -len(toks)


class GenerativeTaskDataset(DynamicBatchDataset):
def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsRequest) -> int:
"""
Collate function for generating batches.

Expand All @@ -189,7 +193,8 @@ def _sorting_criteria(self, x) -> int:
Returns:
Any: The collated data.
"""
toks, (stop_tokens, gen_length) = x
toks = request.tokenized_context
gen_length = request.generation_size
return -(len(toks) + gen_length)


Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.base_model import BaseModel
from lighteval.models.inference_client import ModelClient
from lighteval.models.tgi_model import ModelClient
from lighteval.tasks.lighteval_task import LightevalTask
from lighteval.tasks.requests import Doc, Request, RequestType, TaskExampleId

Expand Down
4 changes: 3 additions & 1 deletion src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, met
raise ValueError(
"You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead."
)
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))]

# Todo: make better system with return_bool_score instead of taking first element
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] # sum(
gold_ixs = as_list(formatted_doc.gold_index)

for metric in metrics:
Expand Down
155 changes: 155 additions & 0 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from abc import ABC, abstractmethod
from typing import Optional, Union

import torch
from transformers import BatchEncoding

from lighteval.models.model_config import EnvConfig
from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
from lighteval.tasks.requests import (
GreedyUntilRequest,
GreedyUntilWithLogitsRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)


TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding]


class LightevalModel(ABC):
DATASET_SPLITS = 4

"""Abstract model class defining the API that every model to plug into lighteval must follow."""

@abstractmethod
def __init__(
self,
config,
env_config: EnvConfig,
):
return NotImplemented

def cleanup(self):
"""Clean up operations if needed, such as closing an endpoint."""
return

@property
@abstractmethod
def tokenizer(self):
raise NotImplementedError

@property
@abstractmethod
def add_special_tokens(self):
raise NotImplementedError

@property
@abstractmethod
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
raise NotImplementedError

@property
def disable_tqdm(self) -> bool:
raise NotImplementedError

def greedy_until_with_logits(
self,
requests: list[GreedyUntilWithLogitsRequest],
override_bs: Optional[int] = None,
) -> list[GenerateReturn]:
"""
Generates sequences greedily until a stopping condition is met,
returning both the generated sequences and the logits.

Args:
requests (list[tuple[str, dict]]): A list of input requests,
where each request is a tuple containing a prompt string and a dictionary of additional parameters.
disable_tqdm (bool, optional): Whether to disable the tqdm progress bar. Defaults to False.
override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None.

Returns:
list[GenerateReturn]: A list of GenerateReturn objects,
where each object contains the generated sequence and the corresponding logits.
"""
return self.greedy_until(
requests=requests,
override_bs=override_bs,
returns_logits=True,
)

@abstractmethod
def greedy_until(
self,
requests: list[GreedyUntilRequest],
returns_logits: bool = False,
override_bs: Optional[int] = None,
) -> list[GenerateReturn]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.

Args:
requests (list[Request]): list of requests containing the context and ending conditions.
returns_logits (bool, optional): Whether to return the logits of the generated responses. Defaults to False.
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
override_bs (int, optional): Override the batch size for generation. Defaults to None.

Returns:
list[GenerateReturn]: list of generated responses.
"""
return NotImplemented

@abstractmethod
def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodReturn]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
return NotImplemented

@abstractmethod
def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs=None
) -> list[LoglikelihoodReturn]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
return NotImplemented

@abstractmethod
def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodSingleTokenReturn]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
return NotImplemented

# Tokenization utils
def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
if add_special_tokens is None:
add_special_tokens = self.add_special_tokens
if isinstance(str_to_encode, str):
return self.tokenizer.encode(str_to_encode, add_special_tokens=add_special_tokens)
return self.tokenizer(
str_to_encode,
padding=True,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)

def tok_encode_pair(self, context, continuation):
"""Encodes a context, continuation pair by taking care of the spaces in between."""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc

def tok_decode(self, tokens: torch.LongTensor) -> list[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
8 changes: 4 additions & 4 deletions src/lighteval/models/adapter_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import nullcontext

import torch
from transformers import AutoModel, PreTrainedTokenizer
from transformers import AutoModelForCausalLM, PreTrainedTokenizer

from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.base_model import BaseModel
Expand All @@ -20,7 +20,7 @@ def _create_auto_tokenizer(self, config: AdapterModelConfig, env_config: EnvConf
# (= the parent model, not the model of interest)
return self._create_auto_tokenizer_with_name(config.base_model, config=config, env_config=env_config)

def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModel:
def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModelForCausalLM:
"""Returns a PeftModel from a base model and a version fined tuned using PEFT."""
torch_dtype = _get_dtype(config.dtype, self._config)
config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel)
Expand All @@ -31,7 +31,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)

if self.accelerator.is_local_main_process if self.accelerator is not None else nullcontext():
hlog(f"Loading model from {adapter_weights} and applying adapter to {config.base_model}")
base = self.AUTO_MODEL_CLASS.from_pretrained(
base = AutoModelForCausalLM.from_pretrained(
config.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=env_config.token
)
# Should pass revision
Expand All @@ -43,7 +43,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)

hlog(f"Loading model from {merged_path}")

model = self.AUTO_MODEL_CLASS.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
merged_path,
max_memory=max_memory,
device_map=device_map,
Expand Down
Loading
Loading