Skip to content

Commit

Permalink
fix bugs in onnx tracer (#833)
Browse files Browse the repository at this point in the history
* add onnx export code for vqgan model

add onnx export code for vqgan model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make return_complex optional

make return_complex optional

* add vqgan encoder export

add vqgan encoder export

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test codes

add test codes

* Fix tracer bugs at padding

Fix tracer bugs at padding

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
NaruseMioShirakana and pre-commit-ci[bot] authored Jan 17, 2025
1 parent e908d40 commit 3d2f842
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
10 changes: 9 additions & 1 deletion fish_speech/models/vqgan/modules/firefly.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,15 @@ def get_extra_padding_for_conv1d(
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
# for tracer, math.ceil will make onnx graph become constant
if isinstance(n_frames, torch.Tensor):
ideal_length = (torch.ceil(n_frames).long() - 1) * stride + (
kernel_size - padding_total
)
else:
ideal_length = (math.ceil(n_frames) - 1) * stride + (
kernel_size - padding_total
)
return ideal_length - length


Expand Down
19 changes: 15 additions & 4 deletions tools/export-onnx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import onnxruntime
import torch
import torch.nn.functional as F

Expand All @@ -20,7 +21,7 @@ def forward(self, audios):
z = self.model.quantizer.downsample(encoded_features)
_, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1))
_, b, l, _ = indices.shape
return indices.permute(1, 0, 3, 2).contiguous().view(b, -1, l)
return indices.permute(1, 0, 3, 2).long().view(b, -1, l)


class Decoder(torch.nn.Module):
Expand Down Expand Up @@ -58,7 +59,7 @@ def get_codes_from_indices(self, cur_index, indices):
all_codes = torch.gather(
self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
dim=2,
index=indices.long().permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
index=indices.permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
)

all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
Expand Down Expand Up @@ -111,7 +112,7 @@ def main(firefly_gan_vq_path, llama_path, export_prefix):
audio_example,
f"{export_prefix}encoder.onnx",
dynamic_axes={
"audio": [0, 2],
"audio": {0: "batch_size", 2: "audio_length"},
},
do_constant_folding=False,
opset_version=18,
Expand All @@ -125,7 +126,7 @@ def main(firefly_gan_vq_path, llama_path, export_prefix):
indices,
f"{export_prefix}decoder.onnx",
dynamic_axes={
"prompt": [0, 2],
"prompt": {0: "batch_size", 2: "frame_count"},
},
do_constant_folding=False,
opset_version=18,
Expand All @@ -134,6 +135,16 @@ def main(firefly_gan_vq_path, llama_path, export_prefix):
output_names=["audio"],
)

test_example = torch.randn(1, 1, 96000 * 5)
encoder_session = onnxruntime.InferenceSession(f"{export_prefix}encoder.onnx")
decoder_session = onnxruntime.InferenceSession(f"{export_prefix}decoder.onnx")

# check graph has no error
onnx_enc_out = encoder_session.run(["prompt"], {"audio": test_example.numpy()})[0]
torch_enc_out = enc(test_example)
onnx_dec_out = decoder_session.run(["audio"], {"prompt": onnx_enc_out})[0]
torch_dec_out = dec(torch_enc_out)


if __name__ == "__main__":
main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")

0 comments on commit 3d2f842

Please sign in to comment.