Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update README.zh.md #181

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 97 additions & 6 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@
- [2.2、基座模型](#22基座模型)
- [三、使用方法](#三使用方法)
- [3.1、环境准备](#31环境准备)
- [3.2、数据准备](#32数据准备)
- [3.3、模型微调](#33模型微调)
- [3.4、模型预测](#34模型预测)
- [3.5、模型权重](#35模型权重)
- [3.5.1 模型和微调权重合并](#351-模型和微调权重合并)
- [3.6、模型评估](#36模型评估)
- [3.2、快速开始](#32快速开始)
- [3.3、数据准备](#33数据准备)
- [3.4、模型微调](#34模型微调)
- [3.5、模型预测](#35模型预测)
- [3.6、模型权重](#36模型权重)
- [3.6.1 模型和微调权重合并](#361-模型和微调权重合并)
- [3.7、模型评估](#37模型评估)
- [四、发展路线](#四发展路线)
- [五、贡献](#五贡献)
- [六、感谢](#六感谢)
Expand Down Expand Up @@ -136,6 +137,96 @@ poetry run sh dbgpt_hub/scripts/gen_train_eval_data.sh
```
项目的数据处理代码中已经嵌套了`chase` 、`cosql`、`sparc`的数据处理,可以根据上面链接将数据集下载到data路径后,在`dbgpt_hub/configs/config.py`中将 `SQL_DATA_INFO`中对应的代码注释松开即可。

### 3.2 快速开始

首先,用如下命令安装`dbgpt-hub`:

`pip install dbgpt-hub`

然后,指定参数并用几行代码完成整个Text2SQL fine-tune流程:
```python
from dbgpt_hub.data_process import preprocess_sft_data
from dbgpt_hub.train import start_sft
from dbgpt_hub.predict import start_predict
from dbgpt_hub.eval import start_evaluate

# 配置训练和验证集路径和参数
data_folder = "dbgpt_hub/data"
data_info = [
{
"data_source": "spider",
"train_file": ["train_spider.json", "train_others.json"],
"dev_file": ["dev.json"],
"tables_file": "tables.json",
"db_id_name": "db_id",
"is_multiple_turn": False,
"train_output": "spider_train.json",
"dev_output": "spider_dev.json",
}
]

# 配置fine-tune参数
train_args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"do_train": True,
"dataset": "example_text2sql_train",
"max_source_length": 2048,
"max_target_length": 512,
"finetuning_type": "lora",
"lora_target": "q_proj,v_proj",
"template": "llama2",
"lora_rank": 64,
"lora_alpha": 32,
"output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"overwrite_cache": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 16,
"lr_scheduler_type": "cosine_with_restarts",
"logging_steps": 50,
"save_steps": 2000,
"learning_rate": 2e-4,
"num_train_epochs": 8,
"plot_loss": True,
"bf16": True,
}

# 配置预测参数
predict_args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"template": "llama2",
"finetuning_type": "lora",
"checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json",
"predict_out_dir": "dbgpt_hub/output/",
"predicted_out_filename": "pred_sql.sql",
}

# 配置评估参数
evaluate_args = {
"input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql",
"gold": "./dbgpt_hub/data/eval_data/gold.txt",
"gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt",
"db": "./dbgpt_hub/data/spider/database",
"table": "./dbgpt_hub/data/eval_data/tables.json",
"table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json",
"etype": "exec",
"plug_value": True,
"keep_distict": False,
"progress_bar_for_each_datapoint": False,
"natsql": False,
}

# 执行整个Fine-tune流程
preprocess_sft_data(
data_folder = data_folder,
data_info = data_info
)

start_sft(train_args)
start_predict(predict_args)
start_evaluate(evaluate_args)
```

### 3.3、模型微调

Expand Down
Loading