Skip to content

Commit

Permalink
Update infer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TheNeodev authored Dec 18, 2024
1 parent 4d726a1 commit b337c53
Showing 1 changed file with 16 additions and 97 deletions.
113 changes: 16 additions & 97 deletions rvc_inferpy/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
adjust_audio_lengths,
combine_silence_nonsilent,
)
from rvc_inferpy.config_loader import *
import torch
from pathlib import Path
import requests
Expand All @@ -21,6 +22,17 @@
models_dir = "models"



validate_config_and_files()

BaseLoader(hubert_path=hubert_model_path, rmvpe_path=rmvpe_model_path)
rvcbasdl = lambda: print("RVC-based loader initialized.") # Replace with the actual function
rvcbasdl()





class Configs:
def __init__(self, device, is_half):
self.device = device
Expand Down Expand Up @@ -85,101 +97,9 @@ def get_model(voice_model):
)


BASE_DIR = Path(".")
files_to_check = ["hubert_base.pt", "rmvpe.pt", "fcpe.pt"]

missing_files = [file for file in files_to_check if not (BASE_DIR / file).exists()]


def dl_model(link, model_name, dir_name):
url = f"{link}/{model_name}"
target_path = dir_name / model_name
target_path.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists

torch.hub.download_url_to_file(url, str(target_path))
print(f"{model_name} downloaded successfully!")


if missing_files:
RVC_DOWNLOAD_LINK = "https://huggingface.co/theNeofr/rvc-base/resolve/main" # Replace with the actual download link

for model in missing_files:
print(f"Downloading {model}...")
dl_model(RVC_DOWNLOAD_LINK, model, BASE_DIR)

print("All missing models have been downloaded!")
else:
pass


def extract_zip(extraction_folder, zip_name):
os.makedirs(extraction_folder, exist_ok=True)
with zipfile.ZipFile(zip_name, "r") as zip_ref:
zip_ref.extractall(extraction_folder)
os.remove(zip_name)

index_filepath, model_filepath = None, None
for root, dirs, files in os.walk(extraction_folder):
for name in files:
if (
name.endswith(".index")
and os.stat(os.path.join(root, name)).st_size > 1024 * 100
):
index_filepath = os.path.join(root, name)

if (
name.endswith(".pth")
and os.stat(os.path.join(root, name)).st_size > 1024 * 1024 * 40
):
model_filepath = os.path.join(root, name)

if not model_filepath:
raise Exception(
f"No .pth model file was found in the extracted zip. Please check {extraction_folder}."
)

os.rename(
model_filepath,
os.path.join(extraction_folder, os.path.basename(model_filepath)),
)
if index_filepath:
os.rename(
index_filepath,
os.path.join(extraction_folder, os.path.basename(index_filepath)),
)

# Remove unnecessary nested folders
for filepath in os.listdir(extraction_folder):
if os.path.isdir(os.path.join(extraction_folder, filepath)):
shutil.rmtree(os.path.join(extraction_folder, filepath))


def download_rvc_model(url, dir_name):
try:
print(f"[~] Downloading voice model with name {dir_name}...")
zip_name = url.split("/")[-1]
extraction_folder = os.path.join(models_dir, dir_name)
if os.path.exists(extraction_folder):
raise Exception(
f"Voice model directory {dir_name} already exists! Choose a different name for your voice model."
)

if "pixeldrain.com" in url:
url = f"https://pixeldrain.com/api/file/{zip_name}"
if "drive.google.com" in url:
zip_name = dir_name + ".zip"
gdown.download(
url, output=zip_name, use_cookies=True, quiet=True, fuzzy=True
)
else:
urllib.request.urlretrieve(url, zip_name)

print(f"[~] Extracting zip file...")
extract_zip(extraction_folder, zip_name)
print(f"[+] {dir_name} Model successfully downloaded!")

except Exception as e:
raise Exception(str(e))


def infer_audio(
Expand All @@ -205,12 +125,11 @@ def infer_audio(
f0_autotune=False,
audio_format="wav",
resample_sr=0,
hubert_model_path="hubert_base.pt",
rmvpe_model_path="rmvpe.pt",
fcpe_model_path="fcpe.pt",
hubert_model_path=hubert_model_path,
rmvpe_model_path=rmvpe_model_path,
fcpe_model_path=fcpe_model_path,
):
os.environ["rmvpe_model_path"] = rmvpe_model_path
os.environ["fcpe_model_path"] = fcpe_model_path

configs = Configs("cuda:0", True)
vc = VC(configs)
pth_path, index_path = get_model(model_name)
Expand Down

0 comments on commit b337c53

Please sign in to comment.