diff --git a/src/lighteval/models/transformers/base_model.py b/src/lighteval/models/transformers/base_model.py index 9b815d2b0..4b624829d 100644 --- a/src/lighteval/models/transformers/base_model.py +++ b/src/lighteval/models/transformers/base_model.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import hashlib import logging import os from dataclasses import dataclass @@ -28,6 +29,7 @@ import torch import torch.nn.functional as F import transformers +from diskcache import Cache from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from tqdm import tqdm @@ -271,6 +273,8 @@ def __init__( self.pairwise_tokenization = config.pairwise_tokenization + self._generation_cache = Cache(directory=os.path.join(env_config.cache_dir, "generation_cache")) + @classmethod def from_model( cls, @@ -833,7 +837,7 @@ def greedy_until( padded=[sum(mask == 0) for mask in tokenized["attention_mask"]], ) - cur_reponses = self._generate( + cur_reponses = self._generate_cached( batch=prepared_batch, max_new_tokens=max_new_tokens, stop_tokens=stop_tokens, @@ -841,10 +845,52 @@ def greedy_until( num_samples=num_samples, do_sample=do_sample, ) + results.extend(cur_reponses) return dataset.get_original_order(results) + def _generate_cached( + self, + batch: Batch, + max_new_tokens: int, + stop_tokens: list[str], + returns_logits: Optional[bool] = False, + num_samples: Optional[int] = 1, + do_sample: Optional[bool] = False, + ) -> list[GenerativeResponse]: + # Create a cache key from the inputs + cache_key = ( + hashlib.sha256(str(batch.input_ids.tolist()).encode()).hexdigest(), + max_new_tokens, + tuple(stop_tokens), + returns_logits, + num_samples, + do_sample, + ) + + # Try to get from cache first + cached_responses = self._generation_cache.get(cache_key) + if cached_responses is not None: + logger.info(f"Cache hit for batch {cache_key}") + return cached_responses + + # Generate if not in cache + cur_reponses = self._generate( + batch=batch, + max_new_tokens=max_new_tokens, + stop_tokens=stop_tokens, + returns_logits=returns_logits, + num_samples=num_samples, + do_sample=do_sample, + ) + + # Cache the results + logger.info(f"Caching results for batch {cache_key}") + self._generation_cache.set(cache_key, cur_reponses) + + return cur_reponses + def _generate( self, batch: Batch,