-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
62 lines (55 loc) · 1.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# %%
import transformer_lens as tl
import transformer_lens.utils as utils
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader
from typing import List, Union, Optional
from functools import partial
import copy
import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
# %%
# !/usr/lib/google-cloud-sdk/platform/bundledpythonunix/bin/pip3.9 install git+https://github.com/neelnanda-io/neel-plotly
from neel_plotly import line, scatter, imshow, histogram
# %%
torch.set_grad_enabled(False)
# %%
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
HookedRootModule,
HookPoint,
) # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
# Import OthelloGPT. Stolen from Neel's code.
cfg = HookedTransformerConfig(
n_layers = 8,
d_model = 512,
d_head = 64,
n_heads = 8,
d_mlp = 2048,
d_vocab = 61,
n_ctx = 59,
act_fn="gelu",
normalization_type="LNPre"
)
model = HookedTransformer(cfg)
# %%
sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)
# %%