Skip to content

Commit

Permalink
Added grok-1 support
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee committed May 5, 2024
1 parent 2c33914 commit a2aa7d6
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 31 deletions.
10 changes: 10 additions & 0 deletions mixtral-moe/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Grok-1 Support
```
export MODEL_REPO=hpcai-tech/grok-1
python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
python scip
```
# Mixtral 8x7B
[Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) is a high-quality sparse mixture of experts (MoE) model that matches or beats GPT3.5 on most benchmarks. This repro is a simple and efficient PyTorch native implementation of Mixtral 8x7B.

Expand All @@ -7,6 +14,9 @@
export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1
python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8

TOKENIZERS_PARALLELISM=false ENABLE_INTRA_NODE_COMM=1 time torchrun --standalone --nproc_per_node=8 generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --compile --compile_prefill
```

## Benchmarks
Expand Down
10 changes: 6 additions & 4 deletions mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def generate(
def encode_tokens(tokenizer, string, bos=True, device='cuda'):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
tokens = [tokenizer.bos_token_id] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)

def _load_model(checkpoint_path, device, precision, use_tp):
Expand Down Expand Up @@ -174,7 +174,7 @@ def main(
"""
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
tokenizer_path = checkpoint_path.parent / "tokenizer.json"
assert tokenizer_path.is_file(), str(tokenizer_path)

global print
Expand All @@ -196,7 +196,9 @@ def main(
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
# tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/grok-1", trust_remote_code=True)
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
prompt_length = encoded.size(0)

Expand Down Expand Up @@ -235,7 +237,7 @@ def callback(x):
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
if x.item() == tokenizer.eos_id():
if x.item() == tokenizer.eos_token_id:
done_generating = True
if len(buffer) == 4 or done_generating:
print(''.join(buffer), end='', flush=True)
Expand Down
28 changes: 22 additions & 6 deletions mixtral-moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ def from_name(cls, name: str):
assert len(config) == 1, name
return cls(**transformer_configs[config[0]])

attn_output_multiplier = 0.08838834764831845
embedding_multiplier_scale = 78.38367176906169
output_multiplier_scale = 0.5773502691896257
max_attn_val = 30.0

transformer_configs = {
"Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
"grok-1": dict(vocab_size=131072, block_size=8192, n_layer=64, n_head=48, n_local_heads=8, dim=6144, intermediate_size=32768, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
}

class KVCache(nn.Module):
Expand Down Expand Up @@ -106,11 +111,13 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)
x *= embedding_multiplier_scale

for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
x = self.norm(x)
logits = self.output(x)
logits *= output_multiplier_scale
return logits

@classmethod
Expand All @@ -123,12 +130,14 @@ def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.block_sparse_moe = MOEFeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
self.pre_moe_norm = RMSNorm(config.dim, config.norm_eps)
self.post_moe_norm = RMSNorm(config.dim, config.norm_eps)
self.post_attn_norm = RMSNorm(config.dim, config.norm_eps)
self.pre_attn_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.block_sparse_moe(self.ffn_norm(h))
h = x + self.post_attn_norm(self.attention(self.pre_attn_norm(x), freqs_cis, mask, input_pos))
out = h + self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(h)))
return out


Expand Down Expand Up @@ -160,7 +169,8 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
qkv = self.wqkv(x)
q, k, v = qkv.split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Expand All @@ -176,7 +186,13 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
attn_weights = torch.matmul(q, k.transpose(2, 3)).to(torch.float32)
attn_weights = attn_weights * attn_output_multiplier
attn_weights = max_attn_val * F.tanh(attn_weights / max_attn_val)
attn_weights += torch.where(mask, 0, -float("inf"))
attn_weights = F.softmax(attn_weights, dim=-1).to(q.dtype)
y = torch.matmul(attn_weights, v)
# y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down
64 changes: 43 additions & 21 deletions mixtral-moe/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,42 +32,52 @@ def convert_hf_checkpoint(
print(f"Model config {config.__dict__}")

weight_map = {
"tok_embeddings.weight": "tok_embeddings.weight",
"layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight",
"layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
"layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
"layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
"layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1",
"layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2",
"layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3",
"layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight",
"layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight",
"layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight",
"norm.weight": "norm.weight",
"output.weight": "output.weight",
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.attn.o_proj.weight": "layers.{}.attention.wo.weight",
# "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
# "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
# "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.moe_block.experts.{}.linear.weight": "layers.{}.block_sparse_moe.cond_ffn.w1.{}",
"model.layers.{}.moe_block.experts.{}.linear_1.weight": "layers.{}.block_sparse_moe.cond_ffn.w2.{}",
"model.layers.{}.moe_block.experts.{}.linear_v.weight": "layers.{}.block_sparse_moe.cond_ffn.w3.{}",
"model.layers.{}.moe_block.gate.weight": "layers.{}.block_sparse_moe.gate.weight",
"model.layers.{}.pre_attn_norm.scale": "layers.{}.pre_attn_norm.weight",
"model.layers.{}.post_attn_norm.scale": "layers.{}.post_attn_norm.weight",
"model.layers.{}.pre_moe_norm.scale": "layers.{}.pre_moe_norm.weight",
"model.layers.{}.post_moe_norm.scale": "layers.{}.post_moe_norm.weight",
"model.norm.scale": "norm.weight",
"lm_head.weight": "output.weight",
}

pt_files = glob.glob(str(checkpoint_dir / "*.pt"))
pt_files = glob.glob(str(checkpoint_dir / "*.bin"))

merged_result = {}
for file in sorted(pt_files):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
for key, value in list(merged_result.items()):
if "layers" in key:
abstract_key = re.sub(r'.(\d+).', '.{}.', key)
layer_num = re.search(r'\d+', key).group(0)
abstract_key = re.sub(r'\.(\d+)\.', '.{}.', key)
nums = re.findall(r'\d+', key)
if abstract_key not in weight_map:
continue
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
new_key = new_key.format(*nums)
else:
if key not in weight_map:
continue
new_key = weight_map[key]

final_result[new_key] = value
del merged_result[key]

for key in tuple(final_result.keys()):
print(key)
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
Expand All @@ -77,9 +87,21 @@ def convert_hf_checkpoint(
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
elif "w1" in key or "w3" in key:
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
if not key.endswith('0'):
continue
full_keys = [key[:-1] + str(i) for i in range(8)]
results = [final_result[k] for k in full_keys]
final_result[key[:-2]] = torch.stack(results, dim=0)
for k in full_keys:
del final_result[k]
elif "w2" in key:
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous()
if not key.endswith('0'):
continue
full_keys = [key[:-1] + str(i) for i in range(8)]
results = [final_result[k] for k in full_keys]
final_result[key[:-2]] = torch.stack(results, dim=0)
for k in full_keys:
del final_result[k]
elif "gate" in key:
final_result[key] = final_result[key].contiguous()

Expand Down

0 comments on commit a2aa7d6

Please sign in to comment.