Skip to content

Commit

Permalink
add joblib parallel feature computation
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Jan 18, 2024
1 parent 75feba8 commit 460b986
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions py_neuromodulation/nm_stream_offline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Module for offline data streams."""
import math
import os

import multiprocessing as mp
from joblib import Parallel, delayed
from itertools import count

import numpy as np
Expand Down Expand Up @@ -92,14 +91,14 @@ def _process_batch(self, data_batch, cnt_samples):
)
feature_series = self._add_timestamp(feature_series, cnt_samples)
return feature_series

def _run_offline(
self,
data: np.ndarray,
out_path_root: _PathLike | None = None,
folder_name: str = "sub",
parallel: bool = True,
num_threads = None
parallel: bool = False,
n_jobs: int = -2,
) -> pd.DataFrame:
generator = nm_generator.raw_data_generator(
data=data,
Expand All @@ -114,21 +113,30 @@ def _run_offline(
offset_start = offset_time / 1000 * self.sfreq

if parallel:
try: mp.set_start_method('fork') # Set process start method. 'spawn' and 'forkserver' do not work
except RuntimeError: pass # mp.set_start_method() will crash the program if called more than once
pool = mp.Pool(processes=num_threads) # Create sub-process pool. Faster than concurrent.futures.ProcessPoolExecutor()
# Assign tasks to sub-processes, starmap is same as map, only for 2+ arguments that must be zipped
feature_df = pd.DataFrame(pool.starmap(self._process_batch, zip(generator, count(offset_start, sample_add))))
# Prevent memory leaks by releasing process pool resources
pool.close()
pool.join()
l_features = Parallel(n_jobs=n_jobs, verbose=10)(
delayed(self._process_batch)(data_batch, cnt_samples)
for data_batch, cnt_samples in zip(
generator, count(offset_start, sample_add)
)
)

else:
# If no parallelization required, is faster to not use a process pool at all
feature_df = pd.DataFrame(map(self._process_batch, generator, count(offset_start, sample_add)))

# I don't know what this does :(
# if self.model is not None:
# prediction = self.model.predict(features[-1])
l_features = []
cnt_samples = offset_start
while True:
data_batch = next(generator, None)
if data_batch is None:
break
feature_series = self.run_analysis.process(
data_batch.astype(np.float64)
)
feature_series = self._add_timestamp(
feature_series, cnt_samples
)
l_features.append(feature_series)

cnt_samples += sample_add
feature_df = pd.DataFrame(l_features)

feature_df = self._add_labels(features=feature_df, data=data)

Expand Down Expand Up @@ -264,8 +272,8 @@ def run(
data: np.ndarray | pd.DataFrame = None,
out_path_root: _PathLike | None = None,
folder_name: str = "sub",
parallel: bool = True,
num_threads = None
parallel: bool = False,
n_jobs: int = -2,
) -> pd.DataFrame:
"""Call run function for offline stream.
Expand Down Expand Up @@ -294,4 +302,10 @@ def run(
elif self.data is None and data is None:
raise ValueError("No data passed to run function.")

return self._run_offline(data, out_path_root, folder_name, parallel=parallel, num_threads=num_threads)
return self._run_offline(
data,
out_path_root,
folder_name,
parallel=parallel,
n_jobs=n_jobs,
)

0 comments on commit 460b986

Please sign in to comment.