Skip to content

Commit

Permalink
making LoRA layer config changable
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Nov 15, 2023
1 parent 87a28d0 commit e396123
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
5 changes: 4 additions & 1 deletion stanza/models/coref/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from dataclasses import dataclass
from typing import Dict
from typing import Dict, List


@dataclass
Expand Down Expand Up @@ -40,6 +40,9 @@ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-me
lora_rank: int
lora_dropout: float

lora_targets: List[str]
lora_fully_tune: List[str]

bert_finetune: bool
dropout_rate: float
learning_rate: float
Expand Down
5 changes: 5 additions & 0 deletions stanza/models/coref/coref_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ lora = false
lora_alpha = 128
lora_dropout = 0.1
lora_rank = 64
lora_targets = []
lora_fully_tune = []


# Training settings ==================

Expand Down Expand Up @@ -118,6 +121,8 @@ bert_model = "roberta-large"
bert_model = "roberta-large"
bert_learning_rate = 0.00005
lora = true
lora_targets = [ "query", "value", "output.dense", "intermediate.dense" ]
lora_fully_tune = [ "pooler" ]

[roberta_no_finetune]
bert_model = "roberta-large"
Expand Down
5 changes: 2 additions & 3 deletions stanza/models/coref/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,10 @@ def __init__(self,
logger.debug("Creating lora adapter with rank %d", self.config.lora_rank)
self.__peft_config = LoraConfig(inference_mode=False,
r=self.config.lora_rank,
target_modules=["query", "value",
"output.dense", "intermediate.dense"],
target_modules=self.config.lora_targets,
lora_alpha=self.config.lora_alpha,
lora_dropout=self.config.lora_dropout,
modules_to_save=["pooler"],
modules_to_save=self.config.lora_fully_tune,
bias="none")

self.bert = get_peft_model(self.bert, self.__peft_config)
Expand Down

0 comments on commit e396123

Please sign in to comment.