From e908d40b9028a9bec23a152d8becca19b20cb836 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Fri, 17 Jan 2025 13:10:52 +0800 Subject: [PATCH] add onnx export code for vqgan encoder (#831) * add onnx export code for vqgan model add onnx export code for vqgan model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make return_complex optional make return_complex optional * add vqgan encoder export add vqgan encoder export * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- fish_speech/utils/spectrogram.py | 6 ++-- tools/export-onnx.py | 55 ++++++++++---------------------- 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py index 01c3d7a2..19ea435c 100644 --- a/fish_speech/utils/spectrogram.py +++ b/fish_speech/utils/spectrogram.py @@ -20,6 +20,7 @@ def __init__( self.hop_length = hop_length self.center = center self.mode = mode + self.return_complex = True self.register_buffer("window", torch.hann_window(win_length), persistent=False) @@ -46,10 +47,11 @@ def forward(self, y: Tensor) -> Tensor: pad_mode="reflect", normalized=False, onesided=True, - return_complex=True, + return_complex=self.return_complex, ) - spec = torch.view_as_real(spec) + if self.return_complex: + spec = torch.view_as_real(spec) if self.mode == "pow2_sqrt": spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) diff --git a/tools/export-onnx.py b/tools/export-onnx.py index 1758c17b..2ffab88c 100644 --- a/tools/export-onnx.py +++ b/tools/export-onnx.py @@ -1,6 +1,5 @@ import torch import torch.nn.functional as F -from einx import get_at from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID from tools.vqgan.extract_vq import get_model @@ -17,8 +16,11 @@ def __init__(self, model): def forward(self, audios): mels = self.model.spec_transform(audios) encoded_features = self.model.backbone(mels) - indices = self.model.quantizer.encode(encoded_features) - return indices + + z = self.model.quantizer.downsample(encoded_features) + _, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1)) + _, b, l, _ = indices.shape + return indices.permute(1, 0, 3, 2).contiguous().view(b, -1, l) class Decoder(torch.nn.Module): @@ -30,12 +32,9 @@ def __init__(self, model): def get_codes_from_indices(self, cur_index, indices): - batch_size, quantize_dim, q_dim = indices.shape + _, quantize_dim, _ = indices.shape d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2] - # because of quantize dropout, one can pass in indices that are coarse - # and the network should be able to reconstruct - if ( quantize_dim < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers @@ -53,26 +52,17 @@ def get_codes_from_indices(self, cur_index, indices): value=-1, ) - # take care of quantizer dropout - mask = indices == -1 - indices = indices.masked_fill( - mask, 0 - ) # have it fetch a dummy code to be masked out later + indices = indices.masked_fill(mask, 0) all_codes = torch.gather( self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1), dim=2, - index=indices.long() - .permute(2, 0, 1) - .unsqueeze(-1) - .repeat(1, 1, 1, d_dim), # q, batch_size, frame, dim + index=indices.long().permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim), ) all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0) - # scale the codes - scales = ( self.model.quantizer.residual_fsq.rvqs[cur_index] .scales.unsqueeze(1) @@ -80,8 +70,6 @@ def get_codes_from_indices(self, cur_index, indices): ) all_codes = all_codes * scales - # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) - return all_codes def get_output_from_indices(self, cur_index, indices): @@ -112,39 +100,30 @@ def forward(self, indices) -> torch.Tensor: return x -def main(): - GanModel = get_model( - "firefly_gan_vq", - "checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", - device="cpu", - ) +def main(firefly_gan_vq_path, llama_path, export_prefix): + GanModel = get_model("firefly_gan_vq", firefly_gan_vq_path, device="cpu") enc = Encoder(GanModel) dec = Decoder(GanModel) audio_example = torch.randn(1, 1, 96000) indices = enc(audio_example) - - print(dec(indices).shape) - - """ torch.onnx.export( enc, audio_example, - "encoder.onnx", - dynamic_axes = { + f"{export_prefix}encoder.onnx", + dynamic_axes={ "audio": [0, 2], }, do_constant_folding=False, opset_version=18, verbose=False, input_names=["audio"], - output_names=["prompt"] + output_names=["prompt"], ) - """ torch.onnx.export( dec, indices, - "decoder.onnx", + f"{export_prefix}decoder.onnx", dynamic_axes={ "prompt": [0, 2], }, @@ -155,8 +134,6 @@ def main(): output_names=["audio"], ) - print(enc(audio_example).shape) - print(dec(enc(audio_example)).shape) - -main() +if __name__ == "__main__": + main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")