Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr authored Jan 27, 2025
1 parent d679567 commit 361f3b2
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions egs/wenetspeech4tts/TTS/f5-tts/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def default(v, d):


def lens_to_mask(
t: int["b"], length: int | None = None
) -> bool["b n"]: # noqa: F722 F821
t: int["b"], length: int | None = None # noqa: F722 F821
) -> bool["b n"]:
if not exists(length):
length = t.amax()

Expand All @@ -48,8 +48,8 @@ def lens_to_mask(


def mask_from_start_end_indices(
seq_len: int["b"], start: int["b"], end: int["b"]
): # noqa: F722 F821
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=start.device).long()
start_mask = seq[None, :] >= start[:, None]
Expand All @@ -58,8 +58,8 @@ def mask_from_start_end_indices(


def mask_from_frac_lengths(
seq_len: int["b"], frac_lengths: float["b"]
): # noqa: F722 F821
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
):
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths

Expand All @@ -71,8 +71,8 @@ def mask_from_frac_lengths(


def maybe_masked_mean(
t: float["b n d"], mask: bool["b n"] = None
) -> float["b d"]: # noqa: F722
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
) -> float["b d"]:
if not exists(mask):
return t.mean(dim=1)

Expand Down

0 comments on commit 361f3b2

Please sign in to comment.