diff --git a/rvc/train/extract/extract.py b/rvc/train/extract/extract.py index 124ff0fb..9b7bd21a 100644 --- a/rvc/train/extract/extract.py +++ b/rvc/train/extract/extract.py @@ -23,13 +23,10 @@ # Load config config = Config() - mp.set_start_method("spawn", force=True) class FeatureInput: - """Class for F0 extraction.""" - def __init__(self, sample_rate=16000, hop_size=160, device="cpu"): self.fs = sample_rate self.hop = hop_size @@ -41,17 +38,15 @@ def __init__(self, sample_rate=16000, hop_size=160, device="cpu"): self.device = device self.model_rmvpe = None - def compute_f0(self, np_arr, f0_method, hop_length): - """Extract F0 using the specified method.""" - if f0_method == "crepe": - return self.get_crepe(np_arr, hop_length) - elif f0_method == "rmvpe": - return self.model_rmvpe.infer_from_audio(np_arr, thred=0.03) - else: - raise ValueError(f"Unknown F0 method: {f0_method}") + def compute_f0(self, audio_array, method, hop_length): + if method == "crepe": + return self._get_crepe(audio_array, hop_length, type="full") + elif method == "crepe-tiny": + return self._get_crepe(audio_array, hop_length, type="tiny") + elif method == "rmvpe": + return self.model_rmvpe.infer_from_audio(audio_array, thred=0.03) - def get_crepe(self, x, hop_length): - """Extract F0 using CREPE.""" + def _get_crepe(self, x, hop_length, type): audio = torch.from_numpy(x.astype(np.float32)).to(self.device) audio /= torch.quantile(torch.abs(audio), 0.999) audio = audio.unsqueeze(0) @@ -61,24 +56,24 @@ def get_crepe(self, x, hop_length): hop_length, self.f0_min, self.f0_max, - "full", + type, batch_size=hop_length * 2, device=audio.device, pad=True, ) source = pitch.squeeze(0).cpu().float().numpy() source[source < 0.001] = np.nan - target = np.interp( - np.arange(0, len(source) * (x.size // self.hop), len(source)) - / (x.size // self.hop), - np.arange(0, len(source)), - source, + return np.nan_to_num( + np.interp( + np.arange(0, len(source) * (x.size // self.hop), len(source)) + / (x.size // self.hop), + np.arange(0, len(source)), + source, + ) ) - return np.nan_to_num(target) def coarse_f0(self, f0): - """Convert F0 to coarse F0.""" - f0_mel = 1127 * np.log(1 + f0 / 700) + f0_mel = 1127.0 * np.log(1.0 + f0 / 700.0) f0_mel = np.clip( (f0_mel - self.f0_mel_min) * (self.f0_bin - 2) @@ -90,27 +85,22 @@ def coarse_f0(self, f0): return np.rint(f0_mel).astype(int) def process_file(self, file_info, f0_method, hop_length): - """Process a single audio file for F0 extraction.""" - inp_path, opt_path1, opt_path2, _ = file_info - - if os.path.exists(opt_path1) and os.path.exists(opt_path2): + inp_path, opt_path_coarse, opt_path_full, _ = file_info + if os.path.exists(opt_path_coarse) and os.path.exists(opt_path_full): return try: - np_arr = load_audio(inp_path, 16000) + np_arr = load_audio(inp_path, self.fs) feature_pit = self.compute_f0(np_arr, f0_method, hop_length) - np.save(opt_path2, feature_pit, allow_pickle=False) + np.save(opt_path_full, feature_pit, allow_pickle=False) coarse_pit = self.coarse_f0(feature_pit) - np.save(opt_path1, coarse_pit, allow_pickle=False) + np.save(opt_path_coarse, coarse_pit, allow_pickle=False) except Exception as error: print( f"An error occurred extracting file {inp_path} on {self.device}: {error}" ) - def process_files( - self, files, f0_method, hop_length, device_num, device, n_threads - ): - """Process multiple files.""" + def process_files(self, files, f0_method, hop_length, device, threads): self.device = device if f0_method == "rmvpe": self.model_rmvpe = RMVPE0Predictor( @@ -118,127 +108,99 @@ def process_files( is_half=False, device=device, ) - else: - n_threads = 1 - - n_threads = 1 if n_threads == 0 else n_threads - def process_file_wrapper(file_info): + def worker(file_info): self.process_file(file_info, f0_method, hop_length) - with tqdm.tqdm(total=len(files), leave=True, position=device_num) as pbar: - # using multi-threading - with concurrent.futures.ThreadPoolExecutor( - max_workers=n_threads - ) as executor: - futures = [ - executor.submit(process_file_wrapper, file_info) - for file_info in files - ] - for future in concurrent.futures.as_completed(futures): + with tqdm.tqdm(total=len(files), leave=True) as pbar: + with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: + futures = [executor.submit(worker, f) for f in files] + for _ in concurrent.futures.as_completed(futures): pbar.update(1) -def run_pitch_extraction(files, devices, f0_method, hop_length, num_processes): +def run_pitch_extraction(files, devices, f0_method, hop_length, threads): devices_str = ", ".join(devices) print( f"Starting pitch extraction with {num_processes} cores on {devices_str} using {f0_method}..." ) start_time = time.time() fe = FeatureInput() - # split the task between devices - ps = [] - num_devices = len(devices) - for i, device in enumerate(devices): - p = mp.Process( - target=fe.process_files, - args=( - files[i::num_devices], + with concurrent.futures.ProcessPoolExecutor(max_workers=len(devices)) as executor: + tasks = [ + executor.submit( + fe.process_files, + files[i :: len(devices)], f0_method, hop_length, - i, - device, - num_processes // num_devices, - ), - ) - ps.append(p) - p.start() - for i, device in enumerate(devices): - ps[i].join() + devices[i], + threads // len(devices), + ) + for i in range(len(devices)) + ] + concurrent.futures.wait(tasks) - elapsed_time = time.time() - start_time - print(f"Pitch extraction completed in {elapsed_time:.2f} seconds.") + print(f"Pitch extraction completed in {time.time() - start_time:.2f} seconds.") def process_file_embedding( files, version, embedder_model, embedder_model_custom, device_num, device, n_threads ): - dtype = torch.float16 if config.is_half and "cuda" in device else torch.float32 + dtype = torch.float16 if (config.is_half and "cuda" in device) else torch.float32 model = load_embedding(embedder_model, embedder_model_custom).to(dtype).to(device) - n_threads = 1 if n_threads == 0 else n_threads + n_threads = max(1, n_threads) - def process_file_embedding_wrapper(file_info): + def worker(file_info): wav_file_path, _, _, out_file_path = file_info if os.path.exists(out_file_path): return feats = torch.from_numpy(load_audio(wav_file_path, 16000)).to(dtype).to(device) feats = feats.view(1, -1) with torch.no_grad(): - feats = model(feats)["last_hidden_state"] - feats = ( - model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats - ) - feats = feats.squeeze(0).float().cpu().numpy() - if not np.isnan(feats).any(): - np.save(out_file_path, feats, allow_pickle=False) + result = model(feats)["last_hidden_state"] + if version == "v1": + result = model.final_proj(result[0]).unsqueeze(0) + feats_out = result.squeeze(0).float().cpu().numpy() + if not np.isnan(feats_out).any(): + np.save(out_file_path, feats_out, allow_pickle=False) else: - print(f"{file} contains NaN values and will be skipped.") + print(f"{wav_file_path} produced NaN values; skipping.") with tqdm.tqdm(total=len(files), leave=True, position=device_num) as pbar: - # using multi-threading with concurrent.futures.ThreadPoolExecutor(max_workers=n_threads) as executor: - futures = [ - executor.submit(process_file_embedding_wrapper, file_info) - for file_info in files - ] - for future in concurrent.futures.as_completed(futures): + futures = [executor.submit(worker, f) for f in files] + for _ in concurrent.futures.as_completed(futures): pbar.update(1) def run_embedding_extraction( - files, devices, version, embedder_model, embedder_model_custom + files, devices, version, embedder_model, embedder_model_custom, threads ): - start_time = time.time() devices_str = ", ".join(devices) print( f"Starting embedding extraction with {num_processes} cores on {devices_str}..." ) - # split the task between devices - ps = [] - num_devices = len(devices) - for i, device in enumerate(devices): - p = mp.Process( - target=process_file_embedding, - args=( - files[i::num_devices], + start_time = time.time() + with concurrent.futures.ProcessPoolExecutor(max_workers=len(devices)) as executor: + tasks = [ + executor.submit( + process_file_embedding, + files[i :: len(devices)], version, embedder_model, embedder_model_custom, i, - device, - num_processes // num_devices, - ), - ) - ps.append(p) - p.start() - for i, device in enumerate(devices): - ps[i].join() - elapsed_time = time.time() - start_time - print(f"Embedding extraction completed in {elapsed_time:.2f} seconds.") + devices[i], + threads // len(devices), + ) + for i in range(len(devices)) + ] + concurrent.futures.wait(tasks) + print(f"Embedding extraction completed in {time.time() - start_time:.2f} seconds.") -if __name__ == "__main__": +if __name__ == "__main__": exp_dir = sys.argv[1] f0_method = sys.argv[2] hop_length = int(sys.argv[3]) @@ -250,27 +212,21 @@ def run_embedding_extraction( embedder_model_custom = sys.argv[9] if len(sys.argv) > 9 else None include_mutes = int(sys.argv[10]) if len(sys.argv) > 10 else 2 - # prep wav_path = os.path.join(exp_dir, "sliced_audios_16k") os.makedirs(os.path.join(exp_dir, "f0"), exist_ok=True) os.makedirs(os.path.join(exp_dir, "f0_voiced"), exist_ok=True) os.makedirs(os.path.join(exp_dir, version + "_extracted"), exist_ok=True) - # write to model_info.json + chosen_embedder_model = ( embedder_model_custom if embedder_model == "custom" else embedder_model ) - file_path = os.path.join(exp_dir, "model_info.json") if os.path.exists(file_path): with open(file_path, "r") as f: data = json.load(f) else: data = {} - data.update( - { - "embedder_model": chosen_embedder_model, - } - ) + data["embedder_model"] = chosen_embedder_model with open(file_path, "w") as f: json.dump(data, f, indent=4) @@ -278,7 +234,7 @@ def run_embedding_extraction( for file in glob.glob(os.path.join(wav_path, "*.wav")): file_name = os.path.basename(file) file_info = [ - file, # full path to sliced 16k wav + file, os.path.join(exp_dir, "f0", file_name + ".npy"), os.path.join(exp_dir, "f0_voiced", file_name + ".npy"), os.path.join( @@ -288,14 +244,12 @@ def run_embedding_extraction( files.append(file_info) devices = ["cpu"] if gpus == "-" else [f"cuda:{idx}" for idx in gpus.split("-")] - # Run Pitch Extraction + run_pitch_extraction(files, devices, f0_method, hop_length, num_processes) - # Run Embedding Extraction run_embedding_extraction( - files, devices, version, embedder_model, embedder_model_custom + files, devices, version, embedder_model, embedder_model_custom, num_processes ) - # Run Preparing Files generate_config(version, sample_rate, exp_dir) generate_filelist(exp_dir, version, sample_rate, include_mutes)