From 40665e1a39d05384401cc510dd4b490042565c00 Mon Sep 17 00:00:00 2001 From: Whale and Dolphin <70465000+Whale-Dolphin@users.noreply.github.com> Date: Sat, 21 Dec 2024 13:10:18 +0800 Subject: [PATCH] [fix]fix problems to let version 1.5 support sft (#774) * [docs]Add docs of Fish Agent. * [docs]:Fix some issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [docs]Add Chinese docs for Fish Agent * [docs]fix some issue * [docs]fix the bug that chinese page display wrong * [docs]Fix bugs in Chinese docs and add translated docs of agent for other language. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [feature]: change conversation.visual color and semantic encoded method * [feature]:change collate_fn tokenizer to FishTokenzier * [fix]fix some dimension problem in semantic.py * [feature]change conf to tiktoken * [fix]:fix ddp training problem * [feature]use conversation to replace manully tokens and labels generate * [fix]fix embedding calculate in BaseTransformer forward * [fix]use einops to operate tensor to avoid bugs * [fix]fix bugs in generate and llama for sft * [fix]delete unused codes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: whaledolphin --- .gitignore | 1 + .../configs/text2semantic_finetune.yaml | 13 +- fish_speech/conversation.py | 33 ++- fish_speech/datasets/semantic.py | 242 ++++++++---------- fish_speech/models/text2semantic/llama.py | 83 +++--- tools/llama/merge_lora.py | 17 +- 6 files changed, 174 insertions(+), 215 deletions(-) diff --git a/.gitignore b/.gitignore index e5305b2f..b9e9fef0 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ asr-label* /example /faster_whisper /.gradio +*log diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml index f4c19930..05efad37 100644 --- a/fish_speech/configs/text2semantic_finetune.yaml +++ b/fish_speech/configs/text2semantic_finetune.yaml @@ -4,22 +4,25 @@ defaults: project: text2semantic_finetune_dual_ar max_length: 4096 -pretrained_ckpt_path: checkpoints/fish-speech-1.4 +pretrained_ckpt_path: checkpoints/fish-speech-1.5 # Lightning Trainer trainer: accumulate_grad_batches: 1 gradient_clip_val: 1.0 gradient_clip_algorithm: "norm" - max_steps: 1000 + max_steps: 10000 precision: bf16-true limit_val_batches: 10 val_check_interval: 100 + # strategy: + # find_unused_parameters: true + # static_graph: true # Dataset Configuration tokenizer: - _target_: transformers.AutoTokenizer.from_pretrained - pretrained_model_name_or_path: ${pretrained_ckpt_path} + _target_: fish_speech.tokenizer.FishTokenizer + model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken # Dataset Configuration train_dataset: @@ -47,7 +50,7 @@ data: train_dataset: ${train_dataset} val_dataset: ${val_dataset} num_workers: 4 - batch_size: 8 + batch_size: 4 tokenizer: ${tokenizer} max_length: ${max_length} diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py index 20d8ab32..e15bea9a 100644 --- a/fish_speech/conversation.py +++ b/fish_speech/conversation.py @@ -207,35 +207,34 @@ def visualize( tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens ) - # Colors for alternating tokens colors = { - "blue": "\033[94m", # Light blue - "cyan": "\033[96m", # Cyan - "green": "\033[92m", # Light green - "dark_green": "\033[32m", # Dark green + "purple": "\033[95m", + "yellow": "\033[93m", + "red": "\033[91m", + "cyan": "\033[96m", } - blue_idx = 0 - green_idx = 0 + first_idx = 0 + second_idx = 0 - def print_in_blue(x): - nonlocal blue_idx - color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"] + def print_first_group(x): + nonlocal first_idx + color = colors["purple"] if first_idx % 2 == 0 else colors["yellow"] print(f"{color}{x}\033[0m", end="") - blue_idx += 1 + first_idx += 1 - def print_in_green(x): - nonlocal green_idx - color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"] + def print_second_group(x): + nonlocal second_idx + color = colors["red"] if second_idx % 2 == 0 else colors["cyan"] print(f"{color}{x}\033[0m", end="") - green_idx += 1 + second_idx += 1 for tok, lab in zip(encoded.tokens, encoded.labels): val = tokenizer.decode([tok]) if lab == -100: - print_in_green(val) + print_second_group(val) else: - print_in_blue(val) + print_first_group(val) print() diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py index 3c64e010..df3bf9ec 100644 --- a/fish_speech/datasets/semantic.py +++ b/fish_speech/datasets/semantic.py @@ -14,12 +14,18 @@ from lightning import LightningDataModule from torch.distributed import get_rank, get_world_size, is_initialized from torch.utils.data import DataLoader, IterableDataset, get_worker_info -from transformers import AutoTokenizer -from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.conversation import ( + CODEBOOK_PAD_TOKEN_ID, + Conversation, + Message, + TextPart, + VQPart, +) from fish_speech.datasets.protos.text_data_pb2 import SampledData from fish_speech.datasets.protos.text_data_stream import read_pb_stream from fish_speech.text.clean import clean_text +from fish_speech.tokenizer import FishTokenizer from fish_speech.utils import RankedLogger from fish_speech.utils.braceexpand import braceexpand @@ -73,7 +79,7 @@ def __init__( seed: int = 42, interactive_prob: float = 0.5, max_length: int = 1024, - tokenizer: AutoTokenizer = None, + tokenizer: FishTokenizer = None, use_speaker: bool | float = True, causal: bool = True, num_codebooks: Optional[int] = None, @@ -106,9 +112,12 @@ def __init__( self.num_codebooks = num_codebooks self.skip_text_prob = skip_text_prob - self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>") self.groups = None + def __iter__(self): + while True: + yield self.augment() + def init_mock_data_server(self): if self.groups is not None: return @@ -148,20 +157,6 @@ def init_mock_data_server(self): Random(self.seed).shuffle(self.groups) self.group_weights = [len(i.sentences) for i in self.groups] - def __iter__(self): - while True: - yield self.augment() - - def tokenize_sentence(self, sentence: str): - sentence = clean_text(sentence) - tokens = self.tokenizer.encode( - f"{sentence}", - max_length=10**6, - add_special_tokens=False, - truncation=False, - ) - return sentence, len(tokens) - def sample_data(self): if self.groups is None: self.init_mock_data_server() @@ -190,155 +185,119 @@ def sample_data(self): samples=samples, ) - def augment(self): - final_text, final_semantic = [], [] - response = self.sample_data() - if len(response.samples) == 0: - # Invalid group - return None - - samples = list(response.samples) - idx = 0 - use_interactive = random.random() < self.interactive_prob - - if use_interactive is False: - # Random sample based on speaker using a truncated normal distribution - a = torch.tensor([0], dtype=torch.float32) - torch.nn.init.trunc_normal_( - a, - mean=self.max_length // 2, - std=self.max_length // 4, - a=10, - b=self.max_length, - ) - remaining_tokens = a.long().item() - 4 - else: - remaining_tokens = self.max_length - - # Use speaker - if isinstance(self.use_speaker, float): - use_speaker = random.random() < self.use_speaker - else: - use_speaker = self.use_speaker - - all_tokens, all_labels = [], [] - while remaining_tokens > 0 and len(samples) > 0: - sentence = samples.pop(0) - - text = random.choice(sentence.texts) - text, length = self.tokenize_sentence(text) - remaining_tokens -= length + len(sentence.semantics[0].values) - - if use_interactive is False: - final_text.append(text) - final_semantic.append(sentence.semantics) - else: - # For interactive mode, we only apply speaker for the first sentence - # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] - tokens, labels = self.pack_sentences( - sentences=[text], - semantics=[sentence.semantics], - speaker=response.name if use_speaker else None, - skip_text=random.random() < self.skip_text_prob, - ) - - all_tokens.append(tokens) - all_labels.append(labels) - - idx += 1 - - if use_interactive is False: - tokens, labels = self.pack_sentences( - final_text, - semantics=final_semantic, - speaker=response.name if use_speaker else None, - ) - all_tokens.append(tokens) - all_labels.append(labels) - - tokens = torch.cat(all_tokens, dim=1) - labels = torch.cat(all_labels, dim=1) - - # Verify that the length is correct - assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" - - data = {"tokens": tokens, "labels": labels} - - return data - def pack_sentences( self, sentences: list[str], semantics: list, - speaker: Optional[str] = None, + # speaker: Optional[str] = None, skip_text: bool = False, ): - if speaker is None: - speaker = "assistant" + # if speaker is None: + # speaker = "assistant" + + messages = [ + Message( + role="system", + parts=[TextPart(text="Speak out the provided text.")], + # add_im_end=False, + # cal_loss=True, + ) + ] cated_sentences = " ".join(sentences) if skip_text: cated_sentences = "<|skip_text|>" - final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>" - final_text = final_text + f"<|im_start|>{speaker}\n" + messages.append( + Message( + role="user", + parts=[TextPart(text=cated_sentences)], + # cal_loss=True, + ) + ) - encoded = self.tokenizer.encode( - final_text, - add_special_tokens=False, - truncation=False, - max_length=10**6, + vq_codes = [x.values for x in semantics[0]] + vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32) + vqpart = VQPart(codes=vq_codes_tensor) + messages.append( + Message( + role="assistant", + parts=[TextPart(text="<|voice|>"), vqpart], + cal_loss=True, + ) ) - semantic_length = sum([len(i[0].values) for i in semantics]) - prompt_length = len(encoded) + num_codebooks = ( len(semantics[0]) if self.num_codebooks is None else self.num_codebooks ) - # Pack the tokens and semantics (add and to semantic tokens) - tokens = ( - encoded - + [self.semantic_token_id] * semantic_length - + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"]) + conversation = Conversation(messages=messages) + # conversation.visualize(tokenizer=self.tokenizer) + encoded = conversation.encode( + tokenizer=self.tokenizer, ) - # Codebook bos/padding: 0, eos: 1 - codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)] - for segment in semantics: - for book_idx, book in zip(range(num_codebooks), segment): - for j in book.values: - codes[book_idx].append(int(j) + 1) + tokens_raw = encoded.tokens + tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int) + tokens[0] = tokens_raw - for book in codes: - book.extend([CODEBOOK_PAD_TOKEN_ID] * 1) + vq_parts = encoded.vq_parts + vq_parts = [part.to(tokens.device) for part in vq_parts] + vq_parts = torch.cat(vq_parts, dim=1) + tokens[1:, encoded.vq_mask_tokens] = vq_parts - tokens = [tokens] + codes + labels_raw = encoded.labels + labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int) + labels[0, :] = labels_raw + labels[1:, encoded.vq_mask_labels] = vq_parts + labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID - tokens = torch.tensor(tokens, dtype=torch.long) - labels = tokens.clone() - - if skip_text: - # If text is not provided, the sentence is used for condition only, all labels are -100 - torch.fill_(labels, -100) - return tokens, labels - - # Mask out the tokens for semantic, predict semantic tokens only - # Since we don't mask out the input tokens, the language modeling still works - labels[1:, :prompt_length] = -100 - - tokens = tokens[:, :-1] - labels = labels[:, 1:] + tokens = tokens.long() + labels = labels.long() # Verify the padding is correct, and the last token is eos - assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all() + assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all() assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() return tokens, labels + def augment(self): + response = self.sample_data() + if len(response.samples) == 0: + # Invalid group + return None + + samples = list(response.samples) + all_tokens, all_labels = [], [] + + while len(samples) > 0: + sentence = samples.pop(0) + text = clean_text(random.choice(sentence.texts)) + + tokens, labels = self.pack_sentences( + sentences=[text], + semantics=[sentence.semantics], + # speaker=response.name if use_speaker else None, + skip_text=random.random() < self.skip_text_prob, + ) + + all_tokens.append(tokens) + all_labels.append(labels) + + tokens = torch.cat(all_tokens, dim=1) + labels = torch.cat(all_labels, dim=1) + + # Verify that the length is correct + assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" + + data = {"tokens": tokens, "labels": labels} + + return data + @dataclass class TextDataCollator: - tokenizer: AutoTokenizer + tokenizer: FishTokenizer max_length: int = 1024 def __call__(self, examples): @@ -388,7 +347,7 @@ def batchify(self, examples, tokens_key="tokens", labels_key="labels"): _tokens = F.pad( _tokens, (0, max_tokens_length - tokens_length), - value=self.tokenizer.eos_token_id, + value=self.tokenizer.get_token_id("<|end_of_text|>"), ) _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID _labels = F.pad( @@ -446,7 +405,7 @@ def __init__( train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], batch_size: int = 32, - tokenizer: AutoTokenizer = None, + tokenizer: FishTokenizer = None, max_length: int = 1024, num_workers: int = 4, ): @@ -483,14 +442,13 @@ def val_dataloader(self): ds = AutoTextSemanticInstructionDataset( ["data/protos"], - tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"), + tokenizer=FishTokenizer("checkpoints/fish-speech-1.5/tokenizer.tiktoken"), use_speaker=False, interactive_prob=1.0, skip_text_prob=0.5, ) for i in ds: - print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False)) - # i["labels"][0][i["labels"][0] == -100] = 0 - # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False)) + # Please uncomment line 235 to visualize the tokenized message + print(i) break diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py index 1811f091..be24c598 100644 --- a/fish_speech/models/text2semantic/llama.py +++ b/fish_speech/models/text2semantic/llama.py @@ -167,7 +167,7 @@ class BaseTransformer(nn.Module): def __init__( self, config: BaseModelArgs, - tokenizer: FishTokenizer | AutoTokenizer, + tokenizer: FishTokenizer, init_weights: bool = True, ) -> None: super().__init__() @@ -246,17 +246,24 @@ def setup_caches( dtype=dtype, ) - def embed(self, x: Tensor) -> Tensor: - vocab_embeds = [self.embeddings(x[:, 0])] + def embed(self, inp: Tensor, share_codebook_embeddings=True) -> Tensor: + embeds = [] + semantic_token_ids_tensor = torch.tensor( + self.semantic_token_ids, device=inp.device + ) + for i in range(self.config.num_codebooks): - emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size) - semantic_token_ids_tensor = torch.tensor( - self.semantic_token_ids, device=x.device - ) - emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0 + if share_codebook_embeddings: + emb = self.codebook_embeddings( + inp[:, i + 1] + i * self.config.codebook_size + ) + else: + emb = self.codebook_embeddings(inp[:, i + 1]) + embeds.append(emb) - x = torch.stack(vocab_embeds, dim=3) - x = x.sum(dim=3) + vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1) + vq_embeds_sum[~torch.isin(inp[:, 0], semantic_token_ids_tensor)] = 0 + x = self.embeddings(inp[:, 0]) + vq_embeds_sum return x @@ -277,8 +284,14 @@ def forward( # To maintain consistency, key_padding_mask use TRUE to mask out mask = None if key_padding_mask is not None: - mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K) - mask = mask & key_padding_mask[:, None, None, :].logical_not() + causal = self.causal_mask[:seq_len, :seq_len] + causal = rearrange(causal, "q k -> 1 1 q k") + + atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s") + atten_mask = atten_mask.logical_not() + mask = causal & atten_mask + + # return freqs_cis, mask for layer in self.layers: if self.config.use_gradient_checkpointing and self.training: @@ -303,36 +316,12 @@ def forward_generate( self, inp: Tensor, input_pos: Optional[Tensor] = None, - vq_masks: Optional[Tensor] = None, # this is not used in fact return_all: bool = False, ) -> BaseTransformerForwardResult: - # This is used for generation, optimized for torch compile - # assert ( - # self.max_seq_len != -1 and self.max_batch_size != -1 - # ), "Please call setup_caches before forward_generate" - - embeds = [] - for i in range(self.config.num_codebooks): - if self.config.share_codebook_embeddings: - _tokens = inp[:, i + 1] + i * self.config.codebook_size - else: - _tokens = inp[:, i + 1] - - emb = self.codebook_embeddings(_tokens) - embeds.append(emb) - - vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1) - # if self.config.use_codebook_mlp: - # vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks - # vq_embeds_sum = self.codebook_mlp(vq_embeds_sum) - - vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & ( - inp[:, 0] <= self.tokenizer.semantic_end_id + x = self.embed( + inp, share_codebook_embeddings=self.config.share_codebook_embeddings ) - vq_embeds_sum[~vq_masks] = 0 - x = self.embeddings(inp[:, 0]) + vq_embeds_sum - if input_pos is None: input_pos = torch.arange(inp.shape[-1], device=x.device) max_seq_len = inp.shape[-1] @@ -401,11 +390,8 @@ def from_pretrained( case _: raise ValueError(f"Unknown model type: {config.model_type}") - if is_agent: - tokenizer = AutoTokenizer.from_pretrained(str(path)) - else: - tokenizer_path = str(path) + "/tokenizer.tiktoken" - tokenizer = FishTokenizer(tokenizer_path) + tokenizer_path = str(path) + "/tokenizer.tiktoken" + tokenizer = FishTokenizer(tokenizer_path) log.info(f"Loading model from {path}, config: {config}") model = model_cls(config, tokenizer=tokenizer) @@ -862,6 +848,17 @@ def forward(self, x: Tensor) -> Tensor: def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + """ + Precomputes frequency tensors for complex exponentials (cis) + + Args: + seq_len: Length of the sequence for which positional embeddings are needed. + n_elem: Number of elements in the frequency tensor. + base: Base value for the frequency scaling (default: 10000). + + Returns: + A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16). + """ freqs = 1.0 / ( base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) ) diff --git a/tools/llama/merge_lora.py b/tools/llama/merge_lora.py index c1bd3cbd..1d0345c5 100644 --- a/tools/llama/merge_lora.py +++ b/tools/llama/merge_lora.py @@ -76,19 +76,20 @@ def merge(lora_config, base_weight, lora_weight, output): new_state_dict = torch.load(output / "model.pth", map_location="cpu") original_keys = set(llama_state_dict_copy.keys()) - merged_keys = set(new_state_dict.keys()) - - assert original_keys == merged_keys, "Keys should be same" + tolerance = 1e-5 for key in original_keys: diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item() - if diff_l1 != 0: + if diff_l1 > tolerance: + logger.info(f"Significant difference found in key: {key}") break - else: - logger.error("Merged model is same as the original model") - exit(1) - logger.info("Merged model is different from the original model, check passed") + if diff_l1 <= tolerance: + logger.warning( + "Merged model seems identical to the original model. Further validation might be needed." + ) + else: + logger.info("Merged model is different from the original model, check passed") if __name__ == "__main__":