From 7902e408c85b37193ce3f3c2361ca3c76be0d533 Mon Sep 17 00:00:00 2001
From: Whale and Dolphin <70465000+Whale-Dolphin@users.noreply.github.com>
Date: Sat, 21 Dec 2024 23:27:38 +0800
Subject: [PATCH] [feature]add dataset classs (#775)
* [feature]add dataset classs
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.../configs/text2semantic_finetune.yaml | 4 +-
fish_speech/datasets/semantic.py | 259 +++++++++++++++---
2 files changed, 224 insertions(+), 39 deletions(-)
diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml
index 05efad37..ef0d945a 100644
--- a/fish_speech/configs/text2semantic_finetune.yaml
+++ b/fish_speech/configs/text2semantic_finetune.yaml
@@ -26,7 +26,7 @@ tokenizer:
# Dataset Configuration
train_dataset:
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
proto_files:
- data/protos
tokenizer: ${tokenizer}
@@ -36,7 +36,7 @@ train_dataset:
interactive_prob: 0.7
val_dataset:
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
proto_files:
- data/protos
tokenizer: ${tokenizer}
diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py
index df3bf9ec..84032d03 100644
--- a/fish_speech/datasets/semantic.py
+++ b/fish_speech/datasets/semantic.py
@@ -13,7 +13,7 @@
from huggingface_hub import HfApi
from lightning import LightningDataModule
from torch.distributed import get_rank, get_world_size, is_initialized
-from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
from fish_speech.conversation import (
CODEBOOK_PAD_TOKEN_ID,
@@ -59,7 +59,7 @@ def split_by_rank_worker(files):
return files
-class AutoTextSemanticInstructionDataset(IterableDataset):
+class AutoTextSemanticInstructionIterableDataset(IterableDataset):
"""
Auto Augment Dataset by Speaker
@@ -295,6 +295,214 @@ def augment(self):
return data
+class AutoTextSemanticInstructionDataset(Dataset):
+ """
+ Auto Augment Dataset by Speaker
+
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+ 2. Automatically normalize the text
+
+ For interactive mode, we use the following format (multiple sequences):
+ [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+
+ For non-interactive mode, we use the following format (one long sequence):
+ [INST] text [/INST] ...
+ """
+
+ def __init__(
+ self,
+ proto_files: list[str],
+ seed: int = 42,
+ interactive_prob: float = 0.5,
+ max_length: int = 1024,
+ tokenizer: FishTokenizer = None,
+ use_speaker: bool | float = True,
+ causal: bool = True,
+ num_codebooks: Optional[int] = None,
+ skip_text_prob: float = 0.0,
+ ):
+ """
+ Args:
+ proto_files: proto buf files if using local data
+ seed: random seed
+ interactive_prob: probability to use interactive mode
+ max_length: max length of the text
+ tokenizer: tokenizer
+ use_speaker: include speaker information in the prompt
+ causal: use causal sampling when using local data, disable will lead to random sampling
+ num_codebooks: number of codebooks, if None, it will be automatically detected
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+ """
+ super().__init__()
+
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+ self.seed = seed
+ self.max_length = max_length
+ self.tokenizer = tokenizer
+ self.interactive_prob = interactive_prob
+ self.use_speaker = use_speaker
+ self.proto_files = proto_files
+ self.causal = causal
+ self.num_codebooks = num_codebooks
+ self.skip_text_prob = skip_text_prob
+
+ self.data = []
+ self._init_data()
+
+ def _init_data(self):
+ expanded_proto_files = []
+ for filename in self.proto_files:
+ for i in braceexpand(filename):
+ i = Path(i)
+ if i.is_file():
+ expanded_proto_files.append(i)
+ elif i.is_dir():
+ expanded_proto_files.extend(i.rglob("*.proto"))
+ expanded_proto_files.extend(i.rglob("*.protos"))
+ else:
+ raise ValueError(f"{i} is not a file or directory")
+
+ expanded_proto_files = sorted(expanded_proto_files)
+ Random(self.seed).shuffle(expanded_proto_files)
+
+ groups = []
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
+ log.info(
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+ )
+
+ count = 0
+ for filename in shard_proto_files:
+ with open(filename, "rb") as f:
+ for text_data in read_pb_stream(f):
+ groups.append(text_data)
+ count += 1
+
+ log.info(f"Read total {count} groups of data")
+
+ for group in groups:
+ if len(group.sentences) == 0:
+ continue
+
+ samples = list(group.sentences)
+ for sentence in samples:
+ text = clean_text(random.choice(sentence.texts))
+
+ tokens, labels = self.pack_sentences(
+ sentences=[text],
+ semantics=[sentence.semantics],
+ skip_text=random.random() < self.skip_text_prob,
+ )
+
+ self.data.append({"tokens": tokens, "labels": labels})
+
+ random.Random(self.seed).shuffle(self.data)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+ def pack_sentences(
+ self,
+ sentences: list[str],
+ semantics: list,
+ skip_text: bool = False,
+ ):
+ messages = [
+ Message(
+ role="system",
+ parts=[TextPart(text="Speak out the provided text.")],
+ )
+ ]
+
+ cated_sentences = " ".join(sentences)
+ if skip_text:
+ cated_sentences = "<|skip_text|>"
+
+ messages.append(
+ Message(
+ role="user",
+ parts=[TextPart(text=cated_sentences)],
+ )
+ )
+
+ vq_codes = [x.values for x in semantics[0]]
+ vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
+ vqpart = VQPart(codes=vq_codes_tensor)
+ messages.append(
+ Message(
+ role="assistant",
+ parts=[TextPart(text="<|voice|>"), vqpart],
+ cal_loss=True,
+ )
+ )
+
+ num_codebooks = (
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+ )
+
+ conversation = Conversation(messages=messages)
+ encoded = conversation.encode(
+ tokenizer=self.tokenizer,
+ )
+
+ tokens_raw = encoded.tokens
+ tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
+ tokens[0] = tokens_raw
+
+ vq_parts = encoded.vq_parts
+ vq_parts = [part.to(tokens.device) for part in vq_parts]
+ vq_parts = torch.cat(vq_parts, dim=1)
+ tokens[1:, encoded.vq_mask_tokens] = vq_parts
+
+ labels_raw = encoded.labels
+ labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
+ labels[0, :] = labels_raw
+ labels[1:, encoded.vq_mask_labels] = vq_parts
+ labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
+
+ tokens = tokens.long()
+ labels = labels.long()
+
+ assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
+
+ return tokens, labels
+
+
+class InterleaveDataset(IterableDataset):
+ def __init__(
+ self,
+ datasets: list[IterableDataset],
+ probabilities: list[float],
+ seed: int = 42,
+ ):
+ super().__init__()
+
+ self.datasets = datasets
+ self.probabilities = probabilities
+ self.seed = seed
+
+ def __iter__(self):
+ rng = np.random.default_rng(self.seed)
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+ while True:
+ # Random choice one
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+ dataset_iterator = dataset_iterators[dataset_idx]
+
+ try:
+ yield next(dataset_iterator)
+ except StopIteration:
+ # Exhausted, create a new iterator
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+ yield next(dataset_iterators[dataset_idx])
+
+
@dataclass
class TextDataCollator:
tokenizer: FishTokenizer
@@ -369,41 +577,19 @@ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
}
-class InterleaveDataset(IterableDataset):
- def __init__(
- self,
- datasets: list[IterableDataset],
- probabilities: list[float],
- seed: int = 42,
- ):
- super().__init__()
-
- self.datasets = datasets
- self.probabilities = probabilities
- self.seed = seed
-
- def __iter__(self):
- rng = np.random.default_rng(self.seed)
- dataset_iterators = [iter(dataset) for dataset in self.datasets]
-
- while True:
- # Random choice one
- dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
- dataset_iterator = dataset_iterators[dataset_idx]
-
- try:
- yield next(dataset_iterator)
- except StopIteration:
- # Exhausted, create a new iterator
- dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
- yield next(dataset_iterators[dataset_idx])
-
-
class SemanticDataModule(LightningDataModule):
def __init__(
self,
- train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
- val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+ train_dataset: Union[
+ AutoTextSemanticInstructionDataset,
+ AutoTextSemanticInstructionIterableDataset,
+ InterleaveDataset,
+ ],
+ val_dataset: Union[
+ AutoTextSemanticInstructionDataset,
+ AutoTextSemanticInstructionIterableDataset,
+ InterleaveDataset,
+ ],
batch_size: int = 32,
tokenizer: FishTokenizer = None,
max_length: int = 1024,
@@ -448,7 +634,6 @@ def val_dataloader(self):
skip_text_prob=0.5,
)
- for i in ds:
+ for i in range(100):
# Please uncomment line 235 to visualize the tokenized message
- print(i)
- break
+ print(ds[i])