-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add onnx export code for vqgan model (#830)
* add onnx export code for vqgan model add onnx export code for vqgan model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
fc289b7
commit 18965de
Showing
1 changed file
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from einx import get_at | ||
|
||
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID | ||
from tools.vqgan.extract_vq import get_model | ||
|
||
PAD_TOKEN_ID = torch.LongTensor([CODEBOOK_PAD_TOKEN_ID]) | ||
|
||
|
||
class Encoder(torch.nn.Module): | ||
def __init__(self, model): | ||
super().__init__() | ||
self.model = model | ||
self.model.spec_transform.spectrogram.return_complex = False | ||
|
||
def forward(self, audios): | ||
mels = self.model.spec_transform(audios) | ||
encoded_features = self.model.backbone(mels) | ||
indices = self.model.quantizer.encode(encoded_features) | ||
return indices | ||
|
||
|
||
class Decoder(torch.nn.Module): | ||
def __init__(self, model): | ||
super().__init__() | ||
self.model = model | ||
self.model.head.training = False | ||
self.model.head.checkpointing = False | ||
|
||
def get_codes_from_indices(self, cur_index, indices): | ||
|
||
batch_size, quantize_dim, q_dim = indices.shape | ||
d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2] | ||
|
||
# because of quantize dropout, one can pass in indices that are coarse | ||
# and the network should be able to reconstruct | ||
|
||
if ( | ||
quantize_dim | ||
< self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers | ||
): | ||
assert ( | ||
self.model.quantizer.residual_fsq.rvqs[cur_index].quantize_dropout > 0.0 | ||
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" | ||
indices = F.pad( | ||
indices, | ||
( | ||
0, | ||
self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers | ||
- quantize_dim, | ||
), | ||
value=-1, | ||
) | ||
|
||
# take care of quantizer dropout | ||
|
||
mask = indices == -1 | ||
indices = indices.masked_fill( | ||
mask, 0 | ||
) # have it fetch a dummy code to be masked out later | ||
|
||
all_codes = torch.gather( | ||
self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1), | ||
dim=2, | ||
index=indices.long() | ||
.permute(2, 0, 1) | ||
.unsqueeze(-1) | ||
.repeat(1, 1, 1, d_dim), # q, batch_size, frame, dim | ||
) | ||
|
||
all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0) | ||
|
||
# scale the codes | ||
|
||
scales = ( | ||
self.model.quantizer.residual_fsq.rvqs[cur_index] | ||
.scales.unsqueeze(1) | ||
.unsqueeze(1) | ||
) | ||
all_codes = all_codes * scales | ||
|
||
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) | ||
|
||
return all_codes | ||
|
||
def get_output_from_indices(self, cur_index, indices): | ||
codes = self.get_codes_from_indices(cur_index, indices) | ||
codes_summed = codes.sum(dim=0) | ||
return self.model.quantizer.residual_fsq.rvqs[cur_index].project_out( | ||
codes_summed | ||
) | ||
|
||
def forward(self, indices) -> torch.Tensor: | ||
batch_size, _, length = indices.shape | ||
dims = self.model.quantizer.residual_fsq.dim | ||
groups = self.model.quantizer.residual_fsq.groups | ||
dim_per_group = dims // groups | ||
|
||
# indices = rearrange(indices, "b (g r) l -> g b l r", g=groups) | ||
indices = indices.view(batch_size, groups, -1, length).permute(1, 0, 3, 2) | ||
|
||
# z_q = self.model.quantizer.residual_fsq.get_output_from_indices(indices) | ||
z_q = torch.empty((batch_size, length, dims)) | ||
for i in range(groups): | ||
z_q[:, :, i * dim_per_group : (i + 1) * dim_per_group] = ( | ||
self.get_output_from_indices(i, indices[i]) | ||
) | ||
|
||
z = self.model.quantizer.upsample(z_q.transpose(1, 2)) | ||
x = self.model.head(z) | ||
return x | ||
|
||
|
||
def main(): | ||
GanModel = get_model( | ||
"firefly_gan_vq", | ||
"checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", | ||
device="cpu", | ||
) | ||
enc = Encoder(GanModel) | ||
dec = Decoder(GanModel) | ||
audio_example = torch.randn(1, 1, 96000) | ||
indices = enc(audio_example) | ||
|
||
print(dec(indices).shape) | ||
|
||
""" | ||
torch.onnx.export( | ||
enc, | ||
audio_example, | ||
"encoder.onnx", | ||
dynamic_axes = { | ||
"audio": [0, 2], | ||
}, | ||
do_constant_folding=False, | ||
opset_version=18, | ||
verbose=False, | ||
input_names=["audio"], | ||
output_names=["prompt"] | ||
) | ||
""" | ||
|
||
torch.onnx.export( | ||
dec, | ||
indices, | ||
"decoder.onnx", | ||
dynamic_axes={ | ||
"prompt": [0, 2], | ||
}, | ||
do_constant_folding=False, | ||
opset_version=18, | ||
verbose=False, | ||
input_names=["prompt"], | ||
output_names=["audio"], | ||
) | ||
|
||
print(enc(audio_example).shape) | ||
print(dec(enc(audio_example)).shape) | ||
|
||
|
||
main() |