Skip to content

Commit

Permalink
[pre-commit.ci] pre-commit autoupdate (#80)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/psf/black: 23.12.1 → 24.1.1](psf/black@23.12.1...24.1.1)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pre-commit-ci[bot] authored Mar 4, 2024
1 parent 0847f68 commit b6e02e0
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
args: [--profile=black]

- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.1.1
hooks:
- id: black

Expand Down
6 changes: 3 additions & 3 deletions fish_speech/datasets/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ def augment(self):

tokens = torch.tensor(tokens, dtype=torch.long)
labels = tokens.clone()
labels[
1:, : len(encoded) + 1
] = -100 # Mask out the <s> tokens for semantic
labels[1:, : len(encoded) + 1] = (
-100
) # Mask out the <s> tokens for semantic

return {
"tokens": tokens[:, :-1],
Expand Down
8 changes: 5 additions & 3 deletions fish_speech/models/vqgan/modules/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,11 @@ def __init__(
in_channels, vq_channels, kernel_size=downsample, stride=downsample
)
self.conv_out = nn.Sequential(
nn.Upsample(scale_factor=downsample, mode="nearest")
if downsample > 1
else nn.Identity(),
(
nn.Upsample(scale_factor=downsample, mode="nearest")
if downsample > 1
else nn.Identity()
),
nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
)

Expand Down
6 changes: 5 additions & 1 deletion fish_speech/text/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,8 @@ def g2p(text, order=None):
)
print(segments)

print(clean_text("测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。<p:123> <p:aH>"))
print(
clean_text(
"测试一下 Hugging face, BGM声音很大吗?那我改一下. 世界、こんにちは。<p:123> <p:aH>"
)
)
5 changes: 4 additions & 1 deletion fish_speech/text/tone_sandhi.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
# 个做量词
elif (
ge_idx >= 1
and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
and (
word[ge_idx - 1].isnumeric()
or word[ge_idx - 1] in "几有两半多各整每做是"
)
) or word == "个":
finals[ge_idx] = finals[ge_idx][:-1] + "5"
else:
Expand Down
10 changes: 7 additions & 3 deletions fish_speech/utils/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ def print_config_tree(

# add fields from `print_order` to queue
for field in print_order:
queue.append(field) if field in cfg else log.warning(
f"Field '{field}' not found in config. "
+ f"Skipping '{field}' config printing..."
(
queue.append(field)
if field in cfg
else log.warning(
f"Field '{field}' not found in config. "
+ f"Skipping '{field}' config printing..."
)
)

# add all the other fields to queue (not specified in `print_order`)
Expand Down
12 changes: 9 additions & 3 deletions fish_speech/webui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,14 @@ def build_model_config_block():
llama_ckpt_path = gr.Textbox(
label="Llama 模型路径", value="checkpoints/text2semantic-400m-v0.2-4k.pth"
)
llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
llama_config_name = gr.Textbox(
label="Llama 配置文件", value="text2semantic_finetune"
)
tokenizer = gr.Textbox(label="Tokenizer", value="fishaudio/speech-lm-v1")

vqgan_ckpt_path = gr.Textbox(label="VQGAN 模型路径", value="checkpoints/vqgan-v1.pth")
vqgan_ckpt_path = gr.Textbox(
label="VQGAN 模型路径", value="checkpoints/vqgan-v1.pth"
)
vqgan_config_name = gr.Textbox(label="VQGAN 配置文件", value="vqgan_pretrain")

load_model_btn = gr.Button(value="加载模型", variant="primary")
Expand Down Expand Up @@ -269,7 +273,9 @@ def inference(

with gr.Row():
with gr.Tab(label="合成参数"):
gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.")
gr.Markdown(
"配置常见合成参数. 自动音素会在推理时自动将文本转换为音素."
)

input_mode = gr.Dropdown(
choices=["文本", "自动音素"],
Expand Down
12 changes: 8 additions & 4 deletions tools/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def decode_one_token(
codebooks.append(
sample(
logits.codebook_logits[:, :, i],
previous_tokens=previous_tokens[i + 1]
if previous_tokens is not None
else None,
previous_tokens=(
previous_tokens[i + 1] if previous_tokens is not None else None
),
**sampling_kwargs,
)[0]
)
Expand Down Expand Up @@ -362,7 +362,11 @@ def load_model(config_name, checkpoint_path, device, precision):


@click.command()
@click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
@click.option(
"--text",
type=str,
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
)
@click.option("--prompt-text", type=str, default=None)
@click.option(
"--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
Expand Down
4 changes: 3 additions & 1 deletion tools/merge_asr_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ def merge_and_delete_files(save_dir, original_files):


if __name__ == "__main__":
merge_and_delete_files("/home/spicysama/fish-speech/data/demo/首次揭秘B站百大是怎么选出来的")
merge_and_delete_files(
"/home/spicysama/fish-speech/data/demo/首次揭秘B站百大是怎么选出来的"
)
1 change: 1 addition & 0 deletions tools/whisper_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
"""

from pathlib import Path

import click
Expand Down

0 comments on commit b6e02e0

Please sign in to comment.