diff --git a/src/models.js b/src/models.js index b0a82cee0..9fabe8bdd 100644 --- a/src/models.js +++ b/src/models.js @@ -3035,9 +3035,9 @@ export class LlamaPreTrainedModel extends PreTrainedModel { // config doesn't contain pad_token_id, so we assume it is the eos_token_id this.config.pad_token_id = this.config.eos_token_id - this.num_heads = this.config.num_attention_heads + this.num_heads = this.config.num_key_value_heads ?? this.config.num_attention_heads this.num_layers = this.config.num_hidden_layers - this.dim_kv = this.config.hidden_size / this.num_heads; + this.dim_kv = this.config.hidden_size / this.config.num_attention_heads } } /**