Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] add the dora adapter #235

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ We fine-tuned multiple LoRA adapters using four A6000 graphics cards with fp32 p
| ✓ | [QLoRA](https://arxiv.org/abs/2305.14314),NIPS,2023 |
| ✓ | [LoRA+](https://arxiv.org/abs/2402.12354),ICML,2024 |
| ✓ | [VeRA](https://arxiv.org/abs/2310.11454),ICLR,2024 |
| ✓ | [DoRA](https://arxiv.org/abs/2402.09353),ICML,2024 |

### Supported preference alignment algorithms
| | Variant |
Expand Down
61 changes: 61 additions & 0 deletions demo/dora/dora_case_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
dispatcher:
name: "default"
concurrency_num: 2
datasets:
- name: "data"
data: "demo/data.json"
prompt: "demo/prompt.yaml"
prompt_type: "instruction"
preprocess: "shuffle"
adapters:
- name: "lora_0"
type: "lora"
path: "adapters/lora_sft_lora_with_dora_0"
optimizer: "adamw"
lr: 3e-4
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
- name: "dora_0"
type: "dora"
path: "adapters/lora_sft_lora_with_dora_1"
optimizer: "adamw"
lr: 1e-3
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
tasks:
- type: "train"
name: "task_0"
adapter: "lora_0"
dataset: "data"
batch_size: 64
mini_batch_size: 64
num_epochs: 10
cutoff_len: 256
save_step: 2000
- type: "train"
name: "task_1"
adapter: "dora_0"
dataset: "data"
batch_size: 64
mini_batch_size: 64
num_epochs: 10
cutoff_len: 256
save_step: 2000
2 changes: 1 addition & 1 deletion demo/vera/vera_case_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ adapters:
type: "vera"
path: "adapters/lora_sft_lora_with_vera_1"
optimizer: "adamw"
lr: 3e-5
lr: 3e-3
r: 1024
alpha: 64
dropout: 0.05
Expand Down
2 changes: 2 additions & 0 deletions mlora/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .adapter import (
ADAPTERCONFIG_CLASS,
AdapterConfig,
DoRAConfig,
LoRAConfig,
LoRAPlusConfig,
VeRAConfig,
Expand Down Expand Up @@ -30,6 +31,7 @@
"LoRAConfig",
"LoRAPlusConfig",
"VeRAConfig",
"DoRAConfig",
"ADAPTERCONFIG_CLASS",
"OptimizerConfig",
"LRSchedulerConfig",
Expand Down
9 changes: 9 additions & 0 deletions mlora/config/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,17 @@ def export(self) -> Dict[str, Any]:
}


class DoRAConfig(LoRAConfig):
__params_map: Dict[str, str] = {}

def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.init(self.__params_map, config)


ADAPTERCONFIG_CLASS = {
"lora": LoRAConfig,
"loraplus": LoRAPlusConfig,
"vera": VeRAConfig,
"dora": DoRAConfig,
}
3 changes: 3 additions & 0 deletions mlora/executor/context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Type

from .context import TaskContext
from .dora import InferenceDoRAContext, TrainDoRAContext
from .inference import InferenceTaskContext
from .lora import InferenceLoRAContext, TrainLoRAContext
from .loraplus import TrainLoRAPlusContext
Expand All @@ -11,12 +12,14 @@
"lora": TrainLoRAContext,
"loraplus": TrainLoRAPlusContext,
"vera": TrainVeRAContext,
"dora": TrainDoRAContext,
}

INFERENCECONTEXT_CLASS: Dict[str, Type[InferenceTaskContext]] = {
"lora": InferenceLoRAContext,
"loraplus": InferenceLoRAContext,
"vera": InferenceVeRAContext,
"dora": InferenceDoRAContext,
}


Expand Down
63 changes: 63 additions & 0 deletions mlora/executor/context/dora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from collections import OrderedDict
from typing import override

from mlora.config import DoRAConfig
from mlora.model.args import LinearInfo
from mlora.model.modules import DoRA

from .context import TaskContext
from .lora import InferenceLoRAContext, TrainLoRAContext


def _init_dora_weight(
context: TaskContext,
config: DoRAConfig,
linears_info: OrderedDict[str, LinearInfo],
):
# init the weight
for linear_name, linear_info in linears_info.items():
target_name = linear_name.split(".")[3]
if target_name not in config.target_:
continue
if config.target_[target_name] is not True:
continue

context.adapter_model_[linear_name] = DoRA(
config.name_,
linear_info.in_dim_,
linear_info.out_dim_,
config.r_,
config.alpha_,
config.dropout_,
linear_info.base_weight_,
)
for _, module in context.adapter_model_.items():
module.init_weight(None, None)


class InferenceDoRAContext(InferenceLoRAContext):
config_: DoRAConfig

def __init__(
self, config: DoRAConfig, linears_info: OrderedDict[str, LinearInfo]
) -> None:
super().__init__(config, linears_info)

@override
def load_weight(self, linears_info: OrderedDict[str, LinearInfo]):
_init_dora_weight(self, self.config_, linears_info)


class TrainDoRAContext(TrainLoRAContext):
config_: DoRAConfig

def __init__(
self,
config: DoRAConfig,
linears_info: OrderedDict[str, LinearInfo],
) -> None:
super().__init__(config, linears_info)

@override
def load_weight(self, linears_info: OrderedDict[str, LinearInfo]):
_init_dora_weight(self, self.config_, linears_info)
1 change: 1 addition & 0 deletions mlora/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class LinearInfo:
name_: str
in_dim_: int
out_dim_: int
base_weight_: torch.nn.Linear


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions mlora/model/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .adapter import Adapter, AdapterModel
from .attention import Attention
from .decoder import Decoder
from .dora import DoRA
from .embedding import Embedding
from .linear import Linear
from .lora import LoRA, LoRAFunction
Expand All @@ -19,6 +20,7 @@
"LoRA",
"VeRA",
"vera_shared_weight",
"DoRA",
"LoRAFunction",
"Attention",
"MLP",
Expand Down
1 change: 1 addition & 0 deletions mlora/model/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def linears_info(self) -> OrderedDict[str, LinearInfo]:
name_=name,
in_dim_=module.weight_.in_features,
out_dim_=module.weight_.out_features,
base_weight_=module.weight_,
)

return ret_val
69 changes: 69 additions & 0 deletions mlora/model/modules/dora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import math
from typing import override

import torch

from .lora import LoRA


class DoRA(LoRA):
magnitude_: torch.Tensor

base_weight_: torch.nn.Linear

def __init__(
self,
adapter_name: str,
in_dim: int,
out_dim: int,
r: int,
alpha: int,
dropout: float,
base_weight: torch.nn.Linear,
):
super().__init__(adapter_name, in_dim, out_dim, r, alpha, dropout)
self.adapter_type_ = "dora"

# just refer the base weight, do not change it!!!
self.base_weight_: torch.nn.Linear = base_weight

self.magnitude_: torch.Tensor = torch.zeros(
size=(1, out_dim), device="cpu", requires_grad=False, dtype=torch.float32
)

@override
def init_weight(
self, lora_a: torch.Tensor | None = None, lora_b: torch.Tensor | None = None
):
with torch.no_grad():
if lora_a is None:
torch.nn.init.kaiming_normal_(self.lora_a_, a=math.sqrt(5))
else:
self.lora_a_.copy_(lora_a)

if lora_b is not None:
self.lora_b_.copy_(lora_b)

self.magnitude_.copy_(self.get_weight_norm())

def get_weight_norm(self) -> torch.Tensor:
with torch.no_grad():
# the dim is out_dim * in_dim
lora_weight = self.scaling_ * (self.lora_b_ @ self.lora_a_)
weight = (
lora_weight.to(self.base_weight_.weight.device)
+ self.base_weight_.weight
)
weight = weight.to(self.lora_a_.device)
weight_norm: torch.Tensor = torch.linalg.norm(weight, dim=1).to(
weight.dtype
)

assert weight_norm.requires_grad is False
assert weight_norm.grad_fn is None

return weight_norm

@override
def get_all_tensors(self) -> torch.List[torch.Tensor]:
return [self.lora_a_, self.lora_b_, self.magnitude_]
58 changes: 55 additions & 3 deletions mlora/model/modules/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, MutableMapping, Optional, Tuple
from typing import Callable, List, MutableMapping, Optional, Tuple

import bitsandbytes
import torch
Expand All @@ -8,6 +8,7 @@
from mlora.profiler import nvtx_range, set_backward_tracepoint

from .adapter import Adapter
from .dora import DoRA
from .lora import LoRA, LoRAFunction, get_range_tensor
from .vera import VeRA

Expand Down Expand Up @@ -39,8 +40,15 @@ def forward(self, data: torch.Tensor, input_args: ModelData) -> torch.Tensor:
result = self.weight_.forward(data)
set_backward_tracepoint(result.grad_fn, "b_linear")

result = self.__lora_forward(data, input_args, result)
result = self.__vera_forward(data, input_args, result)
adapter_func_list: List[Callable] = [
self.__lora_forward,
self.__vera_forward,
self.__dora_forward,
]

for func in adapter_func_list:
result = func(data, input_args, result)

return result

def __lora_forward(
Expand Down Expand Up @@ -117,6 +125,50 @@ def __vera_forward(

return result

def __dora_forward(
self, data: torch.Tensor, input_args: ModelData, result: torch.Tensor
) -> torch.Tensor:
lora_range = get_range_tensor(data.device, data.shape[0])

for lora_config in input_args.data_config_:
adapter_name = lora_config.adapter_name_

if adapter_name not in self.adapters_ or not isinstance(
self.adapters_[adapter_name], DoRA
):
continue

adapter = self.adapters_[adapter_name]

start_idx = lora_config.batch_start_idx_
end_idx = lora_config.batch_end_idx_

with nvtx_range("f_dora"):
weight_norm = adapter.get_weight_norm()
mag_norm_scale = (adapter.magnitude_ / weight_norm).view(1, -1)

dora_data = F.dropout(
data[start_idx:end_idx],
p=adapter.dropout_,
training=True,
inplace=False,
)
lora_result = dora_data @ adapter.lora_a_.transpose(0, 1)
lora_result = lora_result @ adapter.lora_b_.transpose(0, 1)
lora_result = mag_norm_scale * lora_result * adapter.scaling_

base_result = (
result[start_idx:end_idx] * (mag_norm_scale - 1) + lora_result
)

result = result.index_copy(
dim=0, index=lora_range[start_idx:end_idx], source=base_result
)

set_backward_tracepoint(result.grad_fn, "b_dora")

return result

def load_adapter(self, adapter: Adapter):
assert adapter.adapter_name_ not in self.adapters_
self.adapters_[adapter.adapter_name_] = adapter
Expand Down
Loading
Loading