Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bumping brrr model #9

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ optimum = ["optimum==1.12.0"]
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]
adapters = ["peft==0.3.0"]
nanotron = [
"nanotron@git+https://github.com/huggingface/nanotron@8c1a49588d0745a6404644a86547c2dd6a63640e",
"brrr@git+https://github.com/huggingface/brrr@e8a503e2ec08b34eed7522d331aec3bee8cdd29b",
"nanotron@git+https://github.com/huggingface/nanotron@main",
"brrr@git+https://github.com/huggingface/brrr@fix-lighteval",
"tensorboardX"
]

Expand Down
8 changes: 5 additions & 3 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_original_order(self, new_arr: list) -> list:

return original_order

def get_split_start_end(self, split_id: int) -> tuple[int, int]:
def get_set_split_start_end(self, split_id: int) -> tuple[int, int]:
"""
Get the start and end indices of a dataset split.

Expand All @@ -96,7 +96,7 @@ def splits_start_end_iterator(self) -> tuple[int, int]:
tuple: A tuple containing the start and end indices of a split.
"""
for split_id in range(self.dataset_splits):
yield self.get_split_start_end(split_id)
yield self.get_set_split_start_end(split_id)

def __getitem__(self, index) -> Request:
"""
Expand Down Expand Up @@ -189,7 +189,9 @@ def _sorting_criteria(self, x) -> int:
Returns:
Any: The collated data.
"""
toks, (stop_tokens, gen_length) = x
toks = x[0]
meta_data = x[1]
stop_tokens, gen_length = meta_data[0], meta_data[1]
return -(len(toks) + gen_length)


Expand Down
98 changes: 56 additions & 42 deletions src/lighteval/models/brrr_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# flake8: noqa: C901
# flake8: noqa: C901,E1120
import os
import time
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union, Type

import torch
import torch.nn.functional as F
Expand All @@ -28,9 +29,22 @@
from tqdm import tqdm
from transformers import AutoTokenizer, BatchEncoding

from lighteval.data import GenDataset, GenDistributedSampler, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)
from lighteval.data import (
GenDistributedSampler,
GenerativeTaskDataset,
LoglikelihoodDataset,
LoglikelihoodSingleTokenDataset,
)
from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
from lighteval.utils import as_list, find_executable_batch_size
from lighteval.tasks.requests import GreedyUntilRequest
from lighteval.utils import as_list
from lighteval.utils_parallelism import find_executable_batch_size


# from .brrr_generation import GenerationConfig, GenerationInputs, SamplerType, greedy_search_tokenized
Expand All @@ -41,8 +55,7 @@

TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]

# _DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.device]])

STARTING_BATCH_SIZE = 512

class BRRRModel:
# Default max sequence length setting for when no `max_length` is provided
Expand All @@ -68,6 +81,7 @@ def __init__(
s5cmd_numworkers: int = 64,
s5cmd_concurrency: int = 10,
s5cmd_path: str = "/admin/home/thomwolf/miniconda/envs/b4r/bin/s5cmd",
model_class: Optional[Type] = None,
):
"""Initializes a brrr model for evaluation.
Args:
Expand Down Expand Up @@ -120,6 +134,9 @@ def __init__(
self.tokenizer.model_max_length = self.max_length

model_config_cls = self.model_config.__class__.__name__
if model_class is not None:
CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class

if model_config_cls not in CONFIG_TO_MODEL_CLASS:
raise ValueError(
f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported"
Expand Down Expand Up @@ -394,7 +411,7 @@ def _encode_pair(self, context, continuation):
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc

def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | str) -> tuple[list, int]:
def homogeneize_ending_conditions(self, ending_condition: Union[tuple, dict, list, str]) -> tuple[list, int]:
"""Ending conditions are submitted in several possible formats.
By default in lighteval we pass them as tuples (stop sequence, max number of items).
In the harness they sometimes are passed as dicts {"until": .., "max_length": ...} or
Expand Down Expand Up @@ -489,7 +506,7 @@ def loglikelihood_single_token(
disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0),
)

def loglikelihood(self, requests: List[Tuple[str, str]], override_bs=None) -> List[LoglikelihoodReturn]:
def loglikelihood(self, requests: List[LoglikelihoodRequest], override_bs=None) -> List[LoglikelihoodReturn]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.

Expand Down Expand Up @@ -518,7 +535,7 @@ def loglikelihood(self, requests: List[Tuple[str, str]], override_bs=None) -> Li
disable_tqdm=bool(dist.get_rank(self.parallel_context.world_pg) != 0),
)

def loglikelihood_rolling(self, requests: List[Tuple[str, str]], override_bs=None) -> List[LoglikelihoodReturn]:
def loglikelihood_rolling(self, requests: List[LoglikelihoodRollingRequest], override_bs=None) -> List[LoglikelihoodReturn]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
tokenized_reqs = []

Expand Down Expand Up @@ -608,7 +625,7 @@ def prepare_batch(

# when too long to fit in context, truncate from the left
inp = torch.tensor(
(tokens)[-max_context:], # [:-1],
tokens[-max_context:], # [:-1],
dtype=torch.long,
)

Expand Down Expand Up @@ -699,7 +716,7 @@ def _get_subsets(self, dataset, dataset_splits):

@torch.inference_mode()
def _loglikelihood_single_token(
self, requests, disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
self, requests: List[LoglikelihoodSingleTokenRequest], disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
) -> List[LoglikelihoodSingleTokenReturn]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests)
res = []
Expand Down Expand Up @@ -921,7 +938,7 @@ def _loglikelihood_single_token(
# We are in a process which return no output (beginning/middle of the PP group)
return []

return dataset.ordered.get_original(res)
return dataset.get_original_order(res)

@torch.inference_mode()
def _loglikelihood_tokens(
Expand All @@ -932,26 +949,14 @@ def _loglikelihood_tokens(
dataset_splits: int = 1,
return_bool_score: bool = True,
) -> List[LoglikelihoodReturn]:
dataset = LoglikelihoodDataset(requests=requests)
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=dataset_splits)
res = []

# Dataset is sorted in descending size.
# every 20-25% of the dataset we try to double the batch size for speed up
starting_batch_size = 512

total_length, subset_length = self._get_subsets(dataset, dataset_splits)

for s, subset_start in enumerate(
tqdm(
range(0, total_length, subset_length),
disable=disable_tqdm,
position=0,
desc=f"loglikelihood -- Node {dist.get_rank(self.parallel_context.world_pg)}",
)
):
dataset.split_start = subset_start
dataset.split_end = min(subset_start + subset_length, total_length)
starting_batch_size = STARTING_BATCH_SIZE

for s, (split_start, split_end) in tqdm(enumerate(dataset.splits_start_end_iterator())):
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
_, context_enc, continuation_enc = dataset[0]
Expand Down Expand Up @@ -1155,18 +1160,18 @@ def _loglikelihood_tokens(
# print(f"i {i} padded: {r.padded}")

if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
assert len(res) == total_length, "we didn't cover all the data"
assert len(res) == (split_end-split_start), "we didn't cover all the data"

if len(res) == 0:
# We are in a process which return no output (beginning/middle of the PP group)
return []

return dataset.ordered.get_original(res)
return dataset.get_original_order(res)

@torch.inference_mode()
def greedy_until(
self,
requests: List[Tuple[str, dict]],
requests: List[GreedyUntilRequest],
task_names: Optional[List[str]] = None,
returns_logits=False,
disable_tqdm: bool = False,
Expand All @@ -1178,15 +1183,24 @@ def greedy_until(
# pull longest context sample from request
if task_names:
enc_inputs = [
(self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), task_name)
for req, task_name in zip(requests, task_names)
(index, (
self.tok_encode(req.context),
self.homogeneize_ending_conditions((req.stop_sequence, req.generation_size)),
task_name,
))
for index, (req, task_name) in enumerate(zip(requests, task_names))
]
else:
enc_inputs = [
(self.tok_encode(req[0]), self.homogeneize_ending_conditions(req[1]), None) for req in requests
(index, (
self.tok_encode(req.context),
self.homogeneize_ending_conditions((req.stop_sequence, req.generation_size)),
None,
))
for index, req in enumerate(requests)
]

dataset = GenDataset(requests=enc_inputs)
dataset = GenerativeTaskDataset(requests=enc_inputs, dataset_splits=dataset_splits)
res = []

# Dataset is sorted in descending size.
Expand All @@ -1195,20 +1209,20 @@ def greedy_until(

total_length, subset_length = self._get_subsets(dataset, dataset_splits)

for s, subset_start in enumerate(
for s, _ in enumerate(
tqdm(
range(0, total_length, subset_length),
disable=disable_tqdm,
position=0,
dataset.splits_start_end_iterator(),
total=dataset_splits,
desc=f"greedy -- Node {dist.get_rank(self.parallel_context.world_pg)}",
position=0,
disable=disable_tqdm,
)
):
dataset.split_start = subset_start
dataset.split_end = min(subset_start + subset_length, total_length)

# print(dataset[0])
_, (context_enc, _, _) = dataset[0]
max_gen = max(d[1][1][1] for d in dataset)
max_input_length = min(len(context_enc) + max_gen, self.max_length)
# max_input_length = len(context_enc)
batch_size = self._get_batch_size(
override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size
)
Expand Down Expand Up @@ -1360,7 +1374,7 @@ def greedy_until(
# We are in a process which return no output (beginning/middle of the PP group)
return []

return dataset.ordered.get_original(res)
return dataset.get_original_order(res)


class MultiTokenEOSCriteria(transformers.StoppingCriteria):
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class TGIModelConfig:
inference_server_auth: str


def create_model_config(args, accelerator: Accelerator): # noqa C901
def create_model_config(args, accelerator: "Accelerator"): # noqa C901
# Tests
if args.inference_server_address is not None and args.model_args is not None:
raise ValueError("You cannot both use an inference server and load a model from its checkpoint.")
Expand Down
1 change: 1 addition & 0 deletions src/main_brrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def main(args):
lm=model,
max_samples=lighteval_config.tasks.max_samples,
evaluation_tracker=evaluation_tracker,
use_chat_template=False,
)

with htrack_block("Setting seeds and waiting for all processes"):
Expand Down
Loading