Skip to content

Commit

Permalink
[feature]add dataset classs (#775)
Browse files Browse the repository at this point in the history
* [feature]add dataset classs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Whale-Dolphin and pre-commit-ci[bot] authored Dec 21, 2024
1 parent b8bdcd4 commit 7902e40
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 39 deletions.
4 changes: 2 additions & 2 deletions fish_speech/configs/text2semantic_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tokenizer:

# Dataset Configuration
train_dataset:
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
proto_files:
- data/protos
tokenizer: ${tokenizer}
Expand All @@ -36,7 +36,7 @@ train_dataset:
interactive_prob: 0.7

val_dataset:
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
proto_files:
- data/protos
tokenizer: ${tokenizer}
Expand Down
259 changes: 222 additions & 37 deletions fish_speech/datasets/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from huggingface_hub import HfApi
from lightning import LightningDataModule
from torch.distributed import get_rank, get_world_size, is_initialized
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info

from fish_speech.conversation import (
CODEBOOK_PAD_TOKEN_ID,
Expand Down Expand Up @@ -59,7 +59,7 @@ def split_by_rank_worker(files):
return files


class AutoTextSemanticInstructionDataset(IterableDataset):
class AutoTextSemanticInstructionIterableDataset(IterableDataset):
"""
Auto Augment Dataset by Speaker
Expand Down Expand Up @@ -295,6 +295,214 @@ def augment(self):
return data


class AutoTextSemanticInstructionDataset(Dataset):
"""
Auto Augment Dataset by Speaker
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
2. Automatically normalize the text
For interactive mode, we use the following format (multiple sequences):
<s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
For non-interactive mode, we use the following format (one long sequence):
<s> [INST] text [/INST] ... </s>
"""

def __init__(
self,
proto_files: list[str],
seed: int = 42,
interactive_prob: float = 0.5,
max_length: int = 1024,
tokenizer: FishTokenizer = None,
use_speaker: bool | float = True,
causal: bool = True,
num_codebooks: Optional[int] = None,
skip_text_prob: float = 0.0,
):
"""
Args:
proto_files: proto buf files if using local data
seed: random seed
interactive_prob: probability to use interactive mode
max_length: max length of the text
tokenizer: tokenizer
use_speaker: include speaker information in the prompt
causal: use causal sampling when using local data, disable will lead to random sampling
num_codebooks: number of codebooks, if None, it will be automatically detected
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
"""
super().__init__()

assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"

self.seed = seed
self.max_length = max_length
self.tokenizer = tokenizer
self.interactive_prob = interactive_prob
self.use_speaker = use_speaker
self.proto_files = proto_files
self.causal = causal
self.num_codebooks = num_codebooks
self.skip_text_prob = skip_text_prob

self.data = []
self._init_data()

def _init_data(self):
expanded_proto_files = []
for filename in self.proto_files:
for i in braceexpand(filename):
i = Path(i)
if i.is_file():
expanded_proto_files.append(i)
elif i.is_dir():
expanded_proto_files.extend(i.rglob("*.proto"))
expanded_proto_files.extend(i.rglob("*.protos"))
else:
raise ValueError(f"{i} is not a file or directory")

expanded_proto_files = sorted(expanded_proto_files)
Random(self.seed).shuffle(expanded_proto_files)

groups = []
shard_proto_files = split_by_rank_worker(expanded_proto_files)
log.info(
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
)

count = 0
for filename in shard_proto_files:
with open(filename, "rb") as f:
for text_data in read_pb_stream(f):
groups.append(text_data)
count += 1

log.info(f"Read total {count} groups of data")

for group in groups:
if len(group.sentences) == 0:
continue

samples = list(group.sentences)
for sentence in samples:
text = clean_text(random.choice(sentence.texts))

tokens, labels = self.pack_sentences(
sentences=[text],
semantics=[sentence.semantics],
skip_text=random.random() < self.skip_text_prob,
)

self.data.append({"tokens": tokens, "labels": labels})

random.Random(self.seed).shuffle(self.data)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx]

def pack_sentences(
self,
sentences: list[str],
semantics: list,
skip_text: bool = False,
):
messages = [
Message(
role="system",
parts=[TextPart(text="Speak out the provided text.")],
)
]

cated_sentences = " ".join(sentences)
if skip_text:
cated_sentences = "<|skip_text|>"

messages.append(
Message(
role="user",
parts=[TextPart(text=cated_sentences)],
)
)

vq_codes = [x.values for x in semantics[0]]
vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
vqpart = VQPart(codes=vq_codes_tensor)
messages.append(
Message(
role="assistant",
parts=[TextPart(text="<|voice|>"), vqpart],
cal_loss=True,
)
)

num_codebooks = (
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
)

conversation = Conversation(messages=messages)
encoded = conversation.encode(
tokenizer=self.tokenizer,
)

tokens_raw = encoded.tokens
tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
tokens[0] = tokens_raw

vq_parts = encoded.vq_parts
vq_parts = [part.to(tokens.device) for part in vq_parts]
vq_parts = torch.cat(vq_parts, dim=1)
tokens[1:, encoded.vq_mask_tokens] = vq_parts

labels_raw = encoded.labels
labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
labels[0, :] = labels_raw
labels[1:, encoded.vq_mask_labels] = vq_parts
labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID

tokens = tokens.long()
labels = labels.long()

assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()

return tokens, labels


class InterleaveDataset(IterableDataset):
def __init__(
self,
datasets: list[IterableDataset],
probabilities: list[float],
seed: int = 42,
):
super().__init__()

self.datasets = datasets
self.probabilities = probabilities
self.seed = seed

def __iter__(self):
rng = np.random.default_rng(self.seed)
dataset_iterators = [iter(dataset) for dataset in self.datasets]

while True:
# Random choice one
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
dataset_iterator = dataset_iterators[dataset_idx]

try:
yield next(dataset_iterator)
except StopIteration:
# Exhausted, create a new iterator
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
yield next(dataset_iterators[dataset_idx])


@dataclass
class TextDataCollator:
tokenizer: FishTokenizer
Expand Down Expand Up @@ -369,41 +577,19 @@ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
}


class InterleaveDataset(IterableDataset):
def __init__(
self,
datasets: list[IterableDataset],
probabilities: list[float],
seed: int = 42,
):
super().__init__()

self.datasets = datasets
self.probabilities = probabilities
self.seed = seed

def __iter__(self):
rng = np.random.default_rng(self.seed)
dataset_iterators = [iter(dataset) for dataset in self.datasets]

while True:
# Random choice one
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
dataset_iterator = dataset_iterators[dataset_idx]

try:
yield next(dataset_iterator)
except StopIteration:
# Exhausted, create a new iterator
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
yield next(dataset_iterators[dataset_idx])


class SemanticDataModule(LightningDataModule):
def __init__(
self,
train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
train_dataset: Union[
AutoTextSemanticInstructionDataset,
AutoTextSemanticInstructionIterableDataset,
InterleaveDataset,
],
val_dataset: Union[
AutoTextSemanticInstructionDataset,
AutoTextSemanticInstructionIterableDataset,
InterleaveDataset,
],
batch_size: int = 32,
tokenizer: FishTokenizer = None,
max_length: int = 1024,
Expand Down Expand Up @@ -448,7 +634,6 @@ def val_dataloader(self):
skip_text_prob=0.5,
)

for i in ds:
for i in range(100):
# Please uncomment line 235 to visualize the tokenized message
print(i)
break
print(ds[i])

0 comments on commit 7902e40

Please sign in to comment.