Skip to content

Commit

Permalink
add starcoder2
Browse files Browse the repository at this point in the history
  • Loading branch information
reymondzzzz committed Jul 11, 2024
1 parent c8d6bdc commit cfbbdd2
Show file tree
Hide file tree
Showing 6 changed files with 660 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/transformers_neuronx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers_neuronx.mistral.model import MistralForSampling
from transformers_neuronx.mixtral.model import MixtralForSampling
from transformers_neuronx.opt.model import OPTForSampling
from transformers_neuronx.starcoder2.model import Starcoder2ForSampling

from transformers_neuronx.modeling_auto import NeuronAutoModelForCausalLM

Expand Down
2 changes: 2 additions & 0 deletions src/transformers_neuronx/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"mistral": transformers_neuronx.MistralForSampling,
"mixtral": transformers_neuronx.MixtralForSampling,
"opt": transformers_neuronx.OPTForSampling,
"starcoder2": transformers_neuronx.Starcoder2ForSampling,
}


Expand All @@ -24,6 +25,7 @@
transformers.MistralConfig: "mistral",
transformers.MixtralConfig: "mixtral",
transformers.OPTConfig: "opt",
transformers.Starcoder2Config: "starcoder2",
}


Expand Down
39 changes: 39 additions & 0 deletions src/transformers_neuronx/starcoder2/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from transformers_neuronx import utils


class Starcoder2Config:
def __init__(
self,
config,
n_positions,
batch_size,
amp,
tp_degree,
**kwargs
):
# Extract configs used for building HLO
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size

self.attention_head_size = config.hidden_size // config.num_attention_heads
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads if hasattr(config,
"num_key_value_heads") else config.num_attention_heads
self.num_hidden_layers = config.num_hidden_layers
self.vocab_size = config.vocab_size
self.hidden_act = config.hidden_act
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.max_position_embeddings = config.max_position_embeddings
self.rms_norm_eps = config.norm_epsilon
self.rotary_percentage = getattr(config, "rotary_percentage", 1)
self.rope_theta = getattr(config, "rope_theta", 10000)
self.position_interpolation_factor = getattr(config, "position_interpolation_factor", None)
self.use_bias = getattr(config, "use_bias", True)
utils.maybe_override_attributes(self, kwargs)

# Add required Neuron configs
self.n_positions = n_positions
self.batch_size = batch_size
self.amp = amp
self.tp_degree = tp_degree
261 changes: 261 additions & 0 deletions src/transformers_neuronx/starcoder2/hlo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
from typing import Optional

from transformers_neuronx import constants
from transformers_neuronx import hlo
from transformers_neuronx import utils
from transformers_neuronx.config import NeuronConfig
from transformers_neuronx.constants import LAYOUT_HSB
from transformers_neuronx.hlo import mlp
from transformers_neuronx.layers import transformer, rotary, attention, attention_utils, flash_decoding
from transformers_neuronx.starcoder2.config import Starcoder2Config


class Starcoder2ForSamplingNoEmbeddingHlo:

def __init__(self,
config: Starcoder2Config,
neuron_config: Optional[NeuronConfig] = None
):
self.config = config
self.neuron_config = neuron_config
self.n_positions = None

@property
def shard_over_batch(self):
# Property access allows fallback configuration to be enabled after construction
return (
self.neuron_config is not None
and self.neuron_config.group_query_attention == constants.GQA.SHARD_OVER_BATCH
)

def inputs(self, scribe, dtype, n_active_tokens, batch_size):
tensors, dims = transformer.inputs(
scribe, dtype, batch_size, n_active_tokens, self.config.hidden_size, self.neuron_config)

return tensors, dims

def embedding(self, input_ids, cache_ids, start_ids, last_token_id, embed_weight):
dtype = getattr(input_ids.scribe, self.config.amp)
hidden = hlo.embedding(embed_weight, input_ids, tp_degree=self.config.tp_degree, dtype=dtype)
if self.config.hidden_size % self.config.tp_degree != 0:
hidden = hlo.slice_along(hidden, dim=-1, limit=self.config.hidden_size, start=0)
if self.neuron_config.attention_layout == LAYOUT_HSB:
hidden = hlo.transpose210(hidden)
return hidden

def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *weights):
head_dim = self.config.attention_head_size
pos_embed = rotary.hlo_rotary_embedding(
hidden.dtype, int(head_dim * self.config.rotary_percentage), cache_ids,
base=self.config.rope_theta,
interpolation_factor=self.config.position_interpolation_factor
)
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions)
core_id = None
if self.neuron_config.shard_over_sequence:
core_id, *rst = weights
n_kv_heads = self.config.num_key_value_heads if self.config.num_attention_heads else self.config.num_attention_heads
cores_per_kv_head = self.config.tp_degree // n_kv_heads
self.cores_per_kv_head = cores_per_kv_head if cores_per_kv_head > 1 else self.config.tp_degree
cache_ids, mask, active_mask = flash_decoding.convert_attn_mask_and_cache_id(cache_ids,
core_id, self.n_positions,
cores_per_kv_head=self.cores_per_kv_head)

return hidden, last_token_id, pos_embed, cache_ids, start_ids, mask, active_mask, core_id

def layer(
self, hidden, last_token_id, pos_embed, cache_ids, start_ids, mask, active_mask, core_id,
attn_k_cache, attn_v_cache,
pre_attn_ln_weight, pre_attn_ln_bias,
attn_q_weight, attn_q_scales, attn_q_bias,
attn_k_weight, attn_k_scales, attn_k_bias,
attn_v_weight, attn_v_scales, attn_v_bias,
attn_out_weight, attn_out_scales, attn_out_bias,
post_attn_ln_weight, post_attn_ln_bias,
pre_mlp_ln_weight, pre_mlp_ln_bias,
mlp_in_weight, mlp_in_scales, mlp_in_bias,
mlp_out_weight, mlp_out_scales, mlp_out_bias,
post_mlp_ln_weight, post_mlp_ln_bias,
):
# eps = self.config.rms_norm_eps
# is_bsh = self.neuron_config and self.neuron_config.attention_layout == LAYOUT_BSH
ln_hidden = hlo.layer_norm(hidden, pre_attn_ln_weight, pre_attn_ln_bias)

attn_output, out_attn_k_cache, out_attn_v_cache = self.attention(
ln_hidden, cache_ids, start_ids, pos_embed, mask, active_mask, core_id,
attn_k_cache, attn_v_cache,
attn_q_weight, attn_q_scales, attn_q_bias,
attn_k_weight, attn_k_scales, attn_k_bias,
attn_v_weight, attn_v_scales, attn_v_bias,
attn_out_weight, attn_out_scales, attn_out_bias
)
hidden = hlo.add(attn_output, hidden)

norm_hidden = hlo.layer_norm(hidden, pre_mlp_ln_weight, pre_mlp_ln_bias)
mlp_hidden = mlp(
norm_hidden,
mlp_in_weight, mlp_in_bias, mlp_out_weight, mlp_out_bias,
activation_function='gelu_new', # 'gelu_pytorch_tanh',
tp_degree=self.config.tp_degree,
neuron_config=self.neuron_config
)
res_hidden = hlo.add(mlp_hidden, hidden)
return res_hidden, out_attn_k_cache, out_attn_v_cache

def ln_lm_head(self, hidden, last_token_id, rms_weight, unused_bias, lm_head_weight, lm_head_bias,
return_all_outputs=True):
logits = transformer.rms_lm_head(self.config.tp_degree, hidden, last_token_id, rms_weight, lm_head_weight,
lm_head_bias, return_all_outputs, eps=self.config.rms_norm_eps,
neuron_config=self.neuron_config)
return logits

def attention(
self,
hidden, cache_ids, start_ids, pos_embed, mask, active_mask, core_id,
cached_keys, cached_values,
q_weight, q_scales, q_bias,
k_weight, k_scales, k_bias,
v_weight, v_scales, v_bias,
out_weight, out_scales, out_bias,
):
d_head = self.config.attention_head_size
tp_degree = self.config.tp_degree

# Compute the expected number of KV heads (Used in case fused QKV is used)
n_kv_heads_tp = None
if self.config.num_key_value_heads is not None:
n_head = self.config.num_attention_heads
n_kv_head = self.config.num_key_value_heads
_, n_kv_head_padded = utils.get_qkv_padding(n_head, n_kv_head, tp_degree, self.neuron_config)
n_kv_heads_tp = n_kv_head_padded // tp_degree

# Q = (hidden @ wQ) + bQ
# K = (hidden @ wK) + bK
# V = (hidden @ wV) + bV
query, key, value = attention.query_key_value(
hidden,
q_weight, q_scales, q_bias,
k_weight, k_scales, k_bias,
v_weight, v_scales, v_bias,
d_head,
neuron_config=self.neuron_config,
tp_degree=tp_degree, # TODO: include tp_degree into neuron_config
shard_over_batch=self.shard_over_batch,
n_kv_heads_tp=n_kv_heads_tp,
)

# Q = Rotate(Q)
# K = Rotate(K)
query, key = rotary.rotate_half(query, key, pos_embed, self.config.rotary_percentage,
tp_degree=tp_degree, shard_over_batch=self.shard_over_batch)

# Q = Q / sqrt(d_head)
query = attention.scale(query, d_head)

# In BSH cache layout, the output of QKV linear projection is still kept as SBH for all QKV.
bsh_cache_layout = False
batch_dim = 1
if self.neuron_config is not None:
bsh_cache_layout = self.neuron_config.cache_layout == constants.LAYOUT_BSH
if bsh_cache_layout:
query, key, value = attention_utils.transpose_qkv(query, key, value)
batch_dim = 0

# Single Token Generation ("Prefetch"-style) ans speculative forward
if active_mask is not None:

n_active_tokens = key.sizes[1] if bsh_cache_layout else key.sizes[0]
if n_active_tokens > 1 and self.neuron_config and self.neuron_config.continuous_batching:
# For speculative forward + continuous batching, slice out samples in the batch size
# corresponding to the batch size of the speculative head
slice_sizes = [1] * len(cached_keys.sizes)
if cached_keys.sizes[batch_dim] == 1:
# Use hlo.select for batch size 1 as index select is prohibitively slow
# TODO: revert to hlo.index_select once its faster P126527643
cached_keys_s = hlo.select(cached_keys, batch_dim, hlo.reshape(start_ids, slice_sizes),
keepdim=True)
cached_values_s = hlo.select(cached_values, batch_dim, hlo.reshape(start_ids, slice_sizes),
keepdim=True)
else:
cached_keys_s = hlo.index_select(cached_keys, batch_dim, start_ids)
cached_values_s = hlo.index_select(cached_values, batch_dim, start_ids)
else:
cached_keys_s = cached_keys
cached_values_s = cached_values
# Communication 1: all-gather query from cores
if (n_active_tokens != self.n_positions) and self.neuron_config.shard_over_sequence:
query = flash_decoding.gather_query_group(query, self.cores_per_kv_head,
self.config.num_attention_heads,
tp_degree)

# Sp = Q @ Kp
prior_scores = attention.score(query, cached_keys_s, n_kv_heads=self.config.num_key_value_heads,
tp_degree=tp_degree, neuron_config=self.neuron_config)
prior_scores = attention.mask(prior_scores, mask, tp_degree=tp_degree,
shard_over_batch=self.shard_over_batch)

# Sa = Q @ Ka
active_score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads,
tp_degree=tp_degree, neuron_config=self.neuron_config)
active_score = attention.mask(active_score, active_mask, tp_degree=tp_degree,
shard_over_batch=self.shard_over_batch)

# C = softmax(Sa, Sp) @ (Va, Vp)
if self.neuron_config.shard_over_sequence:
dtype = query.dtype
context = flash_decoding.context(prior_scores, active_score, cached_values_s, value, core_id, mask,
active_mask,
n_kv_heads=self.config.num_key_value_heads,
n_heads=self.config.num_attention_heads, dtype=dtype,
tp_degree=tp_degree, neuron_config=self.neuron_config,
shard_over_batch=self.shard_over_batch)
cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids, value, key,
self.cores_per_kv_head, core_id,
dim=0)

else:
context = attention.context(prior_scores, active_score, cached_values_s, value,
n_kv_heads=self.config.num_key_value_heads, tp_degree=tp_degree,
neuron_config=self.neuron_config)

# KCache[I], VCache[I] = K, V
updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids,
key, value, start_ids,
neuron_config=self.neuron_config)

# Multi-Token Context Encoding
else:
_, batch_size, _, _ = query.sizes
if self.neuron_config.lhs_aligned or batch_size == 1:
context = attention.flash_attention(query, key, value)
else:
# do not use flash attention for lhs padded (right aligned) batch > 1 case
# because it does not correctly take mask into account
context = None

if context is None:
# S = Q @ K

score = attention.score(query, key, n_kv_heads=self.config.num_key_value_heads,
tp_degree=tp_degree, neuron_config=self.neuron_config)
score = attention.mask(score, mask, tp_degree=tp_degree, shard_over_batch=self.shard_over_batch)
context = attention.context_combined(score, value, n_kv_heads=self.config.num_key_value_heads,
tp_degree=tp_degree, neuron_config=self.neuron_config)

if self.neuron_config.shard_over_sequence:
cache_ids, value, key = flash_decoding.select_values_within_bound(cache_ids,
value,
key,
self.cores_per_kv_head,
core_id, dim=0)
# KCache, VCache = K, V
if cached_keys.sizes == key.sizes:
updated_keys, updated_values = key, value
else:
updated_keys, updated_values = attention.fused_kv_update_cache(cached_keys, cached_values, cache_ids,
key, value, start_ids,
neuron_config=self.neuron_config)

# O = (C @ wO) + bO
output = attention.output(context, out_weight, out_scales, out_bias, tp_degree, self.neuron_config)
return output, updated_keys, updated_values
Loading

0 comments on commit cfbbdd2

Please sign in to comment.