Skip to content

Commit

Permalink
Add JAISLMHeadModel
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Jan 12, 2024
1 parent aba81d9 commit 3853a47
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,31 @@ def num_layers(self, config: PretrainedConfig) -> int:
num_layers_key="n_layer",
)

JAIS_INFO = StaticTensorNames(
name="JAISLMHeadModel",
pre_weight_names=["transformer.wte.weight", "transformer.relative_pe.slopes"],
post_weight_names=["transformer.ln_f.weight", "transformer.ln_f.bias"],
embed_weight_names=["transformer.wte.weight"],
layer_prefix_format="transformer.h.{idx}",
layer_weight_suffixes=[
"attn.c_attn.weight",
"attn.c_attn.bias",
"attn.c_proj.weight",
"attn.c_proj.bias",
"ln_1.weight",
"ln_1.bias",
"ln_2.weight",
"ln_2.bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"mlp.c_fc2.weight",
"mlp.c_fc2.bias",
"mlp.c_proj.weight",
"mlp.c_proj.bias",
],
num_layers_key="n_layer",
)

GPT2_SEQCLASS_INFO = StaticTensorNames(
name="GPT2ForSequenceClassification",
pre_weight_names=["transformer.wte.weight", "transformer.wpe.weight"],
Expand Down Expand Up @@ -331,6 +356,7 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
CHATGLM_INFO,
STABLELM_INFO,
PHI2_INFO,
JAIS_INFO,
]
for arch in supported:
if arch.name == arch_name:
Expand Down

0 comments on commit 3853a47

Please sign in to comment.