-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathsetup.py
90 lines (75 loc) · 2.47 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2018 Phonexia
# Author: Jan Profant <[email protected]>
# All Rights Reserved
from distutils.core import setup
import glob
import os
from setuptools.command.install import install
from setuptools.command.develop import develop
from setuptools import find_packages
import tempfile
import zipfile
MODELS_DIR = "diarizer/models"
MODELS = ["ResNet101_16kHz"]
def install_scripts(directory):
"""Call cmd commands to install extra software/repositories.
Args:
directory (str): path
"""
# unpack multiple zip files into .pth file
for model in MODELS:
temp_zip = tempfile.NamedTemporaryFile(delete=False)
nnet_dir = os.path.join(MODELS_DIR, model, "nnet")
assert os.path.isdir(nnet_dir), f"{nnet_dir} does not exist."
for zip_part in sorted(
glob.glob(f'{os.path.join(nnet_dir, "*.pth.zip.part*")}')
):
with open(zip_part, "rb") as f:
temp_zip.write(f.read())
with zipfile.ZipFile(temp_zip, "r") as fzip:
fzip.printdir()
fzip.extractall(path=nnet_dir)
class PostDevelopCommand(develop):
"""Post-installation for development mode."""
def run(self):
develop.run(self)
self.execute(
install_scripts, (self.egg_path,), msg="Running post install scripts"
)
class PostInstallCommand(install):
"""Post-installation for installation mode."""
def run(self):
install.run(self)
self.execute(
install_scripts, (self.install_lib,), msg="Running post install scripts"
)
setup(
name="diarizer",
version="0.1.1",
packages=find_packages(),
url="https://github.com/desh2608/diarizer",
install_requires=[
"numpy>=1.19.5",
"scipy==1.4.1",
"h5py==2.9.0",
"fastcluster==1.2.4",
"onnxruntime==1.4.0",
"soundfile==0.10.2",
"torch==1.10.0",
"numba==0.53.0",
"kaldi_io",
"tabulate>=0.8.6",
"intervaltree",
"spy-der==0.2.0",
"scikit-learn",
"pyannote.audio @ git+https://github.com/desh2608/pyannote-audio.git@develop",
],
dependency_links=[
"https://github.com/desh2608/scikit-learn/releases/download/v0.24.0-dev-overlap/scikit_learn-0.24.dev0-cp38-cp38-linux_x86_64.whl",
],
license="Apache License, Version 2.0",
cmdclass={"install": PostInstallCommand, "develop": PostDevelopCommand},
)