diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py index 01c3d7a2..19ea435c 100644 --- a/fish_speech/utils/spectrogram.py +++ b/fish_speech/utils/spectrogram.py @@ -20,6 +20,7 @@ def __init__( self.hop_length = hop_length self.center = center self.mode = mode + self.return_complex = True self.register_buffer("window", torch.hann_window(win_length), persistent=False) @@ -46,10 +47,11 @@ def forward(self, y: Tensor) -> Tensor: pad_mode="reflect", normalized=False, onesided=True, - return_complex=True, + return_complex=self.return_complex, ) - spec = torch.view_as_real(spec) + if self.return_complex: + spec = torch.view_as_real(spec) if self.mode == "pow2_sqrt": spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) diff --git a/tools/export-onnx.py b/tools/export-onnx.py index 1758c17b..2ffab88c 100644 --- a/tools/export-onnx.py +++ b/tools/export-onnx.py @@ -1,6 +1,5 @@ import torch import torch.nn.functional as F -from einx import get_at from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID from tools.vqgan.extract_vq import get_model @@ -17,8 +16,11 @@ def __init__(self, model): def forward(self, audios): mels = self.model.spec_transform(audios) encoded_features = self.model.backbone(mels) - indices = self.model.quantizer.encode(encoded_features) - return indices + + z = self.model.quantizer.downsample(encoded_features) + _, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1)) + _, b, l, _ = indices.shape + return indices.permute(1, 0, 3, 2).contiguous().view(b, -1, l) class Decoder(torch.nn.Module): @@ -30,12 +32,9 @@ def __init__(self, model): def get_codes_from_indices(self, cur_index, indices): - batch_size, quantize_dim, q_dim = indices.shape + _, quantize_dim, _ = indices.shape d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2] - # because of quantize dropout, one can pass in indices that are coarse - # and the network should be able to reconstruct - if ( quantize_dim < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers @@ -53,26 +52,17 @@ def get_codes_from_indices(self, cur_index, indices): value=-1, ) - # take care of quantizer dropout - mask = indices == -1 - indices = indices.masked_fill( - mask, 0 - ) # have it fetch a dummy code to be masked out later + indices = indices.masked_fill(mask, 0) all_codes = torch.gather( self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1), dim=2, - index=indices.long() - .permute(2, 0, 1) - .unsqueeze(-1) - .repeat(1, 1, 1, d_dim), # q, batch_size, frame, dim + index=indices.long().permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim), ) all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0) - # scale the codes - scales = ( self.model.quantizer.residual_fsq.rvqs[cur_index] .scales.unsqueeze(1) @@ -80,8 +70,6 @@ def get_codes_from_indices(self, cur_index, indices): ) all_codes = all_codes * scales - # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) - return all_codes def get_output_from_indices(self, cur_index, indices): @@ -112,39 +100,30 @@ def forward(self, indices) -> torch.Tensor: return x -def main(): - GanModel = get_model( - "firefly_gan_vq", - "checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - device="cpu", - ) +def main(firefly_gan_vq_path, llama_path, export_prefix): + GanModel = get_model("firefly_gan_vq", firefly_gan_vq_path, device="cpu") enc = Encoder(GanModel) dec = Decoder(GanModel) audio_example = torch.randn(1, 1, 96000) indices = enc(audio_example) - - print(dec(indices).shape) - - """ torch.onnx.export( enc, audio_example, - "encoder.onnx", - dynamic_axes = { + f"{export_prefix}encoder.onnx", + dynamic_axes={ "audio": [0, 2], }, do_constant_folding=False, opset_version=18, verbose=False, input_names=["audio"], - output_names=["prompt"] + output_names=["prompt"], ) - """ torch.onnx.export( dec, indices, - "decoder.onnx", + f"{export_prefix}decoder.onnx", dynamic_axes={ "prompt": [0, 2], }, @@ -155,8 +134,6 @@ def main(): output_names=["audio"], ) - print(enc(audio_example).shape) - print(dec(enc(audio_example)).shape) - -main() +if __name__ == "__main__": + main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")