Skip to content

Commit

Permalink
add onnx export code for vqgan encoder (#831)
Browse files Browse the repository at this point in the history
* add onnx export code for vqgan model

add onnx export code for vqgan model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make return_complex optional

make return_complex optional

* add vqgan encoder export

add vqgan encoder export

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
NaruseMioShirakana and pre-commit-ci[bot] authored Jan 17, 2025
1 parent 18965de commit e908d40
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 41 deletions.
6 changes: 4 additions & 2 deletions fish_speech/utils/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
self.hop_length = hop_length
self.center = center
self.mode = mode
self.return_complex = True

self.register_buffer("window", torch.hann_window(win_length), persistent=False)

Expand All @@ -46,10 +47,11 @@ def forward(self, y: Tensor) -> Tensor:
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
return_complex=self.return_complex,
)

spec = torch.view_as_real(spec)
if self.return_complex:
spec = torch.view_as_real(spec)

if self.mode == "pow2_sqrt":
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
Expand Down
55 changes: 16 additions & 39 deletions tools/export-onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn.functional as F
from einx import get_at

from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
from tools.vqgan.extract_vq import get_model
Expand All @@ -17,8 +16,11 @@ def __init__(self, model):
def forward(self, audios):
mels = self.model.spec_transform(audios)
encoded_features = self.model.backbone(mels)
indices = self.model.quantizer.encode(encoded_features)
return indices

z = self.model.quantizer.downsample(encoded_features)
_, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1))
_, b, l, _ = indices.shape
return indices.permute(1, 0, 3, 2).contiguous().view(b, -1, l)


class Decoder(torch.nn.Module):
Expand All @@ -30,12 +32,9 @@ def __init__(self, model):

def get_codes_from_indices(self, cur_index, indices):

batch_size, quantize_dim, q_dim = indices.shape
_, quantize_dim, _ = indices.shape
d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]

# because of quantize dropout, one can pass in indices that are coarse
# and the network should be able to reconstruct

if (
quantize_dim
< self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
Expand All @@ -53,35 +52,24 @@ def get_codes_from_indices(self, cur_index, indices):
value=-1,
)

# take care of quantizer dropout

mask = indices == -1
indices = indices.masked_fill(
mask, 0
) # have it fetch a dummy code to be masked out later
indices = indices.masked_fill(mask, 0)

all_codes = torch.gather(
self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
dim=2,
index=indices.long()
.permute(2, 0, 1)
.unsqueeze(-1)
.repeat(1, 1, 1, d_dim), # q, batch_size, frame, dim
index=indices.long().permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
)

all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)

# scale the codes

scales = (
self.model.quantizer.residual_fsq.rvqs[cur_index]
.scales.unsqueeze(1)
.unsqueeze(1)
)
all_codes = all_codes * scales

# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)

return all_codes

def get_output_from_indices(self, cur_index, indices):
Expand Down Expand Up @@ -112,39 +100,30 @@ def forward(self, indices) -> torch.Tensor:
return x


def main():
GanModel = get_model(
"firefly_gan_vq",
"checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
device="cpu",
)
def main(firefly_gan_vq_path, llama_path, export_prefix):
GanModel = get_model("firefly_gan_vq", firefly_gan_vq_path, device="cpu")
enc = Encoder(GanModel)
dec = Decoder(GanModel)
audio_example = torch.randn(1, 1, 96000)
indices = enc(audio_example)

print(dec(indices).shape)

"""
torch.onnx.export(
enc,
audio_example,
"encoder.onnx",
dynamic_axes = {
f"{export_prefix}encoder.onnx",
dynamic_axes={
"audio": [0, 2],
},
do_constant_folding=False,
opset_version=18,
verbose=False,
input_names=["audio"],
output_names=["prompt"]
output_names=["prompt"],
)
"""

torch.onnx.export(
dec,
indices,
"decoder.onnx",
f"{export_prefix}decoder.onnx",
dynamic_axes={
"prompt": [0, 2],
},
Expand All @@ -155,8 +134,6 @@ def main():
output_names=["audio"],
)

print(enc(audio_example).shape)
print(dec(enc(audio_example)).shape)


main()
if __name__ == "__main__":
main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")

0 comments on commit e908d40

Please sign in to comment.