diff --git a/README.md b/README.md index a091163e..92ab4d54 100644 --- a/README.md +++ b/README.md @@ -78,9 +78,9 @@ We do not hold any responsibility for any illegal usage of the codebase. Please [Fish Agent](https://fish.audio/demo/live) -## Quick Start for Local Inference +## Quick Start for Local Inference -[inference.ipynb](/inference.ipynb) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/fishaudio/fish-speech/blob/main/inference.ipynb) ## Videos diff --git a/fish_speech/models/text2semantic/inference.py b/fish_speech/models/text2semantic/inference.py index 0e9e2dbd..acf31ee8 100644 --- a/fish_speech/models/text2semantic/inference.py +++ b/fish_speech/models/text2semantic/inference.py @@ -1026,6 +1026,7 @@ def worker(): @click.option("--half/--no-half", default=False) @click.option("--iterative-prompt/--no-iterative-prompt", default=True) @click.option("--chunk-length", type=int, default=100) +@click.option("--output-dir", type=Path, default="temp") def main( text: str, prompt_text: Optional[list[str]], @@ -1042,8 +1043,9 @@ def main( half: bool, iterative_prompt: bool, chunk_length: int, + output_dir: Path, ) -> None: - + os.makedirs(output_dir, exist_ok=True) precision = torch.half if half else torch.bfloat16 if prompt_text is not None and len(prompt_text) != len(prompt_tokens): @@ -1101,8 +1103,9 @@ def main( logger.info(f"Sampled text: {response.text}") elif response.action == "next": if codes: - np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy()) - logger.info(f"Saved codes to codes_{idx}.npy") + codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy") + np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy()) + logger.info(f"Saved codes to {codes_npy_path}") logger.info(f"Next sample") codes = [] idx += 1