Skip to content

Commit

Permalink
[Fix] chat templates (#260)
Browse files Browse the repository at this point in the history
* [fix] chat templates keep last newlines

* [ci] fix llama2
  • Loading branch information
huyiwen authored Jun 11, 2024
1 parent 3d426dd commit e932849
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/utilization/model/test_apply_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_llama2(conversation: Conversation):
"This is a system message.\n"
"<</SYS>>\n"
"\n"
"This is a user message. [/INST] This is an assistant message. </s><s>[INST] This is the second user message. [/INST] This is the second assistant message. </s><s>[INST]"
"This is a user message. [/INST] This is an assistant message. </s><s>[INST] This is the second user message. [/INST] This is the second assistant message. </s><s>[INST] "
)


Expand Down
6 changes: 6 additions & 0 deletions utilization/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def smart_space(parts: List[str], auto_leading_space) -> str:
"assistant_start": "",
"assistant_end": " </s><s>[INST] ",
"auto_leading_space": True,
"final_rstrip": False,
"default_stops": [],
},
"chatml": {
Expand All @@ -91,6 +92,7 @@ def smart_space(parts: List[str], auto_leading_space) -> str:
"assistant_start": "<|im_start|>assistant\n",
"assistant_end": "<|im_end|>\n",
"auto_leading_space": True,
"final_rstrip": False,
"default_stops": ["<|im_end|>"],
},
"zephyr": {
Expand All @@ -101,6 +103,7 @@ def smart_space(parts: List[str], auto_leading_space) -> str:
"assistant_start": "<|assistant|>\n",
"assistant_end": "</s>\n",
"auto_leading_space": True,
"final_rstrip": False,
"default_stops": ["</s>"],
},
"phi3": {
Expand All @@ -111,6 +114,7 @@ def smart_space(parts: List[str], auto_leading_space) -> str:
"assistant_start": "<|assistant|>\n",
"assistant_end": "<|end|>\n",
"auto_leading_space": True,
"final_rstrip": False,
"default_stops": ["<|end|>"],
},
"llama3": {
Expand All @@ -121,6 +125,7 @@ def smart_space(parts: List[str], auto_leading_space) -> str:
"assistant_start": "<|start_header_id|>assistant<|end_header_id|>\n\n",
"assistant_end": "<|eot_id|>",
"auto_leading_space": True,
"final_rstrip": False,
"default_stops": ["<|eot_id|>"],
},
"alpaca": {
Expand All @@ -131,6 +136,7 @@ def smart_space(parts: List[str], auto_leading_space) -> str:
"assistant_start": "### Response:\n",
"assistant_end": "\n\n",
"auto_leading_space": True,
"final_rstrip": False,
"default_stops": ["###"],
}
}

0 comments on commit e932849

Please sign in to comment.