diff --git a/tools/export-onnx.py b/tools/export-onnx.py new file mode 100644 index 00000000..1758c17b --- /dev/null +++ b/tools/export-onnx.py @@ -0,0 +1,162 @@ +import torch +import torch.nn.functional as F +from einx import get_at + +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from tools.vqgan.extract_vq import get_model + +PAD_TOKEN_ID = torch.LongTensor([CODEBOOK_PAD_TOKEN_ID]) + + +class Encoder(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.spec_transform.spectrogram.return_complex = False + + def forward(self, audios): + mels = self.model.spec_transform(audios) + encoded_features = self.model.backbone(mels) + indices = self.model.quantizer.encode(encoded_features) + return indices + + +class Decoder(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.head.training = False + self.model.head.checkpointing = False + + def get_codes_from_indices(self, cur_index, indices): + + batch_size, quantize_dim, q_dim = indices.shape + d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2] + + # because of quantize dropout, one can pass in indices that are coarse + # and the network should be able to reconstruct + + if ( + quantize_dim + < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers + ): + assert ( + self.model.quantizer.residual_fsq.rvqs[cur_index].quantize_dropout > 0.0 + ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" + indices = F.pad( + indices, + ( + 0, + self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers + - quantize_dim, + ), + value=-1, + ) + + # take care of quantizer dropout + + mask = indices == -1 + indices = indices.masked_fill( + mask, 0 + ) # have it fetch a dummy code to be masked out later + + all_codes = torch.gather( + self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1), + dim=2, + index=indices.long() + .permute(2, 0, 1) + .unsqueeze(-1) + .repeat(1, 1, 1, d_dim), # q, batch_size, frame, dim + ) + + all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0) + + # scale the codes + + scales = ( + self.model.quantizer.residual_fsq.rvqs[cur_index] + .scales.unsqueeze(1) + .unsqueeze(1) + ) + all_codes = all_codes * scales + + # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) + + return all_codes + + def get_output_from_indices(self, cur_index, indices): + codes = self.get_codes_from_indices(cur_index, indices) + codes_summed = codes.sum(dim=0) + return self.model.quantizer.residual_fsq.rvqs[cur_index].project_out( + codes_summed + ) + + def forward(self, indices) -> torch.Tensor: + batch_size, _, length = indices.shape + dims = self.model.quantizer.residual_fsq.dim + groups = self.model.quantizer.residual_fsq.groups + dim_per_group = dims // groups + + # indices = rearrange(indices, "b (g r) l -> g b l r", g=groups) + indices = indices.view(batch_size, groups, -1, length).permute(1, 0, 3, 2) + + # z_q = self.model.quantizer.residual_fsq.get_output_from_indices(indices) + z_q = torch.empty((batch_size, length, dims)) + for i in range(groups): + z_q[:, :, i * dim_per_group : (i + 1) * dim_per_group] = ( + self.get_output_from_indices(i, indices[i]) + ) + + z = self.model.quantizer.upsample(z_q.transpose(1, 2)) + x = self.model.head(z) + return x + + +def main(): + GanModel = get_model( + "firefly_gan_vq", + "checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + device="cpu", + ) + enc = Encoder(GanModel) + dec = Decoder(GanModel) + audio_example = torch.randn(1, 1, 96000) + indices = enc(audio_example) + + print(dec(indices).shape) + + """ + torch.onnx.export( + enc, + audio_example, + "encoder.onnx", + dynamic_axes = { + "audio": [0, 2], + }, + do_constant_folding=False, + opset_version=18, + verbose=False, + input_names=["audio"], + output_names=["prompt"] + ) + """ + + torch.onnx.export( + dec, + indices, + "decoder.onnx", + dynamic_axes={ + "prompt": [0, 2], + }, + do_constant_folding=False, + opset_version=18, + verbose=False, + input_names=["prompt"], + output_names=["audio"], + ) + + print(enc(audio_example).shape) + print(dec(enc(audio_example)).shape) + + +main()