From 3d2f842e64c16a58ae3d53779871468ef2977675 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Fri, 17 Jan 2025 14:06:33 +0800 Subject: [PATCH] fix bugs in onnx tracer (#833) * add onnx export code for vqgan model add onnx export code for vqgan model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make return_complex optional make return_complex optional * add vqgan encoder export add vqgan encoder export * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test codes add test codes * Fix tracer bugs at padding Fix tracer bugs at padding * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- fish_speech/models/vqgan/modules/firefly.py | 10 +++++++++- tools/export-onnx.py | 19 +++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py index 91fc9118..c49612ac 100644 --- a/fish_speech/models/vqgan/modules/firefly.py +++ b/fish_speech/models/vqgan/modules/firefly.py @@ -43,7 +43,15 @@ def get_extra_padding_for_conv1d( """See `pad_for_conv1d`.""" length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 - ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + # for tracer, math.ceil will make onnx graph become constant + if isinstance(n_frames, torch.Tensor): + ideal_length = (torch.ceil(n_frames).long() - 1) * stride + ( + kernel_size - padding_total + ) + else: + ideal_length = (math.ceil(n_frames) - 1) * stride + ( + kernel_size - padding_total + ) return ideal_length - length diff --git a/tools/export-onnx.py b/tools/export-onnx.py index 2ffab88c..046449f6 100644 --- a/tools/export-onnx.py +++ b/tools/export-onnx.py @@ -1,3 +1,4 @@ +import onnxruntime import torch import torch.nn.functional as F @@ -20,7 +21,7 @@ def forward(self, audios): z = self.model.quantizer.downsample(encoded_features) _, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1)) _, b, l, _ = indices.shape - return indices.permute(1, 0, 3, 2).contiguous().view(b, -1, l) + return indices.permute(1, 0, 3, 2).long().view(b, -1, l) class Decoder(torch.nn.Module): @@ -58,7 +59,7 @@ def get_codes_from_indices(self, cur_index, indices): all_codes = torch.gather( self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1), dim=2, - index=indices.long().permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim), + index=indices.permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim), ) all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0) @@ -111,7 +112,7 @@ def main(firefly_gan_vq_path, llama_path, export_prefix): audio_example, f"{export_prefix}encoder.onnx", dynamic_axes={ - "audio": [0, 2], + "audio": {0: "batch_size", 2: "audio_length"}, }, do_constant_folding=False, opset_version=18, @@ -125,7 +126,7 @@ def main(firefly_gan_vq_path, llama_path, export_prefix): indices, f"{export_prefix}decoder.onnx", dynamic_axes={ - "prompt": [0, 2], + "prompt": {0: "batch_size", 2: "frame_count"}, }, do_constant_folding=False, opset_version=18, @@ -134,6 +135,16 @@ def main(firefly_gan_vq_path, llama_path, export_prefix): output_names=["audio"], ) + test_example = torch.randn(1, 1, 96000 * 5) + encoder_session = onnxruntime.InferenceSession(f"{export_prefix}encoder.onnx") + decoder_session = onnxruntime.InferenceSession(f"{export_prefix}decoder.onnx") + + # check graph has no error + onnx_enc_out = encoder_session.run(["prompt"], {"audio": test_example.numpy()})[0] + torch_enc_out = enc(test_example) + onnx_dec_out = decoder_session.run(["audio"], {"prompt": onnx_enc_out})[0] + torch_dec_out = dec(torch_enc_out) + if __name__ == "__main__": main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")