Skip to content

Commit

Permalink
[SLM] Small correction on Stablelm and Qwen2. (#1958)
Browse files Browse the repository at this point in the history
* small fix

* small fix

* Update stablelm_model.py
  • Loading branch information
tlopex authored Mar 16, 2024
1 parent 994f928 commit 73f2b27
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/mlc_llm/model/qwen2/qwen2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def create_paged_kv_cache(
page_size=page_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
head_dim=self.head_dim,
rope_mode=RopeMode.NORMAL,
rope_scale=1,
Expand Down
5 changes: 2 additions & 3 deletions python/mlc_llm/model/stable_lm/stablelm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __post_init__(self):
bold("context_window_size"),
)
self.prefill_chunk_size = self.context_window_size
assert self.tensor_parallel_shards == 1, "StableLM currently does not support sharding."


# pylint: disable=invalid-name,missing-docstring
Expand Down Expand Up @@ -168,11 +167,11 @@ def __init__(self, config: StableLmConfig):
self.num_hidden_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.vocab_size = config.vocab_size
self.rope_theta = config.rope_theta
self.tensor_parallel_shards = config.tensor_parallel_shards
self.dtype = "float32"
self.partial_rotary_factor = config.partial_rotary_factor

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -253,7 +252,7 @@ def create_paged_kv_cache(
page_size=page_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards,
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
head_dim=self.head_dim,
rope_mode=RopeMode.NORMAL,
rope_scale=1,
Expand Down

0 comments on commit 73f2b27

Please sign in to comment.