Skip to content

Commit

Permalink
update shared state n_threads in parallel region
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed May 30, 2024
1 parent 7918ed7 commit fa864af
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1751,7 +1751,7 @@ struct ggml_compute_state_shared {
int64_t perf_node_start_cycles;
int64_t perf_node_start_time_us;

const int n_threads;
int n_threads;

// synchronization primitives
atomic_int n_active; // num active threads
Expand Down Expand Up @@ -19486,12 +19486,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
if (n_threads <= 0) {
n_threads = GGML_DEFAULT_N_THREADS;
}
#if defined(GGML_USE_OPENMP)
// Limit the number of threads used to avoid deadlock
// ref: https://github.com/ggerganov/llama.cpp/pull/7606
n_threads = MIN(n_threads, omp_get_max_threads());
n_threads = MIN(n_threads, omp_get_thread_limit());
#endif

size_t work_size = 0;

Expand Down Expand Up @@ -19676,9 +19670,20 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
enum ggml_status compute_status = GGML_STATUS_SUCCESS;

#ifdef GGML_USE_OPENMP
#pragma omp parallel num_threads(n_threads)
{
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
if (n_threads > 1) {
#pragma omp parallel num_threads(n_threads)
{
#pragma omp single
{
// update the number of threads from the actual number of threads that we got from OpenMP
n_threads = omp_get_num_threads();
workers[0].shared->n_threads = n_threads;
workers[0].shared->n_active = n_threads;
}
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
}
} else {
ggml_graph_compute_thread(&workers[0]);
}
#else
// create thread pool
Expand Down Expand Up @@ -19724,7 +19729,12 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
}
}

const int n_threads = cplan->n_threads;
int n_threads = cplan->n_threads;

#if defined(GGML_USE_OPENMP)
n_threads = MIN(n_threads, omp_get_max_threads());
n_threads = MIN(n_threads, omp_get_thread_limit());
#endif

struct ggml_compute_state_shared state_shared = {
/*.cgraph =*/ cgraph,
Expand Down

0 comments on commit fa864af

Please sign in to comment.