Skip to content

Commit

Permalink
review change
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhengmao1 committed Jul 5, 2024
1 parent d83db16 commit 73c889a
Show file tree
Hide file tree
Showing 11 changed files with 266 additions and 74 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/code-formatter.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

black ./mlora
black ./mlora_cli
isort ./mlora --profile black
isort ./mlora_cli --profile black
36 changes: 36 additions & 0 deletions demo/checkpoint/checkpoint_case_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
dispatcher:
name: "default"
concurrency_num: 1
datasets:
- name: "data"
data: "demo/data.json"
prompt: "demo/prompt.yaml"
prompt_type: "instruction"
preprocess: "default"
adapters:
- name: "lora_0"
type: "lora"
path: "adapters/lora_sft_checkpoint"
optimizer: "adamw"
lr: 3e-4
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
tasks:
- type: "train"
name: "task_0"
adapter: "lora_0"
dataset: "data"
batch_size: 16
mini_batch_size: 16
num_epochs: 2
cutoff_len: 256
save_step: 5
36 changes: 36 additions & 0 deletions demo/checkpoint/checkpoint_case_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
dispatcher:
name: "default"
concurrency_num: 1
datasets:
- name: "data"
data: "demo/data.json"
prompt: "demo/prompt.yaml"
prompt_type: "instruction"
preprocess: "default"
adapters:
- name: "lora_0"
type: "lora"
path: "adapters/lora_sft_checkpoint"
optimizer: "adamw"
lr: 3e-4
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
tasks:
- type: "train"
name: "task_0"
adapter: "lora_0"
dataset: "data"
batch_size: 16
mini_batch_size: 16
num_epochs: 10
cutoff_len: 256
save_step: 10
2 changes: 1 addition & 1 deletion mlora/config/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, config: Dict[str, str]) -> None:
self.init(self.__params_map, config)

@abstractmethod
def to_fn_parameters(self) -> Dict[str, str]: ...
def to_fn_parameters(self) -> Dict[str, Any]: ...


class CosineLRSchedulerConfig(LRSchedulerConfig):
Expand Down
2 changes: 1 addition & 1 deletion mlora/executor/context/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def switch_device(self, device: str) -> None:
return

for _, adapter in self.adapter_model_.items():
self.switch_list_tensor(adapter.get_tensors(), device)
self.switch_list_tensor(adapter.get_all_tensors(), device)

self.device_ = device

Expand Down
72 changes: 41 additions & 31 deletions mlora/executor/context/lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging
import os
from collections import OrderedDict
from typing import Dict, override

Expand All @@ -14,8 +12,10 @@
from .train import TrainTaskContext


def _load_lora_weight(
obj: TaskContext, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo]
def _init_lora_weight(
context: TaskContext,
config: LoRAConfig,
linears_info: OrderedDict[str, LinearInfo],
):
# init the weight
for linear_name, linear_info in linears_info.items():
Expand All @@ -25,37 +25,16 @@ def _load_lora_weight(
if config.target_[target_name] is not True:
continue

obj.adapter_model_[linear_name] = LoRA(
context.adapter_model_[linear_name] = LoRA(
config.name_,
linear_info.in_dim_,
linear_info.out_dim_,
config.r_,
config.alpha_,
config.dropout_,
)
weight_dict = None

if os.path.isdir(obj.path_):
logging.info(f"Adapter {obj.name_}:{obj.path_} weight exist, load from file.")
weight_dict = torch.load(f"{obj.path_}{os.sep}adapter_model.bin")
prefix_name = "base_model.model.model."
else:
logging.info(
f"Adapter {obj.name_}:{obj.path_} weight not exist, use the default weight."
)

for name, module in obj.adapter_model_.items():
lora_a = (
None
if weight_dict is None
else weight_dict[prefix_name + name + ".lora_A.weight"]
)
lora_b = (
None
if weight_dict is None
else weight_dict[prefix_name + name + ".lora_B.weight"]
)
module.init_weight(lora_a, lora_b)
for _, module in context.adapter_model_.items():
module.init_weight(None, None)


class InferenceLoRAContext(InferenceTaskContext):
Expand All @@ -68,23 +47,26 @@ def __init__(

@override
def load_weight(self, linears_info: OrderedDict[str, LinearInfo]):
_load_lora_weight(self, self.config_, linears_info)
_init_lora_weight(self, self.config_, linears_info)


class TrainLoRAContext(TrainTaskContext):
config_: LoRAConfig

def __init__(
self, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo]
self,
config: LoRAConfig,
linears_info: OrderedDict[str, LinearInfo],
) -> None:
super().__init__(config, linears_info)

self.loss_fn_ = torch.nn.CrossEntropyLoss()

@override
def load_weight(self, linears_info: OrderedDict[str, LinearInfo]):
_load_lora_weight(self, self.config_, linears_info)
_init_lora_weight(self, self.config_, linears_info)

@override
def weight_dict(self) -> Dict[str, torch.Tensor]:
# base_model.model.model.layers.{0}.self_attn.{q_proj}.{lora_A}.weight
# base_model.model.model.layers.{0}.mlp.{gate_proj}.{lora_A}.weight
Expand All @@ -95,3 +77,31 @@ def weight_dict(self) -> Dict[str, torch.Tensor]:
ret_val[prefix_name + ".lora_B.weight"] = adapter.lora_b_

return ret_val

@override
def state_dict(self) -> Dict[str, torch.Tensor]:
return self.optimizer_.state_dict()

@override
def recover_optimizer(self, state_dict: Dict[str, torch.Tensor]):
assert self.optimizer_ is not None
self.optimizer_.load_state_dict(state_dict)

@override
def recover_lr(self, last_epoch: int):
# the last_epoch is increased every time you call .step() of scheduler
# different from the train epoch, be careful
if self.lr_scheduler_ is None:
return

# we recreate the lr scheduler
self.create_lr_scheduler(self.config_.lr_scheduler_config_, last_epoch)

@override
def recover_weight(self, weight_dict: Dict[str, torch.Tensor]):
assert weight_dict is not None
prefix_name = "base_model.model.model."
for name, module in self.adapter_model_.items():
lora_a = weight_dict[prefix_name + name + ".lora_A.weight"]
lora_b = weight_dict[prefix_name + name + ".lora_B.weight"]
module.init_weight(lora_a, lora_b)
34 changes: 27 additions & 7 deletions mlora/executor/context/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ class TrainTaskContext(TaskContext):
lr_scheduler_: torch.optim.lr_scheduler.LRScheduler | None

def __init__(
self, config: AdapterConfig, linears_info: OrderedDict[str, LinearInfo]
self,
config: AdapterConfig,
linears_info: OrderedDict[str, LinearInfo],
) -> None:
super().__init__(config)

# load the adapter's weight
self.load_weight(linears_info)

for module in self.adapter_model_.values():
module.enable_grad()

Expand All @@ -38,6 +39,19 @@ def __init__(
@abstractmethod
def weight_dict(self) -> Dict[str, torch.Tensor]: ...

@abstractmethod
def state_dict(self) -> Dict[str, torch.Tensor]: ...

# recover_optimizer
@abstractmethod
def recover_optimizer(self, state_dict: Dict[str, torch.Tensor]): ...

@abstractmethod
def recover_lr(self, now_epoch: int): ...

@abstractmethod
def recover_weight(self, weight_dict: Dict[str, torch.Tensor]): ...

def create_optimizer(self, optim_config: OptimizerConfig | None):
assert optim_config is not None

Expand All @@ -46,31 +60,37 @@ def create_optimizer(self, optim_config: OptimizerConfig | None):

parameters: List[torch.Tensor] = []
for adapter in self.adapter_model_.values():
parameters.extend(adapter.get_tensors())
parameters.extend(adapter.get_trainable_tensors())

self.optimizer_ = OPTIMIZER_CLASS[optimizer_type_](
parameters, **optim_config.to_fn_parameters()
)

def create_lr_scheduler(self, lr_scheduler_config: LRSchedulerConfig | None):
def create_lr_scheduler(
self, lr_scheduler_config: LRSchedulerConfig | None, last_epoch: int = -1
):
assert self.optimizer_ is not None

if lr_scheduler_config is None:
self.lr_scheduler_ = None
return

lr_scheduler_type_ = lr_scheduler_config.lr_scheduler_
assert lr_scheduler_type_ in LR_SCHEDULER_CLASS

kwargs = lr_scheduler_config.to_fn_parameters()
kwargs["last_epoch"] = last_epoch

self.lr_scheduler_ = LR_SCHEDULER_CLASS[lr_scheduler_type_](
self.optimizer_, **lr_scheduler_config.to_fn_parameters() # type: ignore
self.optimizer_,
**kwargs, # type: ignore
)

def switch_device(self, device: str) -> None:
if self.device_ == device:
return

for _, adapter in self.adapter_model_.items():
self.switch_list_tensor(adapter.get_tensors(), device)
self.switch_list_tensor(adapter.get_all_tensors(), device)

self.switch_optimizer(device)

Expand Down
Loading

0 comments on commit 73c889a

Please sign in to comment.