From 7902e408c85b37193ce3f3c2361ca3c76be0d533 Mon Sep 17 00:00:00 2001 From: Whale and Dolphin <70465000+Whale-Dolphin@users.noreply.github.com> Date: Sat, 21 Dec 2024 23:27:38 +0800 Subject: [PATCH] [feature]add dataset classs (#775) * [feature]add dataset classs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../configs/text2semantic_finetune.yaml | 4 +- fish_speech/datasets/semantic.py | 259 +++++++++++++++--- 2 files changed, 224 insertions(+), 39 deletions(-) diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml index 05efad37..ef0d945a 100644 --- a/fish_speech/configs/text2semantic_finetune.yaml +++ b/fish_speech/configs/text2semantic_finetune.yaml @@ -26,7 +26,7 @@ tokenizer: # Dataset Configuration train_dataset: - _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset proto_files: - data/protos tokenizer: ${tokenizer} @@ -36,7 +36,7 @@ train_dataset: interactive_prob: 0.7 val_dataset: - _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset proto_files: - data/protos tokenizer: ${tokenizer} diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py index df3bf9ec..84032d03 100644 --- a/fish_speech/datasets/semantic.py +++ b/fish_speech/datasets/semantic.py @@ -13,7 +13,7 @@ from huggingface_hub import HfApi from lightning import LightningDataModule from torch.distributed import get_rank, get_world_size, is_initialized -from torch.utils.data import DataLoader, IterableDataset, get_worker_info +from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info from fish_speech.conversation import ( CODEBOOK_PAD_TOKEN_ID, @@ -59,7 +59,7 @@ def split_by_rank_worker(files): return files -class AutoTextSemanticInstructionDataset(IterableDataset): +class AutoTextSemanticInstructionIterableDataset(IterableDataset): """ Auto Augment Dataset by Speaker @@ -295,6 +295,214 @@ def augment(self): return data +class AutoTextSemanticInstructionDataset(Dataset): + """ + Auto Augment Dataset by Speaker + + 1. Random concatenate multiple sentences from the same speaker to form a longer sentence + 2. Automatically normalize the text + + For interactive mode, we use the following format (multiple sequences): + [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] + + For non-interactive mode, we use the following format (one long sequence): + [INST] text [/INST] ... + """ + + def __init__( + self, + proto_files: list[str], + seed: int = 42, + interactive_prob: float = 0.5, + max_length: int = 1024, + tokenizer: FishTokenizer = None, + use_speaker: bool | float = True, + causal: bool = True, + num_codebooks: Optional[int] = None, + skip_text_prob: float = 0.0, + ): + """ + Args: + proto_files: proto buf files if using local data + seed: random seed + interactive_prob: probability to use interactive mode + max_length: max length of the text + tokenizer: tokenizer + use_speaker: include speaker information in the prompt + causal: use causal sampling when using local data, disable will lead to random sampling + num_codebooks: number of codebooks, if None, it will be automatically detected + skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode + """ + super().__init__() + + assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" + + self.seed = seed + self.max_length = max_length + self.tokenizer = tokenizer + self.interactive_prob = interactive_prob + self.use_speaker = use_speaker + self.proto_files = proto_files + self.causal = causal + self.num_codebooks = num_codebooks + self.skip_text_prob = skip_text_prob + + self.data = [] + self._init_data() + + def _init_data(self): + expanded_proto_files = [] + for filename in self.proto_files: + for i in braceexpand(filename): + i = Path(i) + if i.is_file(): + expanded_proto_files.append(i) + elif i.is_dir(): + expanded_proto_files.extend(i.rglob("*.proto")) + expanded_proto_files.extend(i.rglob("*.protos")) + else: + raise ValueError(f"{i} is not a file or directory") + + expanded_proto_files = sorted(expanded_proto_files) + Random(self.seed).shuffle(expanded_proto_files) + + groups = [] + shard_proto_files = split_by_rank_worker(expanded_proto_files) + log.info( + f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" + ) + + count = 0 + for filename in shard_proto_files: + with open(filename, "rb") as f: + for text_data in read_pb_stream(f): + groups.append(text_data) + count += 1 + + log.info(f"Read total {count} groups of data") + + for group in groups: + if len(group.sentences) == 0: + continue + + samples = list(group.sentences) + for sentence in samples: + text = clean_text(random.choice(sentence.texts)) + + tokens, labels = self.pack_sentences( + sentences=[text], + semantics=[sentence.semantics], + skip_text=random.random() < self.skip_text_prob, + ) + + self.data.append({"tokens": tokens, "labels": labels}) + + random.Random(self.seed).shuffle(self.data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def pack_sentences( + self, + sentences: list[str], + semantics: list, + skip_text: bool = False, + ): + messages = [ + Message( + role="system", + parts=[TextPart(text="Speak out the provided text.")], + ) + ] + + cated_sentences = " ".join(sentences) + if skip_text: + cated_sentences = "<|skip_text|>" + + messages.append( + Message( + role="user", + parts=[TextPart(text=cated_sentences)], + ) + ) + + vq_codes = [x.values for x in semantics[0]] + vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32) + vqpart = VQPart(codes=vq_codes_tensor) + messages.append( + Message( + role="assistant", + parts=[TextPart(text="<|voice|>"), vqpart], + cal_loss=True, + ) + ) + + num_codebooks = ( + len(semantics[0]) if self.num_codebooks is None else self.num_codebooks + ) + + conversation = Conversation(messages=messages) + encoded = conversation.encode( + tokenizer=self.tokenizer, + ) + + tokens_raw = encoded.tokens + tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int) + tokens[0] = tokens_raw + + vq_parts = encoded.vq_parts + vq_parts = [part.to(tokens.device) for part in vq_parts] + vq_parts = torch.cat(vq_parts, dim=1) + tokens[1:, encoded.vq_mask_tokens] = vq_parts + + labels_raw = encoded.labels + labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int) + labels[0, :] = labels_raw + labels[1:, encoded.vq_mask_labels] = vq_parts + labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID + + tokens = tokens.long() + labels = labels.long() + + assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all() + assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() + + return tokens, labels + + +class InterleaveDataset(IterableDataset): + def __init__( + self, + datasets: list[IterableDataset], + probabilities: list[float], + seed: int = 42, + ): + super().__init__() + + self.datasets = datasets + self.probabilities = probabilities + self.seed = seed + + def __iter__(self): + rng = np.random.default_rng(self.seed) + dataset_iterators = [iter(dataset) for dataset in self.datasets] + + while True: + # Random choice one + dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) + dataset_iterator = dataset_iterators[dataset_idx] + + try: + yield next(dataset_iterator) + except StopIteration: + # Exhausted, create a new iterator + dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) + yield next(dataset_iterators[dataset_idx]) + + @dataclass class TextDataCollator: tokenizer: FishTokenizer @@ -369,41 +577,19 @@ def batchify(self, examples, tokens_key="tokens", labels_key="labels"): } -class InterleaveDataset(IterableDataset): - def __init__( - self, - datasets: list[IterableDataset], - probabilities: list[float], - seed: int = 42, - ): - super().__init__() - - self.datasets = datasets - self.probabilities = probabilities - self.seed = seed - - def __iter__(self): - rng = np.random.default_rng(self.seed) - dataset_iterators = [iter(dataset) for dataset in self.datasets] - - while True: - # Random choice one - dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) - dataset_iterator = dataset_iterators[dataset_idx] - - try: - yield next(dataset_iterator) - except StopIteration: - # Exhausted, create a new iterator - dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) - yield next(dataset_iterators[dataset_idx]) - - class SemanticDataModule(LightningDataModule): def __init__( self, - train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], - val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], + train_dataset: Union[ + AutoTextSemanticInstructionDataset, + AutoTextSemanticInstructionIterableDataset, + InterleaveDataset, + ], + val_dataset: Union[ + AutoTextSemanticInstructionDataset, + AutoTextSemanticInstructionIterableDataset, + InterleaveDataset, + ], batch_size: int = 32, tokenizer: FishTokenizer = None, max_length: int = 1024, @@ -448,7 +634,6 @@ def val_dataloader(self): skip_text_prob=0.5, ) - for i in ds: + for i in range(100): # Please uncomment line 235 to visualize the tokenized message - print(i) - break + print(ds[i])