Skip to content

Commit

Permalink
Added diskcache to base model.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 30, 2024
1 parent 2ef9740 commit 6a9bb65
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion src/lighteval/models/transformers/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import hashlib
import logging
import os
from dataclasses import dataclass
Expand All @@ -28,6 +29,7 @@
import torch
import torch.nn.functional as F
import transformers
from diskcache import Cache
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
Expand Down Expand Up @@ -271,6 +273,8 @@ def __init__(

self.pairwise_tokenization = config.pairwise_tokenization

self._generation_cache = Cache(directory=os.path.join(env_config.cache_dir, "generation_cache"))

@classmethod
def from_model(
cls,
Expand Down Expand Up @@ -833,18 +837,60 @@ def greedy_until(
padded=[sum(mask == 0) for mask in tokenized["attention_mask"]],
)

cur_reponses = self._generate(
cur_reponses = self._generate_cached(
batch=prepared_batch,
max_new_tokens=max_new_tokens,
stop_tokens=stop_tokens,
returns_logits=returns_logits,
num_samples=num_samples,
do_sample=do_sample,
)

results.extend(cur_reponses)

return dataset.get_original_order(results)

def _generate_cached(
self,
batch: Batch,
max_new_tokens: int,
stop_tokens: list[str],
returns_logits: Optional[bool] = False,
num_samples: Optional[int] = 1,
do_sample: Optional[bool] = False,
) -> list[GenerativeResponse]:
# Create a cache key from the inputs
cache_key = (
hashlib.sha256(str(batch.input_ids.tolist()).encode()).hexdigest(),
max_new_tokens,
tuple(stop_tokens),
returns_logits,
num_samples,
do_sample,
)

# Try to get from cache first
cached_responses = self._generation_cache.get(cache_key)
if cached_responses is not None:
logger.info(f"Cache hit for batch {cache_key}")
return cached_responses

# Generate if not in cache
cur_reponses = self._generate(
batch=batch,
max_new_tokens=max_new_tokens,
stop_tokens=stop_tokens,
returns_logits=returns_logits,
num_samples=num_samples,
do_sample=do_sample,
)

# Cache the results
logger.info(f"Caching results for batch {cache_key}")
self._generation_cache.set(cache_key, cur_reponses)

return cur_reponses

def _generate(
self,
batch: Batch,
Expand Down

0 comments on commit 6a9bb65

Please sign in to comment.