Skip to content

Commit

Permalink
PPOtask for mlora_train
Browse files Browse the repository at this point in the history
  • Loading branch information
ck-gyj committed Dec 7, 2024
1 parent b3dc8af commit 0bc7106
Show file tree
Hide file tree
Showing 13 changed files with 716 additions and 37 deletions.
Empty file added config.sliding_window
Empty file.
26 changes: 26 additions & 0 deletions demo/data1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[
{
"instruction": "Could you a b c d",
"instruction_paraphrased": "Tell",
"chosen": "mLoRA, short for Multi-LoRA Fine-Tune",
"reject": "mLoRA is an open-source"
},
{
"instruction": "Could you a b c d",
"instruction_paraphrased": "Tell",
"chosen": "mLoRA, short for Multi-LoRA Fine-Tune",
"reject": "mLoRA is an open-source"
},
{
"instruction": "Could you a b c d",
"instruction_paraphrased": "Tell",
"chosen": "mLoRA, short for Multi-LoRA Fine-Tune",
"reject": "mLoRA is an open-source"
},
{
"instruction": "Could you a b c d",
"instruction_paraphrased": "Tell",
"chosen": "mLoRA, short for Multi-LoRA Fine-Tune",
"reject": "mLoRA is an open-source"
}
]
2 changes: 2 additions & 0 deletions demo/generate_prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
template: |
{{ data_point['instruction']+'_'+data_point['chosen']+'_'+data_point['reject']}}
64 changes: 64 additions & 0 deletions demo/ppo/ppo_case1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
dispatcher:
name: "default"
concurrency_num: 2
datasets:
- name: "ppo_data"
data: "demo/data1.json"
prompt: "demo/generate_prompt.yaml"
prompt_type: "instruction"
preprocess: "default"
adapters:
- name: "lora_ppo_critic"
type: "lora"
path: "adapters/lora_ppo_critic"
optimizer: "adamw"
lr: 3e-4
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
- name: "lora_ppo_actor"
type: "lora"
path: "adapters/lora_ppo_actor"
optimizer: "adamw"
lr: 3e-4
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
tasks:
- type: "ppo"
name: "task_0"
adapter: ["lora_ppo_critic","lora_ppo_actor"]
dataset: "ppo_data"
batch_size: 4
mini_batch_size: 4
num_epochs: 100
cutoff_len: 256
save_step: 2000
gamma: 0.99
lamdb: 0.99
entropy_coef: 0
entropy_coef_decay: 0.99
clip_rate: 0.2
K_epochs: 2
T_horizon: 2048
optim_num: 2
loss_type1: "mse"
loss_type2: "adv_loss"
adv_normalization: True
generate_num: 10
2 changes: 2 additions & 0 deletions mlora/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DPOTaskConfig,
TaskConfig,
TrainTaskConfig,
PPOTaskConfig,
)

__all__ = [
Expand All @@ -37,4 +38,5 @@
"ADAPTERCONFIG_CLASS",
"OptimizerConfig",
"LRSchedulerConfig",
"PPOTaskConfig",
]
51 changes: 50 additions & 1 deletion mlora/config/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(
super().__init__(config)
self.init(self.__params_map, config)

self.adapter_ = adapters[config["adapter"]]
if(type(config["adapter"])==str): self.adapter_ = adapters[config["adapter"]]
else : self.adapter_=adapters[config["adapter"][0]]
self.dataset_: DatasetConfig | None = datasets[config["dataset"]]


Expand Down Expand Up @@ -147,10 +148,58 @@ def __init__(
self.lambda_ = float(self.lambda_)
self.temperature_ = float(self.temperature_)

class PPOTaskConfig(TrainTaskConfig):
gamma_: float
lamdb_: float
entropy_coef_: float
entropy_coef_decay_: float
K_epochs_: int
T_horizon_: int
loss_type1_:str
loss_type2_:str
clip_rate_: float
adapter__: AdapterConfig
generate_num_: int

__params_map: Dict[str, str] = {
"gamma_": "gamma",
"lamdb_": "lamdb",
"entropy_coef_": "entropy_coef",
"entropy_coef_decay_": "entropy_coef_decay",
"K_epochs_": "K_epochs",
"T_horizon_": "T_horizon",
"optim_num_": "optim_num",
"loss_type1_": "loss_type1",
"loss_type2_": "loss_type2",
"clip_rate_": "clip_rate",
"generate_num_": "generate_num",
}

def __init__(
self,
config: Dict[str, str],
adapters: Mapping[str, AdapterConfig],
datasets: Mapping[str, DatasetConfig],
):
super().__init__(config, adapters, datasets)
self.init(self.__params_map, config)

self.gamma_ = float(self.gamma_)
self.lamdb_ = float(self.lamdb_)
self.entropy_coef_ = float(self.entropy_coef_)
self.entropy_coef_decay_ = float(self.entropy_coef_decay_)
self.clip_rate_=float(self.clip_rate_)
self.K_epochs_=int(self.K_epochs_)
self.T_horizon_=int(self.T_horizon_)
self.optim_num_=int(self.optim_num_)
self.adapter__=adapters[config["adapter"][1]]
self.generate_num_=int(self.generate_num_)


TASKCONFIG_CLASS: Dict[str, Type[TaskConfig]] = {
"train": TrainTaskConfig,
"dpo": DPOTaskConfig,
"cpo": CPOTaskConfig,
"cit": CITTaskConfig,
"ppo":PPOTaskConfig,
}
24 changes: 18 additions & 6 deletions mlora/executor/dispatcher/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,43 +114,55 @@ def _dispatch_task_out(self):
self.done_event_.notify(task)

def _align_batch_tokens(
self, batch_tokens: List[Tokens], configs: List[MLoRADataConfig]
) -> Tuple[List[Tokens], List[Masks]]:
self, batch_tokens: List[Tokens], batch_label: List[Tokens], configs: List[MLoRADataConfig]
) -> Tuple[List[Tokens], List[Tokens], List[Masks]]:
max_seq_len = max(map(lambda x: len(x), batch_tokens))
max_seq_len= max(max_seq_len,max(map(lambda x: len(x), batch_label)))
max_seq_len = math.ceil(max_seq_len / 8) * 8


batch_masks: List[Masks] = []

for data_config in configs:
s_idx = data_config.batch_start_idx_
e_idx = data_config.batch_end_idx_
label_s_idx=data_config.label_start_idx
label_e_idx=data_config.label_end_idx
batch_tokens[s_idx:e_idx], masks = data_config.expand_fn_(
batch_tokens[s_idx:e_idx], max_seq_len
)
batch_label[label_s_idx:label_e_idx], _= data_config.expand_fn_(
batch_label[label_s_idx:label_e_idx], max_seq_len
)
batch_masks.extend(masks)

return batch_tokens, batch_masks
return batch_tokens, batch_label, batch_masks

def data(self) -> MLoRAData | None:
self._dispatch_task_in()

batch_tokens: List[Tokens] = []
batch_masks: List[Masks] = []
data_configs: List[MLoRADataConfig] = []
batch_label: List[Tokens]= []

# get all train data
start_idx: int = 0
label_start_idx: int =0

for task in self.running_:
data, data_config = task.data(start_idx)
data,label, data_config = task.data(start_idx,label_start_idx)
data_configs.extend(data_config)
batch_tokens.extend(data)
batch_label.extend(label)
start_idx = start_idx + len(data)
label_start_idx = label_start_idx + len(label)

# post process this batch data
batch_tokens, batch_masks = self._align_batch_tokens(batch_tokens, data_configs)
batch_tokens, batch_label, batch_masks = self._align_batch_tokens(batch_tokens,batch_label, data_configs)

return MLoRAData(
batch_tokens=batch_tokens, batch_mask=batch_masks, data_config=data_configs
batch_tokens=batch_tokens, batch_mask=batch_masks, data_config=data_configs, label=batch_label,
)

def step(self):
Expand Down
41 changes: 22 additions & 19 deletions mlora/executor/executor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import copy
from typing import Callable, Dict, Optional

import torch

import mlora.profiler
from mlora.config import MLoRAConfig, TaskConfig
from mlora.model.args import MLoRAData
from mlora.model.llm import LLMModel
Expand All @@ -12,15 +12,18 @@
from .dispatcher import DISPATCHER_CLASS, Dispatcher
from .task import Task

import inspect
import os

class Executor:
model_: LLMModel
tokenizer_: Tokenizer

dispatcher_: Dispatcher
batch_data_: torch.Tensor

def __init__(
self, model: LLMModel, tokenizer: Tokenizer, config: MLoRAConfig
self, model: LLMModel, tokenizer: Tokenizer, config: MLoRAConfig,
) -> None:
self.model_ = model
self.tokenizer_ = tokenizer
Expand Down Expand Up @@ -97,23 +100,33 @@ def notify_terminate_task(self, task_name: str):

def execute(self) -> None:
mm_collect_step = 0

while not self.dispatcher_.is_done():
data: MLoRAData | None = self.dispatcher_.data()
assert data is not None

torch.cuda.reset_peak_memory_stats(device=self.model_.device_)

batch_size = data.batch_size()
token_len = data.token_len()
batch_tokens_=data.batch_tokens_

output = self.model_.forward(data.model_data())
labels = torch.tensor(data.batch_tokens_, dtype=torch.long)
#output1: Policy output2: critic
output1,output2 = self.model_.forward(data.model_data())
output2=output2.squeeze(dim=-1)
labels = torch.tensor(data.label, dtype=torch.long)

total_loss: Optional[torch.Tensor] = None

def do_thing(config):
loss=None
if config.task_type_=="ppo":
r=self.model_.calculate_reward(batch_tokens_[config.batch_start_idx_:config.batch_end_idx_]
,labels[config.label_start_idx:config.label_end_idx])
loss=config.loss_fn_(output1,output2,False,batch_tokens_,labels,r)
else:
loss = config.loss_fn_(output1, labels, torch.tensor(data.batch_mask_))
return loss

for config in data.data_config_:
loss = config.loss_fn_(output, labels, torch.tensor(data.batch_mask_))
loss=do_thing(config)
if loss is None:
continue
total_loss = loss if total_loss is None else total_loss + loss
Expand All @@ -124,14 +137,4 @@ def execute(self) -> None:
self.dispatcher_.step()
mm_collect_step += 1

mlora.profiler.metric_log_dict(
"memory",
{
"batch_size": batch_size,
"token_len": token_len,
"memory": torch.cuda.max_memory_allocated(
device=self.model_.device_
),
},
mm_collect_step,
)

3 changes: 3 additions & 0 deletions mlora/executor/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .cit_task import CITTask
from .cpo_task import CPOTask
from .dpo_task import DPOTask
from .ppo_task import PPOTask
from .task import Task
from .train_task import TrainTask

Expand All @@ -12,6 +13,7 @@
"dpo": DPOTask,
"cpo": CPOTask,
"cit": CITTask,
"ppo": PPOTask,
}


Expand All @@ -32,5 +34,6 @@ def register_task_class(type_name: str, task: Type[Task]):
"DPOTask",
"CPOTask",
"CITTask",
"PPOTask",
"register_task_class",
]
Loading

0 comments on commit 0bc7106

Please sign in to comment.