From 91bf1500312cc32859a95a0afabc805ea7711f84 Mon Sep 17 00:00:00 2001 From: Deep Chavda <82630272+Deepchavda007@users.noreply.github.com> Date: Thu, 23 Jan 2025 20:52:08 +0530 Subject: [PATCH] feat(demo): Add Colab badge and ensure output directory is created (#844) * feat(demo): Add Colab Demo link and ensure output directory creation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update README.md --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 4 ++-- fish_speech/models/text2semantic/inference.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index a091163e..92ab4d54 100644 --- a/README.md +++ b/README.md @@ -78,9 +78,9 @@ We do not hold any responsibility for any illegal usage of the codebase. Please [Fish Agent](https://fish.audio/demo/live) -## Quick Start for Local Inference +## Quick Start for Local Inference -[inference.ipynb](/inference.ipynb) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/fishaudio/fish-speech/blob/main/inference.ipynb) ## Videos diff --git a/fish_speech/models/text2semantic/inference.py b/fish_speech/models/text2semantic/inference.py index 0e9e2dbd..acf31ee8 100644 --- a/fish_speech/models/text2semantic/inference.py +++ b/fish_speech/models/text2semantic/inference.py @@ -1026,6 +1026,7 @@ def worker(): @click.option("--half/--no-half", default=False) @click.option("--iterative-prompt/--no-iterative-prompt", default=True) @click.option("--chunk-length", type=int, default=100) +@click.option("--output-dir", type=Path, default="temp") def main( text: str, prompt_text: Optional[list[str]], @@ -1042,8 +1043,9 @@ def main( half: bool, iterative_prompt: bool, chunk_length: int, + output_dir: Path, ) -> None: - + os.makedirs(output_dir, exist_ok=True) precision = torch.half if half else torch.bfloat16 if prompt_text is not None and len(prompt_text) != len(prompt_tokens): @@ -1101,8 +1103,9 @@ def main( logger.info(f"Sampled text: {response.text}") elif response.action == "next": if codes: - np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy()) - logger.info(f"Saved codes to codes_{idx}.npy") + codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy") + np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy()) + logger.info(f"Saved codes to {codes_npy_path}") logger.info(f"Next sample") codes = [] idx += 1