Skip to content

Commit

Permalink
refactor the model init
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhengmao1 committed Jan 26, 2024
1 parent 04bf63d commit fc42eb4
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 150 deletions.
2 changes: 1 addition & 1 deletion mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def get_accumulation_steps(config: Dict[str, any]) -> Dict[str, int]:
def train(config: Dict[str, any], llm_model: mlora.LLMModel, dispatcher: mlora.Dispatcher):
# the train paramas per lora model
all_train_paramas: Dict[str, List[torch.Tensor]
] = llm_model.get_train_paramas(config)
] = llm_model.get_train_paramas()
all_optimizer: Dict[str, torch.optim.Optimizer] = get_optimizer(
config, all_train_paramas)
accumulation_step: Dict[str, int] = get_accumulation_steps(config)
Expand Down
57 changes: 37 additions & 20 deletions mlora/LoraLiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F
import bitsandbytes

from typing import Dict, Optional
from typing import Dict, Optional, Tuple


class Lora():
Expand Down Expand Up @@ -35,6 +35,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:


class Linear():
# the weight just wrapper the module from LlamaForCausalLM
def __init__(self, weight: torch.nn.Module, device: str = None):
if device is None:
self.device_ = weight.device
Expand All @@ -54,36 +55,52 @@ def __init__(self, weight: torch.nn.Module, device: str = None):
self.enable_lora_: bool = False
self.loras_: Dict[str, Lora] = {}

def init_lora_weight(self, adapter_name: str, r: int, alpha: int, dropout: float,
lora_a: Optional[torch.Tensor] = None,
lora_b: Optional[torch.Tensor] = None):
def init_lora_weight(self,
adapter_name: str,
r: int,
alpha: int,
dropout: float,
lora_tensor: Tuple[Optional[torch.Tensor],
Optional[torch.Tensor]] = (None, None)):
# if the lora_tensor is not (None, None), use it to init the lora weight
assert isinstance(lora_tensor, Tuple)
assert len(lora_tensor) == 2
assert type(lora_tensor[0]) == type(lora_tensor[1])

if adapter_name not in self.loras_:
self.loras_[adapter_name] = Lora(adapter_name)
self.loras_[adapter_name].set_parameter(r, alpha, dropout)

if isinstance(self.weight_, bitsandbytes.nn.Linear4bit):
out_dim = self.weight_.out_features
in_dim = self.weight_.in_features
out_dim, in_dim = self.weight_.out_features, self.weight_.in_features
else:
out_dim, in_dim = self.weight_.weight.shape

self.loras_[adapter_name].set_parameter(r, alpha, dropout)

if lora_a is not None:
self.loras_[adapter_name].lora_a_ = lora_a.to(
device=self.device_).to(torch.float32).requires_grad_(True)
else:
self.loras_[adapter_name].lora_a_ = torch.zeros(
def random_init_lora_a_tensor(lora: Lora):
lora.__dict__["lora_a_"] = torch.zeros(
size=(r, in_dim), device=self.device_, requires_grad=True, dtype=torch.float32)
torch.nn.init.kaiming_normal_(
self.loras_[adapter_name].lora_a_, a=math.sqrt(5))
torch.nn.init.kaiming_normal_(lora.lora_a_, a=math.sqrt(5))

if lora_b is not None:
self.loras_[adapter_name].lora_b_ = lora_b.to(
device=self.device_).to(torch.float32).requires_grad_(True)
else:
self.loras_[adapter_name].lora_b_ = torch.zeros(
def zero_init_lora_b_tensor(lora: Lora):
lora.__dict__["lora_b_"] = torch.zeros(
size=(out_dim, r), device=self.device_, requires_grad=True, dtype=torch.float32)

def replace_init_lora_tensor(lora: Lora, lora_a: torch.Tensor, lora_b: torch.Tensor):
lora.__dict__["lora_a_"] = lora_a.to(device=self.device_).to(
torch.float32).detach().requires_grad_(True)
lora.__dict__["lora_b_"] = lora_b.to(device=self.device_).to(
torch.float32).detach().requires_grad_(True)

# ensuer it's none, so we can use the __dict__ to init it
assert self.loras_[adapter_name].lora_a_ is None
assert self.loras_[adapter_name].lora_b_ is None

if lora_tensor == (None, None):
random_init_lora_a_tensor(self.loras_[adapter_name])
zero_init_lora_b_tensor(self.loras_[adapter_name])
else:
replace_init_lora_tensor(self.loras_[adapter_name], *lora_tensor)

self.enable_lora_ = True

def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.Tensor:
Expand Down
5 changes: 3 additions & 2 deletions mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ def forward(self, input: MultiLoraBatchData):
pass

@abstractclassmethod
def get_train_paramas(self, config: Dict[str, str]) -> Dict[str, List[torch.Tensor]]:
def get_train_paramas(self) -> Dict[str, List[torch.Tensor]]:
pass

@abstractclassmethod
def init_lora_weight(self, adapter_name: str,
def init_lora_weight(self,
adapter_name: str,
r: int,
lora_alpha: int,
lora_dropout: float,
Expand Down
24 changes: 2 additions & 22 deletions mlora/model_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,28 +268,8 @@ def init_lora_weight(self, adapter_name: str,
transformer_layer.init_lora_layer_weight(
adapter_name, r, lora_alpha, lora_dropout, target, weight)

def get_train_paramas(self, config: Dict[str, str]) -> Dict[str, List[torch.Tensor]]:
train_paramas = {}

for transformer_layer in self.layers_:
for lora_config in config["lora"]:
adapter_name = lora_config["name"]
if adapter_name not in train_paramas:
train_paramas[adapter_name] = []

lora_layer_list = [transformer_layer.query_key_value_.loras_,
transformer_layer.dense_.loras_,
transformer_layer.dense_h_to_4h_.loras_,
transformer_layer.dense_4h_to_h_.loras_]

for lora_layer in lora_layer_list:
if adapter_name in lora_layer:
train_paramas[adapter_name].append(
lora_layer[adapter_name].lora_a_)
train_paramas[adapter_name].append(
lora_layer[adapter_name].lora_b_)

return train_paramas
def get_train_paramas(self) -> Dict[str, List[torch.Tensor]]:
pass

def get_lora_weight_dict(self, lora_name: str) -> Tuple[Dict[str, torch.Tensor], List[str]]:
# return the lora weight and target_module's name
Expand Down
Loading

0 comments on commit fc42eb4

Please sign in to comment.