From a2aa7d6d7b01ef55c024b5891197c80569f3be83 Mon Sep 17 00:00:00 2001 From: chilli Date: Sun, 5 May 2024 14:36:07 -0700 Subject: [PATCH] Added grok-1 support --- mixtral-moe/README.md | 10 +++ mixtral-moe/generate.py | 10 +-- mixtral-moe/model.py | 28 +++++++-- mixtral-moe/scripts/convert_hf_checkpoint.py | 64 +++++++++++++------- 4 files changed, 81 insertions(+), 31 deletions(-) diff --git a/mixtral-moe/README.md b/mixtral-moe/README.md index cf5e9d9..a046450 100644 --- a/mixtral-moe/README.md +++ b/mixtral-moe/README.md @@ -1,3 +1,10 @@ +# Grok-1 Support +``` +export MODEL_REPO=hpcai-tech/grok-1 +python scripts/download.py --repo_id $MODEL_REPO +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO +python scip +``` # Mixtral 8x7B [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) is a high-quality sparse mixture of experts (MoE) model that matches or beats GPT3.5 on most benchmarks. This repro is a simple and efficient PyTorch native implementation of Mixtral 8x7B. @@ -7,6 +14,9 @@ export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1 python scripts/download.py --repo_id $MODEL_REPO python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8 + +TOKENIZERS_PARALLELISM=false ENABLE_INTRA_NODE_COMM=1 time torchrun --standalone --nproc_per_node=8 generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --compile --compile_prefill ``` ## Benchmarks diff --git a/mixtral-moe/generate.py b/mixtral-moe/generate.py index 9aa076b..58c6aba 100644 --- a/mixtral-moe/generate.py +++ b/mixtral-moe/generate.py @@ -131,7 +131,7 @@ def generate( def encode_tokens(tokenizer, string, bos=True, device='cuda'): tokens = tokenizer.encode(string) if bos: - tokens = [tokenizer.bos_id()] + tokens + tokens = [tokenizer.bos_token_id] + tokens return torch.tensor(tokens, dtype=torch.int, device=device) def _load_model(checkpoint_path, device, precision, use_tp): @@ -174,7 +174,7 @@ def main( """ assert checkpoint_path.is_file(), checkpoint_path - tokenizer_path = checkpoint_path.parent / "tokenizer.model" + tokenizer_path = checkpoint_path.parent / "tokenizer.json" assert tokenizer_path.is_file(), str(tokenizer_path) global print @@ -196,7 +196,9 @@ def main( device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") - tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + # tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/grok-1", trust_remote_code=True) encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) prompt_length = encoded.size(0) @@ -235,7 +237,7 @@ def callback(x): if done_generating: return buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) - if x.item() == tokenizer.eos_id(): + if x.item() == tokenizer.eos_token_id: done_generating = True if len(buffer) == 4 or done_generating: print(''.join(buffer), end='', flush=True) diff --git a/mixtral-moe/model.py b/mixtral-moe/model.py index 9249ac9..b99f2df 100644 --- a/mixtral-moe/model.py +++ b/mixtral-moe/model.py @@ -50,9 +50,14 @@ def from_name(cls, name: str): assert len(config) == 1, name return cls(**transformer_configs[config[0]]) +attn_output_multiplier = 0.08838834764831845 +embedding_multiplier_scale = 78.38367176906169 +output_multiplier_scale = 0.5773502691896257 +max_attn_val = 30.0 transformer_configs = { "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), + "grok-1": dict(vocab_size=131072, block_size=8192, n_layer=64, n_head=48, n_local_heads=8, dim=6144, intermediate_size=32768, rope_base=1000000.0, num_experts=8, num_activated_experts=2), } class KVCache(nn.Module): @@ -106,11 +111,13 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idx) + x *= embedding_multiplier_scale for i, layer in enumerate(self.layers): x = layer(x, input_pos, freqs_cis, mask) x = self.norm(x) logits = self.output(x) + logits *= output_multiplier_scale return logits @classmethod @@ -123,12 +130,14 @@ def __init__(self, config: ModelArgs) -> None: super().__init__() self.attention = Attention(config) self.block_sparse_moe = MOEFeedForward(config) - self.ffn_norm = RMSNorm(config.dim, config.norm_eps) - self.attention_norm = RMSNorm(config.dim, config.norm_eps) + self.pre_moe_norm = RMSNorm(config.dim, config.norm_eps) + self.post_moe_norm = RMSNorm(config.dim, config.norm_eps) + self.post_attn_norm = RMSNorm(config.dim, config.norm_eps) + self.pre_attn_norm = RMSNorm(config.dim, config.norm_eps) def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: - h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) - out = h + self.block_sparse_moe(self.ffn_norm(h)) + h = x + self.post_attn_norm(self.attention(self.pre_attn_norm(x), freqs_cis, mask, input_pos)) + out = h + self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(h))) return out @@ -160,7 +169,8 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + qkv = self.wqkv(x) + q, k, v = qkv.split([self.dim, kv_size, kv_size], dim=-1) q = q.view(bsz, seqlen, self.n_head, self.head_dim) k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) @@ -176,7 +186,13 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + attn_weights = torch.matmul(q, k.transpose(2, 3)).to(torch.float32) + attn_weights = attn_weights * attn_output_multiplier + attn_weights = max_attn_val * F.tanh(attn_weights / max_attn_val) + attn_weights += torch.where(mask, 0, -float("inf")) + attn_weights = F.softmax(attn_weights, dim=-1).to(q.dtype) + y = torch.matmul(attn_weights, v) + # y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) diff --git a/mixtral-moe/scripts/convert_hf_checkpoint.py b/mixtral-moe/scripts/convert_hf_checkpoint.py index e659931..38a9a27 100644 --- a/mixtral-moe/scripts/convert_hf_checkpoint.py +++ b/mixtral-moe/scripts/convert_hf_checkpoint.py @@ -32,42 +32,52 @@ def convert_hf_checkpoint( print(f"Model config {config.__dict__}") weight_map = { - "tok_embeddings.weight": "tok_embeddings.weight", - "layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight", - "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight", - "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight", - "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight", - "layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1", - "layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2", - "layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3", - "layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight", - "layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight", - "layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight", - "norm.weight": "norm.weight", - "output.weight": "output.weight", + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.attn.o_proj.weight": "layers.{}.attention.wo.weight", + # "layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight", + # "layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight", + # "layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.moe_block.experts.{}.linear.weight": "layers.{}.block_sparse_moe.cond_ffn.w1.{}", + "model.layers.{}.moe_block.experts.{}.linear_1.weight": "layers.{}.block_sparse_moe.cond_ffn.w2.{}", + "model.layers.{}.moe_block.experts.{}.linear_v.weight": "layers.{}.block_sparse_moe.cond_ffn.w3.{}", + "model.layers.{}.moe_block.gate.weight": "layers.{}.block_sparse_moe.gate.weight", + "model.layers.{}.pre_attn_norm.scale": "layers.{}.pre_attn_norm.weight", + "model.layers.{}.post_attn_norm.scale": "layers.{}.post_attn_norm.weight", + "model.layers.{}.pre_moe_norm.scale": "layers.{}.pre_moe_norm.weight", + "model.layers.{}.post_moe_norm.scale": "layers.{}.post_moe_norm.weight", + "model.norm.scale": "norm.weight", + "lm_head.weight": "output.weight", } - pt_files = glob.glob(str(checkpoint_dir / "*.pt")) + pt_files = glob.glob(str(checkpoint_dir / "*.bin")) merged_result = {} for file in sorted(pt_files): state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) merged_result.update(state_dict) final_result = {} - for key, value in merged_result.items(): + for key, value in list(merged_result.items()): if "layers" in key: - abstract_key = re.sub(r'.(\d+).', '.{}.', key) - layer_num = re.search(r'\d+', key).group(0) + abstract_key = re.sub(r'\.(\d+)\.', '.{}.', key) + nums = re.findall(r'\d+', key) + if abstract_key not in weight_map: + continue new_key = weight_map[abstract_key] if new_key is None: continue - new_key = new_key.format(layer_num) + new_key = new_key.format(*nums) else: + if key not in weight_map: + continue new_key = weight_map[key] - final_result[new_key] = value + del merged_result[key] for key in tuple(final_result.keys()): + print(key) if "wq" in key: q = final_result[key] k = final_result[key.replace("wq", "wk")] @@ -77,9 +87,21 @@ def convert_hf_checkpoint( del final_result[key.replace("wq", "wk")] del final_result[key.replace("wq", "wv")] elif "w1" in key or "w3" in key: - final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous() + if not key.endswith('0'): + continue + full_keys = [key[:-1] + str(i) for i in range(8)] + results = [final_result[k] for k in full_keys] + final_result[key[:-2]] = torch.stack(results, dim=0) + for k in full_keys: + del final_result[k] elif "w2" in key: - final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous() + if not key.endswith('0'): + continue + full_keys = [key[:-1] + str(i) for i in range(8)] + results = [final_result[k] for k in full_keys] + final_result[key[:-2]] = torch.stack(results, dim=0) + for k in full_keys: + del final_result[k] elif "gate" in key: final_result[key] = final_result[key].contiguous()