From 361f3b2061debc4a2e1b8f0249a4444ab168e576 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 27 Jan 2025 16:02:06 +0800 Subject: [PATCH] Update utils.py --- egs/wenetspeech4tts/TTS/f5-tts/model/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py b/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py index 09a46c3e53..038ed0315d 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py @@ -38,8 +38,8 @@ def default(v, d): def lens_to_mask( - t: int["b"], length: int | None = None -) -> bool["b n"]: # noqa: F722 F821 + t: int["b"], length: int | None = None # noqa: F722 F821 +) -> bool["b n"]: if not exists(length): length = t.amax() @@ -48,8 +48,8 @@ def lens_to_mask( def mask_from_start_end_indices( - seq_len: int["b"], start: int["b"], end: int["b"] -): # noqa: F722 F821 + seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821 +): max_seq_len = seq_len.max().item() seq = torch.arange(max_seq_len, device=start.device).long() start_mask = seq[None, :] >= start[:, None] @@ -58,8 +58,8 @@ def mask_from_start_end_indices( def mask_from_frac_lengths( - seq_len: int["b"], frac_lengths: float["b"] -): # noqa: F722 F821 + seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821 +): lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths @@ -71,8 +71,8 @@ def mask_from_frac_lengths( def maybe_masked_mean( - t: float["b n d"], mask: bool["b n"] = None -) -> float["b d"]: # noqa: F722 + t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821 +) -> float["b d"]: if not exists(mask): return t.mean(dim=1)