diff --git a/stanza/models/coref/config.py b/stanza/models/coref/config.py index 92b7eb0526..1b23d3debb 100644 --- a/stanza/models/coref/config.py +++ b/stanza/models/coref/config.py @@ -4,7 +4,7 @@ """ from dataclasses import dataclass -from typing import Dict +from typing import Dict, List @dataclass @@ -40,6 +40,9 @@ class Config: # pylint: disable=too-many-instance-attributes, too-few-public-me lora_rank: int lora_dropout: float + lora_targets: List[str] + lora_fully_tune: List[str] + bert_finetune: bool dropout_rate: float learning_rate: float diff --git a/stanza/models/coref/coref_config.toml b/stanza/models/coref/coref_config.toml index 57d7757939..f5100592b1 100755 --- a/stanza/models/coref/coref_config.toml +++ b/stanza/models/coref/coref_config.toml @@ -71,6 +71,9 @@ lora = false lora_alpha = 128 lora_dropout = 0.1 lora_rank = 64 +lora_targets = [] +lora_fully_tune = [] + # Training settings ================== @@ -118,6 +121,8 @@ bert_model = "roberta-large" bert_model = "roberta-large" bert_learning_rate = 0.00005 lora = true +lora_targets = [ "query", "value", "output.dense", "intermediate.dense" ] +lora_fully_tune = [ "pooler" ] [roberta_no_finetune] bert_model = "roberta-large" diff --git a/stanza/models/coref/model.py b/stanza/models/coref/model.py index 2d89f7cf76..3d18328d69 100644 --- a/stanza/models/coref/model.py +++ b/stanza/models/coref/model.py @@ -80,11 +80,10 @@ def __init__(self, logger.debug("Creating lora adapter with rank %d", self.config.lora_rank) self.__peft_config = LoraConfig(inference_mode=False, r=self.config.lora_rank, - target_modules=["query", "value", - "output.dense", "intermediate.dense"], + target_modules=self.config.lora_targets, lora_alpha=self.config.lora_alpha, lora_dropout=self.config.lora_dropout, - modules_to_save=["pooler"], + modules_to_save=self.config.lora_fully_tune, bias="none") self.bert = get_peft_model(self.bert, self.__peft_config)