-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathexport.py
145 lines (129 loc) · 6.08 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
This script exports the AutoGPT-Q Llama 2 weights in llama2rs.bin format.
"""
import pathlib
import click
import struct
import torch
from typing import Tuple, Union
from torch import nn
from auto_gptq.modeling import BaseGPTQForCausalLM
from auto_gptq import AutoGPTQForCausalLM
from auto_gptq.nn_modules import qlinear
from transformers.models.llama import modeling_llama
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return freqs_cos, freqs_sin
Serializable = Union[torch.Tensor, qlinear.GeneralQuantLinear, modeling_llama.LlamaRMSNorm, nn.modules.linear.Linear, nn.Embedding]
def export(model_wrapper: BaseGPTQForCausalLM, path: pathlib.Path, max_vocab_size: int = 32000):
"""export the model weights in fp32 into .bin file to be read from Rust"""
f = open(path, 'wb')
print(model_wrapper.model)
model = model_wrapper.model.model
def serialize(k: Serializable):
def write_buffer(w: torch.Tensor, transpose: bool = False, cast_to_float: bool = True):
assert isinstance(w, torch.Tensor)
print(w.shape)
if transpose:
w = w.T
t = w.contiguous().view(-1).detach().cpu()
if cast_to_float:
t = t.type(torch.float32)
t = t.numpy()
f.write(memoryview(t))
if type(k) is torch.Tensor:
write_buffer(k)
elif type(k) in (modeling_llama.LlamaRMSNorm, nn.Embedding, nn.modules.linear.Linear):
assert isinstance(k, (modeling_llama.LlamaRMSNorm, nn.Embedding, nn.modules.linear.Linear))
write_buffer(k.weight)
elif type(k) is qlinear.GeneralQuantLinear or hasattr(k, 'qweight'):
offset = torch.tensor([0, 4, 8, 12, 16, 20, 24, 28], dtype=torch.int32)
def rearrange(k: qlinear.GeneralQuantLinear):
order = k.g_idx.cpu().argsort(stable=True)
extract = (k.qweight.cpu()[:, None, :] >> offset[:, None]) & (2**4-1)
extract = extract.view(k.g_idx.shape[0], -1)[order]
store = extract << offset.repeat(1, extract.shape[0] // 8)[..., None]
store = store.view(k.qweight.shape[0], 8, k.qweight.shape[1])
final = torch.zeros(*k.qweight.shape, dtype=torch.int32)
for i in range(8):
final = final | store[:, i]
return final
for w in [
rearrange(k).type(torch.int32),
k.qzeros.type(torch.int32),
k.scales.type(torch.float32),
k.g_idx.argsort(stable=True).type(torch.int32)
]:
write_buffer(w, transpose=len(w.size()) == 2, cast_to_float=False)
else:
raise ValueError(f"Unable to export this type of weight: {k}")
# first write out the header
p = {}
p['dim'] = model.layers[0].mlp.up_proj.g_idx.shape[0]
p['n_layers'] = len(model.layers)
p['n_heads'] = model.layers[0].self_attn.num_heads
p['hidden_dim'] = model.layers[0].mlp.up_proj.qweight.shape[1]
p['vocab_size'] = min(model.embed_tokens.num_embeddings, max_vocab_size)
p['max_seq_len'] = 2048
model.embed_tokens.weight.data = model.embed_tokens.weight[:p['vocab_size']]
model_wrapper.model.lm_head.weight.data = model_wrapper.model.lm_head.weight[:p['vocab_size']]
n_kv_heads = p.get('n_kv_heads') or p['n_heads']
header = struct.pack(
'iiiiiii',
p['dim'], p['hidden_dim'], p['n_layers'], p['n_heads'],
n_kv_heads, -p['vocab_size'], p['max_seq_len']
)
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
# in the checkpoint and should be loaded.
f.write(header)
# next write out the embedding weights
print("writing tok_embeddings...")
f.write(memoryview(torch.tensor([model_wrapper.config.rms_norm_eps]).numpy()))
serialize(model.embed_tokens)
# now all the layers
# attention weights
for i in range(p['n_layers']): serialize(model.layers[i].input_layernorm)
for i in range(p['n_layers']): serialize(model.layers[i].self_attn.q_proj)
for i in range(p['n_layers']): serialize(model.layers[i].self_attn.k_proj)
for i in range(p['n_layers']): serialize(model.layers[i].self_attn.v_proj)
for i in range(p['n_layers']): serialize(model.layers[i].self_attn.o_proj)
# ffn weights
for i in range(p['n_layers']): serialize(model.layers[i].post_attention_layernorm)
for i in range(p['n_layers']): serialize(model.layers[i].mlp.gate_proj)
for i in range(p['n_layers']): serialize(model.layers[i].mlp.down_proj)
for i in range(p['n_layers']): serialize(model.layers[i].mlp.up_proj)
# final rmsnorm
serialize(model.norm)
# freqs_cis
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
serialize(freqs_cos[:p['max_seq_len']])
serialize(freqs_sin[:p['max_seq_len']])
# finally write the output weights
serialize(model_wrapper.model.lm_head)
f.close()
print(f"wrote {path}")
@click.command()
@click.argument("output-path", type=click.Path(exists=False, path_type=pathlib.Path))
@click.argument("model-name", type=str)
@click.argument("revision", type=str)
@click.argument("max-vocab-size", type=int, default=32000)
def main(output_path: pathlib.Path, model_name: str, revision: str, max_vocab_size: int):
print(f"Loading model {model_name} / {revision} ...")
model = AutoGPTQForCausalLM.from_quantized(
model_name,
revision=revision,
use_safetensors=True,
trust_remote_code=True,
inject_fused_attention=False,
inject_fused_mlp=False,
use_triton=False,
quantize_config=None,
)
print("Exporting...")
export(model, output_path, max_vocab_size)
if __name__ == '__main__':
main()