From b8bdcd454ccaa03d1250d88bdcd756390d7f3b35 Mon Sep 17 00:00:00 2001 From: MithrilMan Date: Sat, 21 Dec 2024 16:26:15 +0100 Subject: [PATCH] Fix ServeReferenceAudio to allow base64 reference data in json (#777) * Update schema.py to fix ServeStreamResponse * Update schema.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tools/schema.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tools/schema.py b/tools/schema.py index ac8b5c9c..d3b805d3 100644 --- a/tools/schema.py +++ b/tools/schema.py @@ -1,10 +1,11 @@ +import base64 import os import queue from dataclasses import dataclass from typing import Literal import torch -from pydantic import BaseModel, Field, conint, conlist +from pydantic import BaseModel, Field, conint, conlist, model_validator from pydantic.functional_validators import SkipValidation from typing_extensions import Annotated @@ -140,6 +141,19 @@ class ServeReferenceAudio(BaseModel): audio: bytes text: str + @model_validator(mode="before") + def decode_audio(cls, values): + audio = values.get("audio") + if ( + isinstance(audio, str) and len(audio) > 255 + ): # Check if audio is a string (Base64) + try: + values["audio"] = base64.b64decode(audio) + except Exception as e: + # If the audio is not a valid base64 string, we will just ignore it and let the server handle it + pass + return values + def __repr__(self) -> str: return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"