Skip to content

Commit

Permalink
[improvment] add user friendly method to generate the prompt (#221)
Browse files Browse the repository at this point in the history
* use the jinja language to generate prompt
  • Loading branch information
yezhengmao1 authored Jun 23, 2024
1 parent b6b158a commit 15f065b
Show file tree
Hide file tree
Showing 29 changed files with 338 additions and 188 deletions.
3 changes: 2 additions & 1 deletion demo/cpo/cpo_case_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ dispatcher:
datasets:
- name: "cpo_data"
data: "demo/data.json"
prompt: "demo/preference_template.json"
prompt: "demo/preference_prompt.yaml"
prompt_type: "preference"
preprocess: "default"
adapters:
- name: "lora_cpo"
Expand Down
3 changes: 2 additions & 1 deletion demo/dpo/dpo_case_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ dispatcher:
datasets:
- name: "dpo_data"
data: "demo/data.json"
prompt: "demo/preference_template.json"
prompt: "demo/preference_prompt.yaml"
prompt_type: "preference"
preprocess: "default"
adapters:
- name: "lora_dpo"
Expand Down
3 changes: 2 additions & 1 deletion demo/dpo/dpo_case_2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ dispatcher:
datasets:
- name: "dpo_data"
data: "demo/data.json"
prompt: "demo/preference_template.json"
prompt: "demo/preference_prompt.yaml"
prompt_type: "preference"
preprocess: "default"
adapters:
- name: "lora_dpo"
Expand Down
3 changes: 2 additions & 1 deletion demo/dpo/dpo_case_3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ dispatcher:
datasets:
- name: "dpo_data"
data: "demo/data.json"
prompt: "demo/preference_template.json"
prompt: "demo/preference_prompt.yaml"
prompt_type: "preference"
preprocess: "default"
adapters:
- name: "lora_dpo"
Expand Down
3 changes: 2 additions & 1 deletion demo/lora/lora_case_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ dispatcher:
datasets:
- name: "data"
data: "demo/data.json"
prompt: "demo/template.json"
prompt: "demo/prompt.yaml"
prompt_type: "instruction"
preprocess: "shuffle"
adapters:
- name: "lora_0"
Expand Down
3 changes: 2 additions & 1 deletion demo/loraplus/loraplus_case_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ dispatcher:
datasets:
- name: "data"
data: "demo/data.json"
prompt: "demo/template.json"
prompt: "demo/prompt.yaml"
prompt_type: "instruction"
preprocess: "shuffle"
adapters:
- name: "lora_0"
Expand Down
13 changes: 13 additions & 0 deletions demo/preference_prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
template: |
### Instruction:
{{ data_point['instruction'] + '\n'}}
{% if 'input' in data_point %}
### Input:
{{ data_point['input'] + '\n'}}
{% endif %}
### Output:
{% if is_chosen %}
{{ data_point['chosen'] + '\n'}}
{% else %}
{{ data_point['reject'] + '\n'}}
{% endif %}
6 changes: 0 additions & 6 deletions demo/preference_template.json

This file was deleted.

9 changes: 9 additions & 0 deletions demo/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
template: |
### Instruction:
{{ data_point['instruction'] + '\n'}}
{% if 'input' in data_point %}
### Input:
{{ data_point['input'] + '\n'}}
{% endif %}
### Output:
{{ data_point['chosen'] + '\n'}}
6 changes: 0 additions & 6 deletions demo/template.json

This file was deleted.

63 changes: 49 additions & 14 deletions mlora/cli/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import requests
from InquirerPy import inquirer
from InquirerPy import inquirer, separator
from rich import print
from rich.table import Table
from rich.box import ASCII
Expand All @@ -14,17 +14,21 @@ def list_dataset(obj):
ret = json.loads(ret.text)

table = Table(show_header=True, show_lines=True, box=ASCII)
table.add_column("name", justify="left")
table.add_column("train", justify="center")
table.add_column("prompt", justify="center")
table.add_column("name", justify="center")
table.add_column("train data name", justify="center")
table.add_column("prompt data name", justify="center")
table.add_column("prompter", justify="center")
table.add_column("preprocess", justify="center")

obj.ret_ = []

for item in ret:
item = json.loads(item)
table.add_row(item["name"], item["train"],
item["prompt"], item["preprocess"])
table.add_row(item["name"],
item["data_name"],
item["prompt_name"],
item["prompt_type"],
item["preprocess"])
obj.ret_.append(item["name"])

obj.pret_ = table
Expand All @@ -34,9 +38,9 @@ def create_dataset(obj):
name = inquirer.text(
message="name:").execute()

list_file(obj, "train")
all_train = [item["name"] for item in obj.ret_]
if len(all_train) == 0:
list_file(obj, "data")
all_train_data = [item["name"] for item in obj.ret_]
if len(all_train_data) == 0:
print("no train data, please upload one")
return

Expand All @@ -47,28 +51,57 @@ def create_dataset(obj):
return

use_train = inquirer.select(
message="train data file:", choices=all_train).execute()
message="train data file:", choices=[separator.Separator(),
*all_train_data,
separator.Separator()]).execute()
use_prompt = inquirer.select(
message="prompt template file:", choices=all_prompt).execute()
message="prompt template file:", choices=[separator.Separator(),
*all_prompt,
separator.Separator()]).execute()
use_preprocess = inquirer.select(
message="data preprocessing:", choices=["default", "shuffle", "sort"]).execute()
message="data preprocessing:", choices=[separator.Separator(),
"default",
"shuffle",
"sort",
separator.Separator()]).execute()

ret = requests.post(url() + "/dataset", json={
"name": name,
"train": use_train,
"prompt": use_prompt,
"data_name": use_train,
"prompt_name": use_prompt,
"preprocess": use_preprocess
})

print(json.loads(ret.text))


def showcase_dataset(obj):
list_dataset(obj)
all_dataset = obj.ret_

if len(all_dataset) == 0:
print("no dataset, please create one")
return

use_dataset = inquirer.select(
message="dataset name:", choices=[separator.Separator(),
*all_dataset,
separator.Separator()]).execute()

ret = requests.get(url() + f"/showcase?name={use_dataset}")
ret = json.loads(ret.text)

print(ret)


def help_dataset(_):
print("Usage of dataset:")
print(" ls")
print(" list all the dataset.")
print(" create")
print(" create a new dataset.")
print(" showcase")
print(" display training data composed of prompt and dataset.")


def do_dataset(obj, args):
Expand All @@ -79,5 +112,7 @@ def do_dataset(obj, args):
return print(obj.pret_)
elif args[0] == "create":
return create_dataset(obj)
elif args[0] == "showcase":
return showcase_dataset(obj)

help_dataset(None)
54 changes: 41 additions & 13 deletions mlora/cli/file.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import json
import requests
from InquirerPy import inquirer, validator
from InquirerPy import inquirer, validator, separator
from rich import print
from rich.table import Table
from rich.box import ASCII


from .setting import url

g_file_type_map = {
"train data": "data",
"prompt data": "prompt"
}


def list_file(obj, file_type: str):
ret = requests.get(url() + f"/{file_type}")
Expand All @@ -16,29 +21,47 @@ def list_file(obj, file_type: str):
table = Table(show_header=True, show_lines=True, box=ASCII)
table.add_column("name", justify="center")
table.add_column("file", justify="center")
if file_type == "prompt":
table.add_column("prompter", justify="center")

for item in ret:
table.add_row(item["name"], item["file"])
row_data = [item["name"], item["file"]["file_path"]]
if file_type == "prompt":
row_data.append(item["file"]["prompt_type"])
table.add_row(*row_data)

obj.ret_ = ret
obj.pret_ = table


def upload_file():
file_type = inquirer.select(
message="type:", choices=["train", "prompt"]).execute()
name = inquirer.text(
message="name:",
validate=validator.EmptyInputValidator("name should not be empty")).execute()
path = inquirer.filepath(
message="file path:",
default="/",
validate=validator.PathValidator(
is_file=True, message="input is not a file"),
only_files=True).execute()

ret = requests.post(
url() + f"/{file_type}?name={name}", files={"data_file": open(path, "rb")})
file_type = inquirer.select(message="file type:",
choices=[separator.Separator(),
*g_file_type_map.keys(),
separator.Separator()]).execute()
file_type = g_file_type_map[file_type]

post_url = url() + f"/{file_type}?name={name}"

if file_type == "prompt":
prompt_type = inquirer.select(message="prompter type:",
choices=[separator.Separator(),
"instruction",
"preference",
separator.Separator()]).execute()
post_url += f"&prompt_type={prompt_type}"

path = inquirer.filepath(message="file path:",
default="/",
validate=validator.PathValidator(is_file=True,
message="input is not a file"),
only_files=True).execute()

ret = requests.post(post_url, files={"data_file": open(path, "rb")})

print(json.loads(ret.text))

Expand All @@ -57,7 +80,12 @@ def do_file(obj, args):
if args[0] == "ls":
# to chose file type
file_type = inquirer.select(
message="type:", choices=["train", "prompt"]).execute()
message="type:",
choices=[separator.Separator(),
*g_file_type_map.keys(),
separator.Separator()]
).execute()
file_type = g_file_type_map[file_type]
list_file(obj, file_type)
return print(obj.pret_)
elif args[0] == "upload":
Expand Down
2 changes: 2 additions & 0 deletions mlora/config/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ class DatasetConfig(DictConfig):
name_: str = ""
data_path_: str = ""
prompt_path_: str = ""
prompt_type_: str = ""
preprocess_: str = "shuffle"

__params_map: Dict[str, str] = {
"name_": "name",
"data_path_": "data",
"prompt_path_": "prompt",
"prompt_type_": "prompt_type",
"preprocess_": "preprocess",
}

Expand Down
1 change: 1 addition & 0 deletions mlora/executor/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"loraplus": InferenceLoRAContext
}


__all__ = [
"TRAINCONTEXT_CLASS",
"INFERENCECONTEXT_CLASS",
Expand Down
11 changes: 3 additions & 8 deletions mlora/executor/task/cpo_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from mlora.model.args import LinearInfo, Tokens, MLoRADataConfig
from mlora.model.tokenizer import Tokenizer
from mlora.config import CPOTaskConfig
from mlora.prompter import PreferenceDataPrompter

import torch
import logging
Expand All @@ -15,10 +13,6 @@
class CPOTask(TrainTask):
now_epoch_: int = 0

def __init__(self, config: CPOTaskConfig, llm_name: str) -> None:
super().__init__(config, llm_name)
self.prompter_ = PreferenceDataPrompter(config.dataset_.prompt_path_)

@override
def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokenizer):
self.tokenizer_ = tokenizer
Expand All @@ -42,13 +36,14 @@ def __cpo_loss_hinge(self, logits: torch.Tensor) -> torch.Tensor:
@override
def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]:
logging.info(
f'Adapter {self.context_.name_} epoch: {self.now_epoch_}/{self.config_.num_epochs_}'
f'Adapter {self.context_.name_} epoch: {
self.now_epoch_}/{self.config_.num_epochs_}'
f' iteration: {self.now_data_idx_}/{len(self.data_)} step: {self.now_step_}')
data_idx_s = self.now_data_idx_
data_idx_e = self.now_data_idx_ + self.config_.mini_batch_size_

# get the train raw string
batch_str = self.prompter_.generate_prompt_batch(
batch_str = self.prompter_.generate_prompt(
self.data_[data_idx_s:data_idx_e])

# convert the string to tokens
Expand Down
9 changes: 1 addition & 8 deletions mlora/executor/task/dpo_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from mlora.model.tokenizer import Tokenizer
from mlora.model.modules import AdapterModel
from mlora.config import DPOTaskConfig
from mlora.prompter import PreferenceDataPrompter
from mlora.model.args import Tokens, MLoRADataConfig, LinearInfo
from mlora.executor.context import TaskContext, INFERENCECONTEXT_CLASS

Expand All @@ -17,11 +15,6 @@
class DPOTask(TrainTask):
ref_context_: TaskContext = None

def __init__(self, config: DPOTaskConfig, llm_name: str) -> None:
super().__init__(config, llm_name)

self.prompter_ = PreferenceDataPrompter(config.dataset_.prompt_path_)

@override
def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokenizer):
self.tokenizer_ = tokenizer
Expand Down Expand Up @@ -88,7 +81,7 @@ def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]:

# 0...mid is chosen data
# mid.end is reject data
batch_str = self.prompter_.generate_prompt_batch(
batch_str = self.prompter_.generate_prompt(
self.data_[data_idx_s:data_idx_e])

assert len(batch_str) % 2 == 0
Expand Down
Loading

0 comments on commit 15f065b

Please sign in to comment.