Skip to content
This repository has been archived by the owner on Nov 11, 2023. It is now read-only.

Commit

Permalink
Accelerate up random slice segments
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue authored Nov 10, 2023
1 parent 730930d commit 0ee0b08
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions modules/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,19 @@ def rand_gumbel_like(x):


def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
# Slice segments
gather_indices = ids_str[:, None, None] + torch.arange(
segment_size, device=x.device
)
return torch.gather(x, 2, gather_indices)


def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str

Expand Down

0 comments on commit 0ee0b08

Please sign in to comment.