From 7c4f5301afc7489bfcf8e5ee1a8592892e374cb3 Mon Sep 17 00:00:00 2001 From: xansar Date: Thu, 25 Jul 2024 11:59:27 +0800 Subject: [PATCH 1/2] fix idx out of range and different subset performance --- utilization/model/model_utils/batch_sampler.py | 2 +- utilization/model/model_utils/conversation.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/utilization/model/model_utils/batch_sampler.py b/utilization/model/model_utils/batch_sampler.py index 8bb19bde..f82f29e1 100644 --- a/utilization/model/model_utils/batch_sampler.py +++ b/utilization/model/model_utils/batch_sampler.py @@ -74,7 +74,7 @@ def check_new_batch(self, queries: List[int], next_data: int) -> bool: current_batch = len(queries) if not self.auto_batch_size: return current_batch > self.batch_size - max_len = max(len(self.data[q]) for q in queries) + max_len = max(len(self.data[q - self.start_from]) for q in queries) if next_data < len(self.data): max_len = max(len(self.data[next_data]), max_len) diff --git a/utilization/model/model_utils/conversation.py b/utilization/model/model_utils/conversation.py index 2047d32a..5d5c88e5 100644 --- a/utilization/model/model_utils/conversation.py +++ b/utilization/model/model_utils/conversation.py @@ -7,6 +7,8 @@ from jinja2.exceptions import TemplateError from jinja2.sandbox import ImmutableSandboxedEnvironment +import copy + from ...chat_templates import DEFAULT_CHAT_CONFIGS, DEFAULT_CHAT_TEMPLATE, add_space, smart_space # legacy types @@ -77,6 +79,7 @@ def from_chat_template( chat_template = "base" if chat_template in DEFAULT_CHAT_CONFIGS: + chat_config = copy.deepcopy(DEFAULT_CHAT_CONFIGS[chat_template]) chat_config = DEFAULT_CHAT_CONFIGS[chat_template] chat_template = DEFAULT_CHAT_TEMPLATE else: From b5d481241fdc1c0974c4465abb78d5bd64cecb3a Mon Sep 17 00:00:00 2001 From: huyiwen <1020030101@qq.com> Date: Fri, 26 Jul 2024 00:16:47 +0800 Subject: [PATCH 2/2] add comments --- .../model/model_utils/batch_sampler.py | 50 ++++++++++++------- utilization/model/model_utils/conversation.py | 34 ++++++------- .../model/model_utils/prefix_caching.py | 17 +++++-- utilization/utils/arguments.py | 1 + utilization/utils/log_results.py | 3 +- 5 files changed, 65 insertions(+), 40 deletions(-) diff --git a/utilization/model/model_utils/batch_sampler.py b/utilization/model/model_utils/batch_sampler.py index f82f29e1..c95990d5 100644 --- a/utilization/model/model_utils/batch_sampler.py +++ b/utilization/model/model_utils/batch_sampler.py @@ -43,38 +43,53 @@ def info_dataset_group( class AutoBatchSizeSampler(Sampler[List[int]]): - def __init__(self, data, batch_size: int, auto_batch_size: bool, start_from: int = 0): + def __init__(self, data, batch_size: int, auto_batch_size: bool, index_offset: int = 0): + """Sampler that automatically adjusts the batch size based on the maximum length of the data. + + Args: + data: The data to sample from. + batch_size: The maximum batch size. + auto_batch_size: Whether to automatically adjust the batch size based on the maximum length of the data. + index_offset: The offset of indices to yield. + """ self.data = [src.to_model_prompt() if hasattr(src, "to_model_prompt") else src for src in data] total = len(self.data) self.batch_size = batch_size self.auto_batch_size = auto_batch_size self.first_max_len = None + self.index_offset = index_offset self.data_order = [[]] - self.start_from = start_from + """The data indices to yield (batches of indices). In convenience of the `__iter__` method, the indices are offset-based: `range(index_offset, index_offset + total)`.""" if not self.auto_batch_size: for i in range(0, total, self.batch_size): - st = i + self.start_from - ed = min(i + self.batch_size, total) + self.start_from + st = i + self.index_offset + ed = min(i + self.batch_size, total) + self.index_offset self.data_order[-1].extend(range(st, ed)) if len(self.data_order[-1]) == self.batch_size: self.data_order.append([]) else: for i in range(total): - self.data_order[-1].append(i + self.start_from) + self.data_order[-1].append(i + self.index_offset) if self.check_new_batch(self.data_order[-1], i + 1): self.data_order.append([]) + + # remove the last empty batches while self.data_order[-1] == []: self.data_order.pop() - logger.debug(f"AutoBatchSizeSampler: {len(self.data_order)} batches starting from {self.start_from}") + logger.debug(f"AutoBatchSizeSampler: {len(self.data_order)} batches starting from {self.index_offset}") - def check_new_batch(self, queries: List[int], next_data: int) -> bool: + def check_new_batch(self, offset_query_indices: List[int], next_data: int) -> bool: """Check the condition to start a new batch.""" - current_batch = len(queries) + current_batch = len(offset_query_indices) if not self.auto_batch_size: return current_batch > self.batch_size - max_len = max(len(self.data[q - self.start_from]) for q in queries) + + # data: 0-based + # offset_query_indices: offset-based + # next_data: 0-based + max_len = max(len(self.data[q - self.index_offset]) for q in offset_query_indices) if next_data < len(self.data): max_len = max(len(self.data[next_data]), max_len) @@ -85,7 +100,6 @@ def check_new_batch(self, queries: List[int], next_data: int) -> bool: batch_size = available_space // max_len batch_size = round_down(batch_size) - # print("!!!", queries, current_batch, batch_size, available_space, max_len, self.first_max_len) return current_batch >= batch_size def __iter__(self) -> Iterator[List[int]]: @@ -162,17 +176,19 @@ def wrapper(): def __iter__(self) -> Iterator[List[int]]: model = self.dataset_collection._datasets[0].model - accumulative = 0 - for total, init_model, self._forward_call in zip(*self._splitted): + accumulative_offset = 0 + + # iterate over the dataset groups + for group_total, init_model, self._forward_call in zip(*self._splitted): iterator, total_prefix_num = init_model() if total_prefix_num > 1 and model.support_cache: sampler = CachePrefixSampler( data=iterator, - total=total, + total=group_total, total_prefix_num=total_prefix_num, batch_size=self.batch_size, auto_batch_size=self.auto_batch_size, - start_from=accumulative, + index_offset=accumulative_offset, ) model.set_cacher(sampler) yield from sampler @@ -182,11 +198,11 @@ def __iter__(self) -> Iterator[List[int]]: # dynamic batch size for vLLM yield from AutoBatchSizeSampler( iterator, - self.batch_size if not self.vllm else total, + self.batch_size if not self.vllm else group_total, self.auto_batch_size and not self.vllm, - start_from=accumulative + index_offset=accumulative_offset ) - accumulative += total + accumulative_offset += group_total def call_model(self, *args, **kwargs) -> List[Any]: """Route the model to call the corresponding `model_evaluation_method`""" diff --git a/utilization/model/model_utils/conversation.py b/utilization/model/model_utils/conversation.py index 5d5c88e5..6fcfd2a6 100644 --- a/utilization/model/model_utils/conversation.py +++ b/utilization/model/model_utils/conversation.py @@ -7,8 +7,6 @@ from jinja2.exceptions import TemplateError from jinja2.sandbox import ImmutableSandboxedEnvironment -import copy - from ...chat_templates import DEFAULT_CHAT_CONFIGS, DEFAULT_CHAT_TEMPLATE, add_space, smart_space # legacy types @@ -54,20 +52,26 @@ def __init__( chat_template: str, special_tokens_map: Optional[Dict[str, Union[str, List[str]]]] = None, ): - self.default_stop = chat_config.pop("default_stop", []) - self.auto_leading_space = chat_config.pop("auto_leading_space", True) - self.final_lstrip = chat_config.pop("final_lstrip", True) - self.final_rstrip = chat_config.pop("final_rstrip", True) - self.merge_system_to_user = chat_config.pop("merge_system_to_user", False) - self.system_user_sep = chat_config.pop("system_user_sep", "\n") + sequences = deepcopy(chat_config) + self.default_stop = sequences.pop("default_stop", []) + self.auto_leading_space = sequences.pop("auto_leading_space", True) + self.final_lstrip = sequences.pop("final_lstrip", True) + self.final_rstrip = sequences.pop("final_rstrip", True) + self.merge_system_to_user = sequences.pop("merge_system_to_user", False) + self.system_user_sep = sequences.pop("system_user_sep", "\n") # api model does not need bos_token - if "bos_token" not in chat_config: - chat_config["bos_token"] = "" + if "bos_token" not in sequences: + sequences["bos_token"] = "" + + if special_tokens_map is not None: + for key, value in special_tokens_map.items(): + if key not in sequences: + sequences[key] = value - self.sequences = chat_config + self.sequences = sequences self.chat_template = chat_template - self.special_tokens_map = special_tokens_map or {} + self.special_tokens_map = deepcopy(special_tokens_map or {}) @classmethod def from_chat_template( @@ -79,7 +83,6 @@ def from_chat_template( chat_template = "base" if chat_template in DEFAULT_CHAT_CONFIGS: - chat_config = copy.deepcopy(DEFAULT_CHAT_CONFIGS[chat_template]) chat_config = DEFAULT_CHAT_CONFIGS[chat_template] chat_template = DEFAULT_CHAT_TEMPLATE else: @@ -89,11 +92,6 @@ def from_chat_template( chat_config = {} chat_template = chat_config - if special_tokens_map is not None: - for key, value in special_tokens_map.items(): - if key not in chat_config: - chat_config[key] = value - return cls(chat_config=chat_config, chat_template=chat_template, special_tokens_map=special_tokens_map) @staticmethod diff --git a/utilization/model/model_utils/prefix_caching.py b/utilization/model/model_utils/prefix_caching.py index b4bbcbc3..76b97a8e 100644 --- a/utilization/model/model_utils/prefix_caching.py +++ b/utilization/model/model_utils/prefix_caching.py @@ -296,7 +296,16 @@ class CachePrefixSampler(Sampler[List[int]], Cacher): Consider a batch of data indexed from 0 to 7 with cache level 2. Assume data 0~3 have the same prefix and 4~7 have the same prefix. We need to yield the data 0 and 4 to cache the prefix, and then yield 0~7 to generate with the cache. - Notes that the data 0 and 4 will be yielded twice in total.""" + Notes that the data 0 and 4 will be yielded twice in total. + + Args: + data: The data to sample from. + total: The total length of data. + total_prefix_num: The number of prefixes to cache. + batch_size: The maximum batch size. + auto_batch_size: Whether to automatically adjust the batch size based on the maximum length of the data. + index_offset: The offset of indices to yield. + """ def __init__( self, @@ -305,14 +314,14 @@ def __init__( total_prefix_num: int, batch_size: int, auto_batch_size: bool = False, - start_from: int = 0, + index_offset: int = 0, ): # split data into (src,) and (src, tgt) self.total_prefix_num = total_prefix_num self.joined_data = [[] for _ in range(self.total_prefix_num)] self.cache_levels = [0] * total - self.start_from = start_from + self.index_offset = index_offset # the batch_size for the kvcache is smaller than the batch_size to avoid OOM cache_batch_size = (batch_size + 1) // 2 @@ -394,7 +403,7 @@ def _get_data_order(self, total): order_idx_by_cache[i] = -1 for o in data_order_with_cache: - o = [i + self.start_from for i in o] + o = [i + self.index_offset for i in o] return data_order_with_cache diff --git a/utilization/utils/arguments.py b/utilization/utils/arguments.py index 12cc0c85..fcd6e70d 100644 --- a/utilization/utils/arguments.py +++ b/utilization/utils/arguments.py @@ -488,6 +488,7 @@ class DatasetArguments: ) continue_from: ClassVar[int] = 0 + """The number of instances (lines) in .json file to resume from. This is set in `PredictionWriter.write_metainfo`.""" # set in `set_logging` with format "{evaluation_results_dir}/{log_filename}.json" evaluation_results_path: ClassVar[Optional[str]] = None diff --git a/utilization/utils/log_results.py b/utilization/utils/log_results.py index fc800c36..f1a2aa58 100644 --- a/utilization/utils/log_results.py +++ b/utilization/utils/log_results.py @@ -82,11 +82,12 @@ def write_metainfo( self.continue_from_path = None return + # load num instances self.continue_from_path = continue_from or self.evaluation_args.continue_from if self.continue_from_path: self.continue_from_instance = self.check_continue() - # load num instances + # set num instances in dataset_args if self.continue_from_instance is not None and continue_from is None: self.dataset_args.continue_from = self.continue_from_instance