-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathexport.py
91 lines (75 loc) · 2.44 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import os.path
import torch
from ddsp.vocoder import load_model
class DDSPWrapper(torch.nn.Module):
def __init__(self, module, device):
super().__init__()
self.model = module
self.to(device)
def forward(self, mel, f0):
f0 = f0[..., None]
signal, _, (s_h, s_n) = self.model(mel, f0)
return signal, s_h, s_n
def parse_args(args=None, namespace=None):
parser = argparse.ArgumentParser(
description='Export model to standalone PyTorch traced module or ONNX format'
)
parser.add_argument(
'-m',
'--model_path',
type=str,
required=True,
help='path to model file'
)
parser.add_argument(
'--traced',
required=False,
action='store_true',
help='export to traced module format'
)
parser.add_argument(
'--onnx',
required=False,
action='store_true',
help='export to ONNX format'
)
cmd = parser.parse_args(args=args, namespace=namespace)
if not cmd.traced and not cmd.onnx:
parser.error('either --traced or --onnx should be specified.')
return cmd
def main():
device = 'cpu'
# parse commands
cmd = parse_args()
# load model
model, args = load_model(cmd.model_path, device=device)
#model = DDSPWrapper(model, device)
# extract model dirname and filename
directory = os.path.dirname(os.path.abspath(cmd.model_path))
name = os.path.basename(cmd.model_path).rsplit('.', maxsplit=1)[0]
# load input
n_mel_channels = args.data.n_mels
n_frames = 10
mel = torch.randn((1, n_frames, n_mel_channels), dtype=torch.float32, device=device)
f0 = torch.FloatTensor([[440.] * n_frames]).to(device)
f0 = f0[..., None]
# export model
with torch.no_grad():
if cmd.traced:
torch_version = torch.version.__version__.rsplit('+', maxsplit=1)[0]
export_path = os.path.join(directory, f'{name}-traced-torch{torch_version}.jit')
print(f' [Tracing] {cmd.model_path} => {export_path}')
model = torch.jit.trace(
model,
(
mel,
f0
),
check_trace=False
)
torch.jit.save(model, export_path)
if cmd.onnx:
raise NotImplementedError('Exporting to ONNX format is not supported yet.')
if __name__ == '__main__':
main()