Skip to content

Commit

Permalink
add TransformerLens example (#2651)
Browse files Browse the repository at this point in the history
* add TransformerLens example

Many people use TransformerLens to do interpretability and interventions on models, and then need to test the model.

Here is a simple script that allows one to pass in the TransformerLens model and run evaluations on it.

* Ran pre-commit checks
  • Loading branch information
nickypro authored Jan 28, 2025
1 parent a0466f0 commit 42f7913
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions examples/transformer-lens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import warnings

import torch
import torch.nn as nn
from transformer_lens import HookedTransformer
from transformers import AutoConfig

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM


def evaluate_lm_eval(lens_model: HookedTransformer, tasks: list[str], **kwargs):
class HFLikeModelAdapter(nn.Module):
"""Adapts HookedTransformer to match the HuggingFace interface expected by lm-eval"""

def __init__(self, model: HookedTransformer):
super().__init__()
self.model = model
self.tokenizer = model.tokenizer
self.config = AutoConfig.from_pretrained(model.cfg.tokenizer_name)
self.device = model.cfg.device
self.tie_weights = lambda: self

def forward(self, input_ids=None, attention_mask=None, **kwargs):
output = self.model(input_ids, attention_mask=attention_mask, **kwargs)
# Make sure output has the expected .logits attribute
if not hasattr(output, "logits"):
if isinstance(output, torch.Tensor):
output.logits = output
return output

# Only delegate specific attributes we know we need
def to(self, *args, **kwargs):
return self.model.to(*args, **kwargs)

def eval(self):
self.model.eval()
return self

def train(self, mode=True):
self.model.train(mode)
return self

model = HFLikeModelAdapter(lens_model)
warnings.filterwarnings("ignore", message="Failed to get model SHA for")
results = evaluator.simple_evaluate(
model=HFLM(pretrained=model, tokenizer=model.tokenizer),
tasks=tasks,
verbosity="WARNING",
**kwargs,
)
return results


if __name__ == "__main__":
# Load base model
model = HookedTransformer.from_pretrained("pythia-70m")
res = evaluate_lm_eval(model, tasks=["arc_easy"])
print(res["results"])

0 comments on commit 42f7913

Please sign in to comment.