Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix idx out of range and different subset performance #276

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions utilization/model/model_utils/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,53 @@ def info_dataset_group(

class AutoBatchSizeSampler(Sampler[List[int]]):

def __init__(self, data, batch_size: int, auto_batch_size: bool, start_from: int = 0):
def __init__(self, data, batch_size: int, auto_batch_size: bool, index_offset: int = 0):
"""Sampler that automatically adjusts the batch size based on the maximum length of the data.

Args:
data: The data to sample from.
batch_size: The maximum batch size.
auto_batch_size: Whether to automatically adjust the batch size based on the maximum length of the data.
index_offset: The offset of indices to yield.
"""
self.data = [src.to_model_prompt() if hasattr(src, "to_model_prompt") else src for src in data]
total = len(self.data)
self.batch_size = batch_size
self.auto_batch_size = auto_batch_size
self.first_max_len = None
self.index_offset = index_offset
self.data_order = [[]]
self.start_from = start_from
"""The data indices to yield (batches of indices). In convenience of the `__iter__` method, the indices are offset-based: `range(index_offset, index_offset + total)`."""

if not self.auto_batch_size:
for i in range(0, total, self.batch_size):
st = i + self.start_from
ed = min(i + self.batch_size, total) + self.start_from
st = i + self.index_offset
ed = min(i + self.batch_size, total) + self.index_offset
self.data_order[-1].extend(range(st, ed))
if len(self.data_order[-1]) == self.batch_size:
self.data_order.append([])
else:
for i in range(total):
self.data_order[-1].append(i + self.start_from)
self.data_order[-1].append(i + self.index_offset)
if self.check_new_batch(self.data_order[-1], i + 1):
self.data_order.append([])

# remove the last empty batches
while self.data_order[-1] == []:
self.data_order.pop()
logger.debug(f"AutoBatchSizeSampler: {len(self.data_order)} batches starting from {self.start_from}")
logger.debug(f"AutoBatchSizeSampler: {len(self.data_order)} batches starting from {self.index_offset}")

def check_new_batch(self, queries: List[int], next_data: int) -> bool:
def check_new_batch(self, offset_query_indices: List[int], next_data: int) -> bool:
"""Check the condition to start a new batch."""

current_batch = len(queries)
current_batch = len(offset_query_indices)
if not self.auto_batch_size:
return current_batch > self.batch_size
max_len = max(len(self.data[q]) for q in queries)

# data: 0-based
# offset_query_indices: offset-based
# next_data: 0-based
max_len = max(len(self.data[q - self.index_offset]) for q in offset_query_indices)
if next_data < len(self.data):
max_len = max(len(self.data[next_data]), max_len)

Expand All @@ -85,7 +100,6 @@ def check_new_batch(self, queries: List[int], next_data: int) -> bool:

batch_size = available_space // max_len
batch_size = round_down(batch_size)
# print("!!!", queries, current_batch, batch_size, available_space, max_len, self.first_max_len)
return current_batch >= batch_size

def __iter__(self) -> Iterator[List[int]]:
Expand Down Expand Up @@ -162,17 +176,19 @@ def wrapper():

def __iter__(self) -> Iterator[List[int]]:
model = self.dataset_collection._datasets[0].model
accumulative = 0
for total, init_model, self._forward_call in zip(*self._splitted):
accumulative_offset = 0

# iterate over the dataset groups
for group_total, init_model, self._forward_call in zip(*self._splitted):
iterator, total_prefix_num = init_model()
if total_prefix_num > 1 and model.support_cache:
sampler = CachePrefixSampler(
data=iterator,
total=total,
total=group_total,
total_prefix_num=total_prefix_num,
batch_size=self.batch_size,
auto_batch_size=self.auto_batch_size,
start_from=accumulative,
index_offset=accumulative_offset,
)
model.set_cacher(sampler)
yield from sampler
Expand All @@ -182,11 +198,11 @@ def __iter__(self) -> Iterator[List[int]]:
# dynamic batch size for vLLM
yield from AutoBatchSizeSampler(
iterator,
self.batch_size if not self.vllm else total,
self.batch_size if not self.vllm else group_total,
self.auto_batch_size and not self.vllm,
start_from=accumulative
index_offset=accumulative_offset
)
accumulative += total
accumulative_offset += group_total

def call_model(self, *args, **kwargs) -> List[Any]:
"""Route the model to call the corresponding `model_evaluation_method`"""
Expand Down
31 changes: 16 additions & 15 deletions utilization/model/model_utils/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,26 @@ def __init__(
chat_template: str,
special_tokens_map: Optional[Dict[str, Union[str, List[str]]]] = None,
):
self.default_stop = chat_config.pop("default_stop", [])
self.auto_leading_space = chat_config.pop("auto_leading_space", True)
self.final_lstrip = chat_config.pop("final_lstrip", True)
self.final_rstrip = chat_config.pop("final_rstrip", True)
self.merge_system_to_user = chat_config.pop("merge_system_to_user", False)
self.system_user_sep = chat_config.pop("system_user_sep", "\n")
sequences = deepcopy(chat_config)
self.default_stop = sequences.pop("default_stop", [])
self.auto_leading_space = sequences.pop("auto_leading_space", True)
self.final_lstrip = sequences.pop("final_lstrip", True)
self.final_rstrip = sequences.pop("final_rstrip", True)
self.merge_system_to_user = sequences.pop("merge_system_to_user", False)
self.system_user_sep = sequences.pop("system_user_sep", "\n")

# api model does not need bos_token
if "bos_token" not in chat_config:
chat_config["bos_token"] = ""
if "bos_token" not in sequences:
sequences["bos_token"] = ""

self.sequences = chat_config
if special_tokens_map is not None:
for key, value in special_tokens_map.items():
if key not in sequences:
sequences[key] = value

self.sequences = sequences
self.chat_template = chat_template
self.special_tokens_map = special_tokens_map or {}
self.special_tokens_map = deepcopy(special_tokens_map or {})

@classmethod
def from_chat_template(
Expand All @@ -86,11 +92,6 @@ def from_chat_template(
chat_config = {}
chat_template = chat_config

if special_tokens_map is not None:
for key, value in special_tokens_map.items():
if key not in chat_config:
chat_config[key] = value

return cls(chat_config=chat_config, chat_template=chat_template, special_tokens_map=special_tokens_map)

@staticmethod
Expand Down
17 changes: 13 additions & 4 deletions utilization/model/model_utils/prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,16 @@ class CachePrefixSampler(Sampler[List[int]], Cacher):
Consider a batch of data indexed from 0 to 7 with cache level 2. Assume data
0~3 have the same prefix and 4~7 have the same prefix. We need to yield the
data 0 and 4 to cache the prefix, and then yield 0~7 to generate with the cache.
Notes that the data 0 and 4 will be yielded twice in total."""
Notes that the data 0 and 4 will be yielded twice in total.

Args:
data: The data to sample from.
total: The total length of data.
total_prefix_num: The number of prefixes to cache.
batch_size: The maximum batch size.
auto_batch_size: Whether to automatically adjust the batch size based on the maximum length of the data.
index_offset: The offset of indices to yield.
"""

def __init__(
self,
Expand All @@ -305,14 +314,14 @@ def __init__(
total_prefix_num: int,
batch_size: int,
auto_batch_size: bool = False,
start_from: int = 0,
index_offset: int = 0,
):

# split data into (src,) and (src, tgt)
self.total_prefix_num = total_prefix_num
self.joined_data = [[] for _ in range(self.total_prefix_num)]
self.cache_levels = [0] * total
self.start_from = start_from
self.index_offset = index_offset

# the batch_size for the kvcache is smaller than the batch_size to avoid OOM
cache_batch_size = (batch_size + 1) // 2
Expand Down Expand Up @@ -394,7 +403,7 @@ def _get_data_order(self, total):
order_idx_by_cache[i] = -1

for o in data_order_with_cache:
o = [i + self.start_from for i in o]
o = [i + self.index_offset for i in o]

return data_order_with_cache

Expand Down
1 change: 1 addition & 0 deletions utilization/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ class DatasetArguments:
)

continue_from: ClassVar[int] = 0
"""The number of instances (lines) in .json file to resume from. This is set in `PredictionWriter.write_metainfo`."""

# set in `set_logging` with format "{evaluation_results_dir}/{log_filename}.json"
evaluation_results_path: ClassVar[Optional[str]] = None
Expand Down
3 changes: 2 additions & 1 deletion utilization/utils/log_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,12 @@ def write_metainfo(
self.continue_from_path = None
return

# load num instances
self.continue_from_path = continue_from or self.evaluation_args.continue_from
if self.continue_from_path:
self.continue_from_instance = self.check_continue()

# load num instances
# set num instances in dataset_args
if self.continue_from_instance is not None and continue_from is None:
self.dataset_args.continue_from = self.continue_from_instance

Expand Down
Loading