Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Drum Transcription: add num_threads and parallel_workers for madmom beat extraction #111

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions omnizart/drum/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self):
super().__init__(DrumSettings)
self.custom_objects = {"ConvSN2D": ConvSN2D}

def transcribe(self, input_audio, model_path=None, output="./"):
def transcribe(self, input_audio, model_path=None, output="./", beat_tracker_num_threads=3, beat_tracker_parallel_workers=3):
"""Transcribe drum in the audio.

This function transcribes drum activations in the music. Currently the model
Expand Down Expand Up @@ -62,7 +62,11 @@ def transcribe(self, input_audio, model_path=None, output="./"):

# Extract feature according to model configuration
logger.info("Extracting feature...")
patch_cqt_feature, mini_beat_arr = extract_patch_cqt(input_audio)
patch_cqt_feature, mini_beat_arr = extract_patch_cqt(
input_audio,
num_threads=beat_tracker_num_threads,
num_workers=beat_tracker_parallel_workers
)

# Load model configurations
logger.info("Loading model...")
Expand Down
58 changes: 36 additions & 22 deletions omnizart/feature/beat_for_drum.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ class MadmomBeatTracking:
Three different beat tracking methods are used together for producing a more
stable beat tracking result.
"""
def __init__(self, num_threads=3):
def __init__(self, num_threads=3, parallel_workers=3):
self.num_threads = num_threads
self.parallel_workers=parallel_workers

def _get_dbn_down_beat(self, audio_data_in1, min_bpm_in=50, max_bpm_in=230):
proccesor = DBNDownBeatTrackingProcessor(
Expand All @@ -51,23 +52,30 @@ def _get_beat(self, audio_data_in3):

def process(self, audio_data):
"""Generate beat tracking results with multiple approaches."""
with ProcessPoolExecutor(max_workers=3) as executor:
logger.debug("Submitting and executing parallel beat tracking jobs")
future_1 = executor.submit(self._get_dbn_down_beat, audio_data, min_bpm_in=50, max_bpm_in=230)
future_2 = executor.submit(self._get_dbn_beat, audio_data)
future_3 = executor.submit(self._get_beat, audio_data)

queue = {future_1: "dbn_down_beat", future_2: "dbn_beat", future_3: "beat"}

results = {}
for future in concurrent.futures.as_completed(queue, timeout=600):
func_name = queue[future]
results[func_name] = future.result()
logger.debug("Job %s finished.", func_name)

pred_beats1 = results["dbn_down_beat"]
pred_beats2 = results["dbn_beat"]
pred_beats3 = results["beat"]
if self.parallel_workers == 0:
# Run sequentially
logger.debug("Running beat tracking sequentially...")
pred_beats1 = self._get_dbn_down_beat(audio_data, min_bpm_in=50, max_bpm_in=230)
pred_beats2 = self._get_dbn_beat(audio_data)
pred_beats3 = self._get_beat(audio_data)
else:
with ProcessPoolExecutor(max_workers=self.parallel_workers) as executor:
logger.debug("Submitting and executing parallel beat tracking jobs")
future_1 = executor.submit(self._get_dbn_down_beat, audio_data, min_bpm_in=50, max_bpm_in=230)
future_2 = executor.submit(self._get_dbn_beat, audio_data)
future_3 = executor.submit(self._get_beat, audio_data)

queue = {future_1: "dbn_down_beat", future_2: "dbn_beat", future_3: "beat"}

results = {}
for future in concurrent.futures.as_completed(queue, timeout=600):
func_name = queue[future]
results[func_name] = future.result()
logger.debug("Job %s finished.", func_name)

pred_beats1 = results["dbn_down_beat"]
pred_beats2 = results["dbn_beat"]
pred_beats3 = results["beat"]

pred_beat_len1 = np.mean(
np.sort(pred_beats1[1:] - pred_beats1[:-1])[int(len(pred_beats1) * 0.2):int(len(pred_beats1) * 0.8)]
Expand All @@ -89,7 +97,7 @@ def process(self, audio_data):
return self._get_dbn_down_beat(audio_data, min_bpm_in=pred_bpm_avg / 1.38, max_bpm_in=pred_bpm_avg * 1.38)


def extract_beat_with_madmom(audio_path, sampling_rate=44100):
def extract_beat_with_madmom(audio_path, sampling_rate=44100, parallel_workers=3, num_threads=3):
"""Extract beat position (in seconds) of the audio.

Extract beat with mixture of beat tracking techiniques using madmom.
Expand All @@ -111,7 +119,8 @@ def extract_beat_with_madmom(audio_path, sampling_rate=44100):
logger.debug("Loading audio: %s", audio_path)
audio_data, _ = load_audio(audio_path, sampling_rate=sampling_rate)
logger.debug("Runnig beat tracking...")
return MadmomBeatTracking().process(audio_data), len(audio_data) / sampling_rate
mbt = MadmomBeatTracking(num_threads=num_threads, parallel_workers=parallel_workers)
return mbt.process(audio_data), len(audio_data) / sampling_rate


def extract_mini_beat_from_beat_arr(beat_arr, audio_len_sec, mini_beat_div_n=32):
Expand Down Expand Up @@ -152,10 +161,15 @@ def extract_mini_beat_from_beat_arr(beat_arr, audio_len_sec, mini_beat_div_n=32)
return mini_beat_pos_t


def extract_mini_beat_from_audio_path(audio_path, sampling_rate=44100, mini_beat_div_n=32):
def extract_mini_beat_from_audio_path(audio_path, sampling_rate=44100, mini_beat_div_n=32, parallel_workers=3, num_threads=3):
""" Wrapper of extracting mini beats from audio path. """
logger.debug("Extracting beat with madmom")
beat_arr, audio_len_sec = extract_beat_with_madmom(audio_path, sampling_rate=sampling_rate)
beat_arr, audio_len_sec = extract_beat_with_madmom(
audio_path,
sampling_rate=sampling_rate,
parallel_workers=parallel_workers,
num_threads=num_threads
)
logger.debug("Extracting mini beat")
return extract_mini_beat_from_beat_arr(beat_arr, audio_len_sec, mini_beat_div_n=mini_beat_div_n)

Expand Down
9 changes: 7 additions & 2 deletions omnizart/feature/wrapper_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_frame_by_time(time_sec, sampling_rate=44100, hop_size=256):
return int(round(time_sec * sampling_rate / hop_size))


def extract_patch_cqt(audio_path, sampling_rate=44100, hop_size=256):
def extract_patch_cqt(audio_path, sampling_rate=44100, hop_size=256, beat_tracker_num_threads=3, beat_tracker_parallel_workers=3):
"""Extract patched CQT feature.

Leverages mini-beat information to determine the bound of each
Expand All @@ -51,7 +51,12 @@ def extract_patch_cqt(audio_path, sampling_rate=44100, hop_size=256):
omnizart.feature.beat_for_drum.extract_mini_beat_from_audio_path: Function for extracting mini-beat.
"""
cqt_ext = cqt.extract_cqt(audio_path, sampling_rate=sampling_rate, a_hop=hop_size)
mini_beat_arr = b4d.extract_mini_beat_from_audio_path(audio_path, sampling_rate=sampling_rate)
mini_beat_arr = b4d.extract_mini_beat_from_audio_path(
audio_path,
sampling_rate=sampling_rate,
num_threads=beat_tracker_num_threads,
parallel_workers=beat_tracker_parallel_workers
)

m_beat_cqt_patch_list = []
for m_beat_t_cur in mini_beat_arr:
Expand Down
Loading