diff --git a/pyproject.toml b/pyproject.toml index d894afcf..74ae9f43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,13 +46,13 @@ dependencies = [ "ormsgpack", "tiktoken>=0.8.0", "pydantic==2.9.2", + "cachetools", ] [project.optional-dependencies] stable = [ "torch<=2.4.1", "torchaudio", - "cachetools", ] [build-system] diff --git a/tools/api_client.py b/tools/api_client.py index 90d7b29b..c6cb6a45 100644 --- a/tools/api_client.py +++ b/tools/api_client.py @@ -79,7 +79,7 @@ def parse_args(): parser.add_argument( "--max_new_tokens", type=int, - default=0, + default=1024, help="Maximum new tokens to generate. \n0 means no limit.", ) parser.add_argument( diff --git a/tools/download_models.py b/tools/download_models.py index 9e79c34c..e14a0991 100644 --- a/tools/download_models.py +++ b/tools/download_models.py @@ -22,14 +22,14 @@ def check_and_download_files(repo_id, file_list, local_dir): # 1st -repo_id_1 = "fishaudio/fish-speech-1.4" -local_dir_1 = "./checkpoints/fish-speech-1.4" +repo_id_1 = "fishaudio/fish-speech-1.5" +local_dir_1 = "./checkpoints/fish-speech-1.5" files_1 = [ + "gitattributes", "model.pth", "README.md", - "special_tokens_map.json", - "tokenizer_config.json", - "tokenizer.json", + "special_tokens.json", + "tokenizer.tiktoken", "config.json", "firefly-gan-vq-fsq-8x1024-21hz-generator.pth", ] diff --git a/tools/inference_engine/__init__.py b/tools/inference_engine/__init__.py index 2eb3396a..2c3e476c 100644 --- a/tools/inference_engine/__init__.py +++ b/tools/inference_engine/__init__.py @@ -109,8 +109,7 @@ def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, No audio=(sample_rate, segment), error=None, ) - else: - segments.append(segment) + segments.append(segment) else: break diff --git a/tools/inference_engine/reference_loader.py b/tools/inference_engine/reference_loader.py index 4b560393..6369f1db 100644 --- a/tools/inference_engine/reference_loader.py +++ b/tools/inference_engine/reference_loader.py @@ -85,7 +85,6 @@ def load_by_hash( # If the references are not already loaded, encode them prompt_tokens.append( self.encode_reference( - decoder_model=self.decoder_model, reference_audio=ref.audio, enable_reference_audio=True, ) diff --git a/tools/inference_engine/utils.py b/tools/inference_engine/utils.py index b49e37bf..6a6a5ae3 100644 --- a/tools/inference_engine/utils.py +++ b/tools/inference_engine/utils.py @@ -11,7 +11,7 @@ @dataclass class InferenceResult: code: Literal["header", "segment", "error", "final"] - audio: Optional[Tuple[int, np.ndarray]] + audio: Optional[Tuple[int, np.ndarray | bytes]] error: Optional[Exception] @@ -25,7 +25,7 @@ def normalize_text(user_input: str, use_normalization: bool) -> str: def wav_chunk_header( sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1 -) -> np.ndarray: +) -> bytes: buffer = io.BytesIO() with wave.open(buffer, "wb") as wav_file: @@ -36,7 +36,4 @@ def wav_chunk_header( wav_header_bytes = buffer.getvalue() buffer.close() - # Convert to numpy array - wav_header = np.frombuffer(wav_header_bytes, dtype=np.uint8) - - return wav_header + return wav_header_bytes diff --git a/tools/run_webui.py b/tools/run_webui.py index 5844b72c..ab6f84e3 100644 --- a/tools/run_webui.py +++ b/tools/run_webui.py @@ -87,7 +87,7 @@ def parse_args(): text="Hello world.", references=[], reference_id=None, - max_new_tokens=0, + max_new_tokens=1024, chunk_length=200, top_p=0.7, repetition_penalty=1.5, diff --git a/tools/server/inference.py b/tools/server/inference.py index d5e95483..2cfdceea 100644 --- a/tools/server/inference.py +++ b/tools/server/inference.py @@ -14,6 +14,7 @@ def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine): Wrapper for the inference function. Used in the API server. """ + count = 0 for result in engine.inference(req): match result.code: case "header": @@ -27,15 +28,18 @@ def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine): ) case "segment": + count += 1 if isinstance(result.audio, tuple): yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes() case "final": + count += 1 if isinstance(result.audio, tuple): yield result.audio[1] return None # Stop the generator - raise HTTPException( - HTTPStatus.INTERNAL_SERVER_ERROR, - content="No audio generated, please check the input text.", - ) + if count == 0: + raise HTTPException( + HTTPStatus.INTERNAL_SERVER_ERROR, + content="No audio generated, please check the input text.", + ) diff --git a/tools/server/model_manager.py b/tools/server/model_manager.py index c3f0a896..cb70da6a 100644 --- a/tools/server/model_manager.py +++ b/tools/server/model_manager.py @@ -113,10 +113,10 @@ def warm_up(self, tts_inference_engine) -> None: text="Hello world.", references=[], reference_id=None, - max_new_tokens=0, + max_new_tokens=1024, chunk_length=200, top_p=0.7, - repetition_penalty=1.5, + repetition_penalty=1.2, temperature=0.7, format="wav", )