-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add transcriber tool
- Loading branch information
Showing
3 changed files
with
236 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import json | ||
import subprocess | ||
import time | ||
import requests | ||
import urllib.request | ||
import zipfile | ||
import srt | ||
import datetime | ||
import os | ||
import re | ||
import logging | ||
|
||
from vosk import KaldiRecognizer, Model | ||
from pathlib import Path | ||
|
||
|
||
WORDS_PER_LINE = 7 | ||
MODEL_PRE_PATH = 'https://alphacephei.com/vosk/models/' | ||
MODEL_LIST_URL = MODEL_PRE_PATH + 'model-list.json' | ||
|
||
class Transcriber: | ||
|
||
def get_result_and_tot_samples(self, rec, process): | ||
tot_samples = 0 | ||
result = [] | ||
while True: | ||
data = process.stdout.read(4000) | ||
if len(data) == 0: | ||
break | ||
if rec.AcceptWaveform(data): | ||
tot_samples += len(data) | ||
result.append(json.loads(rec.Result())) | ||
result.append(json.loads(rec.FinalResult())) | ||
return result, tot_samples | ||
|
||
def transcribe(self, model, process, args): | ||
rec = KaldiRecognizer(model, 16000) | ||
rec.SetWords(True) | ||
result, tot_samples = self.get_result_and_tot_samples(rec, process) | ||
final_result = '' | ||
if args.outputtype == 'srt': | ||
subs = [] | ||
for i, res in enumerate(result): | ||
if not 'result' in res: | ||
continue | ||
words = res['result'] | ||
for j in range(0, len(words), WORDS_PER_LINE): | ||
line = words[j : j + WORDS_PER_LINE] | ||
s = srt.Subtitle(index=len(subs), | ||
content = ' '.join([l['word'] for l in line]), | ||
start=datetime.timedelta(seconds=line[0]['start']), | ||
end=datetime.timedelta(seconds=line[-1]['end'])) | ||
subs.append(s) | ||
final_result = srt.compose(subs) | ||
elif args.outputtype == 'txt': | ||
for part in result: | ||
final_result += part['text'] + ' ' | ||
return final_result, tot_samples | ||
|
||
def resample_ffmpeg(self, infile): | ||
stream = subprocess.Popen( | ||
['ffmpeg', '-nostdin', '-loglevel', 'quiet', '-i', | ||
infile, | ||
'-ar', '16000','-ac', '1', '-f', 's16le', '-'], | ||
stdout=subprocess.PIPE) | ||
return stream | ||
|
||
def get_task_list(self, args): | ||
task_list = [(Path(args.input, fn), Path(args.output, Path(fn).stem).with_suffix('.' + args.outputtype)) for fn in os.listdir(args.input)] | ||
return task_list | ||
|
||
def list_models(self): | ||
response = requests.get(MODEL_LIST_URL) | ||
[print(model['name']) for model in response.json()] | ||
exit(1) | ||
|
||
def list_languages(self): | ||
response = requests.get(MODEL_LIST_URL) | ||
list_languages = set([language['lang'] for language in response.json()]) | ||
print(*list_languages, sep='\n') | ||
exit(1) | ||
|
||
def check_args(self, args): | ||
if args.list_models == True: | ||
self.list_models() | ||
elif args.list_languages == True: | ||
self.list_languages() | ||
|
||
def get_model_by_name(self, args, models_path): | ||
if not Path.is_dir(Path(models_path, args.model_name)): | ||
response = requests.get(MODEL_LIST_URL) | ||
result = [model['name'] for model in response.json() if model['name'] == args.model_name] | ||
if result == []: | ||
logging.info('model name "%s" does not exist, request -list_models to see available models' % (args.model_name)) | ||
exit(1) | ||
else: | ||
result = result[0] | ||
else: | ||
result = args.model_name | ||
return result | ||
|
||
def get_model_by_lang(self, args, models_path): | ||
model_file_list = os.listdir(models_path) | ||
model_file = [model for model in model_file_list if re.match(f"vosk-model(-small)?-{args.lang}", model)] | ||
if model_file == []: | ||
response = requests.get(MODEL_LIST_URL) | ||
result = [model['name'] for model in response.json() if model['lang'] == args.lang and model['type'] == 'small' and model['obsolete'] == 'false'] | ||
if result == []: | ||
logging.info('language "%s" does not exist, request -list_languages to see available languages' % (args.lang)) | ||
exit(1) | ||
else: | ||
result = result[0] | ||
else: | ||
result = model_file[0] | ||
return result | ||
|
||
def get_model(self, args): | ||
models_path = Path.home() / '.cache' / 'vosk' | ||
if not Path.is_dir(models_path): | ||
Path.mkdir(models_path) | ||
if args.lang == None: | ||
result = self.get_model_by_name(args, models_path) | ||
else: | ||
result = self.get_model_by_lang(args, models_path) | ||
model_location = Path(models_path, result) | ||
if not model_location.exists(): | ||
model_zip = model_location + '.zip' | ||
urllib.request.urlretrieve(MODEL_PRE_PATH + model_location[len(str(Path(model_location).parent))+1:] + '.zip', model_zip) | ||
with zipfile.ZipFile(model_zip, 'r') as model_ref: | ||
model_ref.extractall(models_path) | ||
Path.unlink(Path(model_zip)) | ||
model = Model(str(model_location)) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import logging | ||
import argparse | ||
|
||
from datetime import datetime as dt | ||
from transcriber import Transcriber | ||
from multiprocessing.dummy import Pool | ||
from pathlib import Path | ||
|
||
|
||
parser = argparse.ArgumentParser( | ||
description = 'Transcribe audio file and save result in selected format') | ||
parser.add_argument( | ||
'-model', type=str, | ||
help='model path') | ||
parser.add_argument( | ||
'-list_models', default=False, action='store_true', | ||
help='list available models') | ||
parser.add_argument( | ||
'-list_languages', default=False, action='store_true', | ||
help='list available languages') | ||
parser.add_argument( | ||
'-model_name', default='vosk-model-small-en-us-0.15', type=str, | ||
help='select model by name') | ||
parser.add_argument( | ||
'-lang', type=str, | ||
help='select model by language') | ||
parser.add_argument( | ||
'-input', type=str, | ||
help='audiofile') | ||
parser.add_argument( | ||
'-output', default='', type=str, | ||
help='optional output filename path') | ||
parser.add_argument( | ||
'-otype', '--outputtype', default='txt', type=str, | ||
help='optional arg output data type') | ||
parser.add_argument( | ||
'--log', default='INFO', | ||
help='logging level') | ||
|
||
args = parser.parse_args() | ||
log_level = args.log.upper() | ||
logging.getLogger().setLevel(log_level) | ||
logging.info('checking args') | ||
|
||
def get_results(inputdata): | ||
logging.info('converting audiofile to 16K sampled wav') | ||
stream = transcriber.resample_ffmpeg(inputdata[0]) | ||
logging.info('starting transcription') | ||
final_result, tot_samples = transcriber.transcribe(model, stream, args) | ||
logging.info('complete') | ||
if args.output: | ||
with open(inputdata[1], 'w', encoding='utf-8') as fh: | ||
fh.write(final_result) | ||
logging.info('output written to %s' % (inputdata[1])) | ||
else: | ||
print(final_result) | ||
return final_result, tot_samples | ||
|
||
def main(args): | ||
global model | ||
global transcriber | ||
transcriber = Transcriber() | ||
transcriber.check_args(args) | ||
if args.input: | ||
model = transcriber.get_model(args) | ||
if Path(args.input).is_dir() and Path(args.output).is_dir(): | ||
task_list = transcriber.get_task_list(args) | ||
with Pool() as pool: | ||
for final_result, tot_samples in pool.map(get_results, file_list): | ||
return final_result, tot_samples | ||
else: | ||
if Path(args.input).is_file(): | ||
task = (args.input, args.output) | ||
final_result, tot_samples = get_results(task) | ||
elif not Path(args.input).exists(): | ||
logging.info('File %s does not exist, please select an existing file' % (args.input)) | ||
exit(1) | ||
elif not Path(args.output).exists(): | ||
logging.info('Folder %s does not exist, please select an existing folder' % (args.output)) | ||
exit(1) | ||
return final_result, tot_samples | ||
else: | ||
logging.info('Please specify input file or directory') | ||
exit(1) | ||
|
||
def get_start_time(): | ||
start_time = dt.now() | ||
return start_time | ||
|
||
def get_end_time(start_time): | ||
script_time = str(dt.now() - start_time) | ||
seconds = script_time[5:8].strip('0') | ||
mcseconds = script_time[8:].strip('0') | ||
return script_time.strip(':0'), seconds.rstrip('.'), mcseconds | ||
|
||
if __name__ == '__main__': | ||
start_time = get_start_time() | ||
tot_samples = main(args)[1] | ||
diff_end_start, sec, mcsec = get_end_time(start_time) | ||
logging.info(f'''Execution time: {sec} sec, {mcsec} mcsec; xRT: {format(tot_samples / 16000.0 / float(diff_end_start), '.3f')}''') |