-
Notifications
You must be signed in to change notification settings - Fork 0
/
cli.py
142 lines (109 loc) · 3.94 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import datetime
from pathlib import Path
import click
from chorus import train as c_train
from chorus import pipelines
from chorus import infer
from chorus.config import SAMPLE_RATE
@click.group()
def cli():
pass
@cli.group(help="subcommands to download data")
def data():
pass
@data.command("xc-meta", help="download xeno-canto meta data")
@click.option("-v", "--verbose", is_flag=True, help="Show progress bar.")
def xeno_meta(verbose: bool):
pipelines.save_all_xeno_canto_meta(verbose)
@data.command("xc-audio", help="download xeno-canto audio data")
@click.option("-v", "--verbose", is_flag=True, help="Show progress bar.")
@click.option("--redownload", is_flag=True, help="Redownload all files.")
def xeno_audio(verbose: bool, redownload: bool):
pipelines.save_all_xeno_canto_audio(verbose, skip_existing=not redownload)
@data.command("xc-to-npy", help="convert audio to .npy data at [SAMPLERATE]")
@click.option(
"--samplerate",
type=int,
default=SAMPLE_RATE,
help="Edit chorus/config.py if you change this value!",
)
@click.option("--reprocess", is_flag=True, help="Reprocess all audio.")
@click.option("-v", "--verbose", is_flag=True, help="Show progress bar.")
def xeno_to_numpy(samplerate: int, reprocess: bool, verbose: bool):
pipelines.convert_to_numpy(samplerate, verbose, not reprocess)
@data.command("range-meta", help="download range map meta data")
def range_meta():
pipelines.save_range_map_meta()
@data.command("range-map", help="download range map data")
@click.option("-v", "--verbose", is_flag=True, help="Show progress bar.")
def range_maps(verbose: bool):
pipelines.save_range_maps(verbose)
@data.command("background", help="download background audio files")
@click.option(
"--samplerate",
type=int,
default=SAMPLE_RATE,
help="Edit chorus/config.py if you change this value!",
)
def background(samplerate: int):
pipelines.save_background_sounds(samplerate)
@cli.group(help="subcommands to train models")
def train():
pass
@train.command(help="train the classifier model", name="classifier")
@click.argument("name", type=str)
def train_classifier(name: str):
c_train.train_classifier(name)
@train.command(help="train the isolator model", name="isolator")
@click.argument("name", type=str)
@click.argument("classifier_filepath", type=str)
def train_isolator(name: str, classifier_filepath: str):
c_train.train_isolator(name, classifier_filepath)
@train.command(
help="export a classifier model as an optimized torchscript module",
name="export-classifier",
)
@click.argument("model_in_path", type=Path)
@click.argument("model_out_path", type=Path)
def export_classifier(model_in_path: Path, model_out_path: Path):
c_train.export_jitted_classifier(model_in_path, model_out_path)
@cli.group(help="run models on audio file")
def run():
pass
@run.command(help="run classifier on audio file", name="classifier")
@click.argument("modelpath", type=Path)
@click.argument("audiofile", type=Path)
@click.option(
"--latlng", default=None, help="comma-separated lat,lng coordinates"
)
@click.option("--date", default=None, help="date of recording as YYYY-MM-DD")
@click.option(
"--top-n",
type=int,
default=5,
help="Show top n predictions",
show_default=True,
)
@click.option(
"--scientific", is_flag=True, help="Display results using scientific names"
)
def run_classifier(
modelpath: Path,
audiofile: Path,
latlng,
date,
top_n: int,
scientific: bool,
):
"""Run classifier located at MODELPATH on AUDIOFILE"""
if latlng is not None:
latlng = [float(x) for x in latlng.split(",")]
if date is not None:
date = datetime.datetime.fromisoformat(date)
preds = infer.run_classifier(
modelpath, audiofile, latlng, date, scientific=scientific
)
for label in sorted(preds, key=preds.__getitem__, reverse=True)[:top_n]:
print(f"{label: >30}: {preds[label]:.3f}")
if __name__ == "__main__":
cli()