Skip to content

Commit

Permalink
add config checking
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jan 6, 2024
1 parent c67b6d1 commit de52c99
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions mlora/modelargs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from transformers.activations import ACT2FN
from typing import List, Dict, Tuple
from dataclasses import dataclass

Expand Down Expand Up @@ -88,6 +89,18 @@ class LoraConfig:
lora_dropout_: float = None
target_modules_: Dict[str, bool] = None

def check(self) -> "LoraConfig":
assert isinstance(self.lora_r_, int) and self.lora_r_ > 0
assert isinstance(self.lora_alpha_, int) and self.lora_alpha_ > 0
assert isinstance(self.lora_dropout_,
float) and self.lora_dropout_ >= 0
assert isinstance(self.target_modules_, Dict)
for key, value in self.target_modules_.items():
assert isinstance(key, str) and len(key) > 0
assert isinstance(value, bool)

return self

def from_config(self, config: Dict[str, any]) -> "LoraConfig":
self.lora_r_ = config["r"]
self.lora_alpha_ = config["lora_alpha"]
Expand Down Expand Up @@ -130,6 +143,9 @@ def export(self) -> Dict[str, any]:
return config


available_routing_strategies = ["mixtral", "switch"]


@dataclass
class MixConfig(LoraConfig):
# router config
Expand All @@ -145,6 +161,28 @@ class MixConfig(LoraConfig):
jitter_noise_: float = None
dropout_rate_: float = None

def check(self) -> "MixConfig":
super().check()
assert isinstance(self.router_aux_loss_coef_,
float) and self.router_aux_loss_coef_ >= 0
assert isinstance(self.routing_strategy_,
str) and self.routing_strategy_ in available_routing_strategies
assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
assert isinstance(self.act_fn_, str) and self.act_fn_ in ACT2FN
if self.routing_strategy_ == "mixtral":
assert isinstance(self.top_k_, int) and self.top_k_ > 0
elif self.routing_strategy_ == "switch":
assert isinstance(self.router_z_loss_coef_,
float) and self.router_z_loss_coef_ >= 0
assert isinstance(self.expert_capacity_,
int) and self.expert_capacity_ > 0
assert isinstance(self.jitter_noise_,
float) and self.jitter_noise_ >= 0
assert isinstance(self.dropout_rate_,
float) and self.dropout_rate_ >= 0

return self

def from_config(self, config: Dict[str, any]) -> "MixConfig":
super().from_config(config)
self.router_aux_loss_coef_ = config.get(
Expand Down Expand Up @@ -184,6 +222,6 @@ def export(self) -> Dict[str, any]:

def lora_config_factory(config: Dict[str, any]) -> LoraConfig:
if ("peft_type" in config and config["peft_type"] == "MIXLORA") or "routing_strategy" in config:
return MixConfig().from_config(config)
return MixConfig().from_config(config).check()
else:
return LoraConfig().from_config(config)
return LoraConfig().from_config(config).check()

0 comments on commit de52c99

Please sign in to comment.