diff --git a/tools/schema.py b/tools/schema.py index ac8b5c9c..d3b805d3 100644 --- a/tools/schema.py +++ b/tools/schema.py @@ -1,10 +1,11 @@ +import base64 import os import queue from dataclasses import dataclass from typing import Literal import torch -from pydantic import BaseModel, Field, conint, conlist +from pydantic import BaseModel, Field, conint, conlist, model_validator from pydantic.functional_validators import SkipValidation from typing_extensions import Annotated @@ -140,6 +141,19 @@ class ServeReferenceAudio(BaseModel): audio: bytes text: str + @model_validator(mode="before") + def decode_audio(cls, values): + audio = values.get("audio") + if ( + isinstance(audio, str) and len(audio) > 255 + ): # Check if audio is a string (Base64) + try: + values["audio"] = base64.b64decode(audio) + except Exception as e: + # If the audio is not a valid base64 string, we will just ignore it and let the server handle it + pass + return values + def __repr__(self) -> str: return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"