Skip to content

Commit

Permalink
Add dropout options to optimize overfitting
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Jan 5, 2024
1 parent 38c599b commit 39f6902
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
1 change: 1 addition & 0 deletions fish_speech/configs/text2semantic_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ model:
norm_eps: 1e-5
num_codebooks: 4 # single codebook
codebook_size: 168 # codebook size 160 + 2 special tokens
dropout: 0.1 # For small dataset, dropout helps to prevent overfitting

optimizer:
_target_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions fish_speech/configs/text2semantic_finetune_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ model:
norm_eps: 1e-5
num_codebooks: 4 # single codebook
codebook_size: 168 # codebook size 160 + 2 special tokens
dropout: 0.1 # For small dataset, dropout helps to prevent overfitting

lora_config:
_target_: fish_speech.models.text2semantic.lit_module.LoraConfig
Expand Down
12 changes: 10 additions & 2 deletions fish_speech/models/text2semantic/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ModelArgs:
rope_base: float = 10000
norm_eps: float = 1e-5
max_seq_len: int = 2048
dropout: float = 0.0

# Additional decoding heads
codebook_size: int = 160
Expand Down Expand Up @@ -260,6 +261,7 @@ def __init__(self, config: ModelArgs):
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None

self.dropout = config.dropout
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
Expand Down Expand Up @@ -301,7 +303,13 @@ def forward(

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
else:
Expand All @@ -311,7 +319,7 @@ def forward(

# We don't need to transpose q, k, v here because flash_attn_varlen_func
attn_output = self._flash_attention_forward(
q, k, v, mask, seqlen, dropout=0.0
q, k, v, mask, seqlen, dropout=self.dropout if self.training else 0.0
)

y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
Expand Down

0 comments on commit 39f6902

Please sign in to comment.