From 6a9bb65fcd085967afd7d7eb039f0b72c0dc36be Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Mon, 30 Dec 2024 09:29:20 +0100 Subject: [PATCH 1/3] Added diskcache to base model. --- .../models/transformers/base_model.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/lighteval/models/transformers/base_model.py b/src/lighteval/models/transformers/base_model.py index 9b815d2b0..4b624829d 100644 --- a/src/lighteval/models/transformers/base_model.py +++ b/src/lighteval/models/transformers/base_model.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import hashlib import logging import os from dataclasses import dataclass @@ -28,6 +29,7 @@ import torch import torch.nn.functional as F import transformers +from diskcache import Cache from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from tqdm import tqdm @@ -271,6 +273,8 @@ def __init__( self.pairwise_tokenization = config.pairwise_tokenization + self._generation_cache = Cache(directory=os.path.join(env_config.cache_dir, "generation_cache")) + @classmethod def from_model( cls, @@ -833,7 +837,7 @@ def greedy_until( padded=[sum(mask == 0) for mask in tokenized["attention_mask"]], ) - cur_reponses = self._generate( + cur_reponses = self._generate_cached( batch=prepared_batch, max_new_tokens=max_new_tokens, stop_tokens=stop_tokens, @@ -841,10 +845,52 @@ def greedy_until( num_samples=num_samples, do_sample=do_sample, ) + results.extend(cur_reponses) return dataset.get_original_order(results) + def _generate_cached( + self, + batch: Batch, + max_new_tokens: int, + stop_tokens: list[str], + returns_logits: Optional[bool] = False, + num_samples: Optional[int] = 1, + do_sample: Optional[bool] = False, + ) -> list[GenerativeResponse]: + # Create a cache key from the inputs + cache_key = ( + hashlib.sha256(str(batch.input_ids.tolist()).encode()).hexdigest(), + max_new_tokens, + tuple(stop_tokens), + returns_logits, + num_samples, + do_sample, + ) + + # Try to get from cache first + cached_responses = self._generation_cache.get(cache_key) + if cached_responses is not None: + logger.info(f"Cache hit for batch {cache_key}") + return cached_responses + + # Generate if not in cache + cur_reponses = self._generate( + batch=batch, + max_new_tokens=max_new_tokens, + stop_tokens=stop_tokens, + returns_logits=returns_logits, + num_samples=num_samples, + do_sample=do_sample, + ) + + # Cache the results + logger.info(f"Caching results for batch {cache_key}") + self._generation_cache.set(cache_key, cur_reponses) + + return cur_reponses + def _generate( self, batch: Batch, From cb16cc506372956bf9ac9745ab9644abb14ee2d5 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Mon, 30 Dec 2024 14:27:30 +0100 Subject: [PATCH 2/3] Changed info to debug logging for cache. --- src/lighteval/models/transformers/base_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lighteval/models/transformers/base_model.py b/src/lighteval/models/transformers/base_model.py index 4b624829d..202ba3134 100644 --- a/src/lighteval/models/transformers/base_model.py +++ b/src/lighteval/models/transformers/base_model.py @@ -872,7 +872,7 @@ def _generate_cached( # Try to get from cache first cached_responses = self._generation_cache.get(cache_key) if cached_responses is not None: - logger.info(f"Cache hit for batch {cache_key}") + logger.debug(f"Cache hit for batch {cache_key}") return cached_responses # Generate if not in cache @@ -886,7 +886,7 @@ def _generate_cached( ) # Cache the results - logger.info(f"Caching results for batch {cache_key}") + logger.debug(f"Caching results for batch {cache_key}") self._generation_cache.set(cache_key, cur_reponses) return cur_reponses From b79d6869b9308cb90ef1fab8ef0c9a925c542ad9 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Tue, 31 Dec 2024 12:24:25 +0100 Subject: [PATCH 3/3] Moved cache location and increased size. --- src/lighteval/models/transformers/base_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lighteval/models/transformers/base_model.py b/src/lighteval/models/transformers/base_model.py index 202ba3134..8213188c8 100644 --- a/src/lighteval/models/transformers/base_model.py +++ b/src/lighteval/models/transformers/base_model.py @@ -273,7 +273,10 @@ def __init__( self.pairwise_tokenization = config.pairwise_tokenization - self._generation_cache = Cache(directory=os.path.join(env_config.cache_dir, "generation_cache")) + self._generation_cache = Cache( + directory=os.path.join(env_config.cache_dir, ".generation_cache"), + size_limit=10 * 1024**3, # 10GB + ) @classmethod def from_model( @@ -860,6 +863,7 @@ def _generate_cached( do_sample: Optional[bool] = False, ) -> list[GenerativeResponse]: # Create a cache key from the inputs + # TODO: add model name to the cache key cache_key = ( hashlib.sha256(str(batch.input_ids.tolist()).encode()).hexdigest(), max_new_tokens,