Skip to content

Commit

Permalink
add onnx export code for vqgan model (#830)
Browse files Browse the repository at this point in the history
* add onnx export code for vqgan model

add onnx export code for vqgan model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
NaruseMioShirakana and pre-commit-ci[bot] authored Jan 16, 2025
1 parent fc289b7 commit 18965de
Showing 1 changed file with 162 additions and 0 deletions.
162 changes: 162 additions & 0 deletions tools/export-onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import torch
import torch.nn.functional as F
from einx import get_at

from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
from tools.vqgan.extract_vq import get_model

PAD_TOKEN_ID = torch.LongTensor([CODEBOOK_PAD_TOKEN_ID])


class Encoder(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.model.spec_transform.spectrogram.return_complex = False

def forward(self, audios):
mels = self.model.spec_transform(audios)
encoded_features = self.model.backbone(mels)
indices = self.model.quantizer.encode(encoded_features)
return indices


class Decoder(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.model.head.training = False
self.model.head.checkpointing = False

def get_codes_from_indices(self, cur_index, indices):

batch_size, quantize_dim, q_dim = indices.shape
d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]

# because of quantize dropout, one can pass in indices that are coarse
# and the network should be able to reconstruct

if (
quantize_dim
< self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
):
assert (
self.model.quantizer.residual_fsq.rvqs[cur_index].quantize_dropout > 0.0
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
indices = F.pad(
indices,
(
0,
self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
- quantize_dim,
),
value=-1,
)

# take care of quantizer dropout

mask = indices == -1
indices = indices.masked_fill(
mask, 0
) # have it fetch a dummy code to be masked out later

all_codes = torch.gather(
self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
dim=2,
index=indices.long()
.permute(2, 0, 1)
.unsqueeze(-1)
.repeat(1, 1, 1, d_dim), # q, batch_size, frame, dim
)

all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)

# scale the codes

scales = (
self.model.quantizer.residual_fsq.rvqs[cur_index]
.scales.unsqueeze(1)
.unsqueeze(1)
)
all_codes = all_codes * scales

# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)

return all_codes

def get_output_from_indices(self, cur_index, indices):
codes = self.get_codes_from_indices(cur_index, indices)
codes_summed = codes.sum(dim=0)
return self.model.quantizer.residual_fsq.rvqs[cur_index].project_out(
codes_summed
)

def forward(self, indices) -> torch.Tensor:
batch_size, _, length = indices.shape
dims = self.model.quantizer.residual_fsq.dim
groups = self.model.quantizer.residual_fsq.groups
dim_per_group = dims // groups

# indices = rearrange(indices, "b (g r) l -> g b l r", g=groups)
indices = indices.view(batch_size, groups, -1, length).permute(1, 0, 3, 2)

# z_q = self.model.quantizer.residual_fsq.get_output_from_indices(indices)
z_q = torch.empty((batch_size, length, dims))
for i in range(groups):
z_q[:, :, i * dim_per_group : (i + 1) * dim_per_group] = (
self.get_output_from_indices(i, indices[i])
)

z = self.model.quantizer.upsample(z_q.transpose(1, 2))
x = self.model.head(z)
return x


def main():
GanModel = get_model(
"firefly_gan_vq",
"checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
device="cpu",
)
enc = Encoder(GanModel)
dec = Decoder(GanModel)
audio_example = torch.randn(1, 1, 96000)
indices = enc(audio_example)

print(dec(indices).shape)

"""
torch.onnx.export(
enc,
audio_example,
"encoder.onnx",
dynamic_axes = {
"audio": [0, 2],
},
do_constant_folding=False,
opset_version=18,
verbose=False,
input_names=["audio"],
output_names=["prompt"]
)
"""

torch.onnx.export(
dec,
indices,
"decoder.onnx",
dynamic_axes={
"prompt": [0, 2],
},
do_constant_folding=False,
opset_version=18,
verbose=False,
input_names=["prompt"],
output_names=["audio"],
)

print(enc(audio_example).shape)
print(dec(enc(audio_example)).shape)


main()

0 comments on commit 18965de

Please sign in to comment.