forked from qiuqiao/SOFA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
117 lines (104 loc) · 3.08 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pathlib
import click
import lightning as pl
import torch
import modules.AP_detector
import modules.g2p
from modules.utils.export_tool import Exporter
from modules.utils.post_processing import post_processing
from train import LitForcedAlignmentTask
@click.command()
@click.option(
"--ckpt",
"-c",
default=None,
required=True,
type=str,
help="path to the checkpoint",
)
@click.option(
"--folder", "-f", default="segments", type=str, help="path to the input folder"
)
@click.option(
"--mode", "-m", default="force", type=click.Choice(["force", "match"])
) # TODO: add asr mode
@click.option(
"--g2p", "-g", default="Dictionary", type=str, help="name of the g2p class"
)
@click.option(
"--ap_detector",
"-a",
default="LoudnessSpectralcentroidAPDetector", # "NoneAPDetector",
type=str,
help="name of the AP detector class",
)
@click.option(
"--in_format",
"-if",
default="lab",
required=False,
type=str,
help="File extension of input transcriptions. Default: lab",
)
@click.option(
"--out_formats",
"-of",
default="textgrid,htk,trans",
required=False,
type=str,
help="Types of output file, separated by comma. Supported types:"
"textgrid(praat),"
" htk(lab,nnsvs,sinsy),"
" transcriptions.csv(diffsinger,trans,transcription,transcriptions)",
)
@click.option(
"--save_confidence",
"-sc",
is_flag=True,
default=False,
show_default=True,
help="save confidence.csv",
)
@click.option(
"--dictionary",
"-d",
default="dictionary/opencpop-extension.txt",
type=str,
help="(only used when --g2p=='Dictionary') path to the dictionary",
)
def main(
ckpt,
folder,
mode,
g2p,
ap_detector,
in_format,
out_formats,
save_confidence,
**kwargs,
):
if not g2p.endswith("G2P"):
g2p += "G2P"
g2p_class = getattr(modules.g2p, g2p)
grapheme_to_phoneme = g2p_class(**kwargs)
out_formats = [i.strip().lower() for i in out_formats.split(",")]
if not ap_detector.endswith("APDetector"):
ap_detector += "APDetector"
AP_detector_class = getattr(modules.AP_detector, ap_detector)
get_AP = AP_detector_class(**kwargs)
grapheme_to_phoneme.set_in_format(in_format)
dataset = grapheme_to_phoneme.get_dataset(pathlib.Path(folder).rglob("*.wav"))
torch.set_grad_enabled(False)
model = LitForcedAlignmentTask.load_from_checkpoint(ckpt)
model.set_inference_mode(mode)
trainer = pl.Trainer(logger=False)
predictions = trainer.predict(model, dataloaders=dataset, return_predictions=True)
predictions = get_AP.process(predictions)
predictions, log = post_processing(predictions)
exporter = Exporter(predictions, log)
if save_confidence:
out_formats.append('confidence')
exporter.export(out_formats)
print("Output files are saved to the same folder as the input wav files.")
if __name__ == "__main__":
main()