From fa864af945c19816dd073425d62b62e851d52ca2 Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 30 May 2024 09:47:29 +0200 Subject: [PATCH] update shared state n_threads in parallel region --- ggml.c | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/ggml.c b/ggml.c index 0b27712a338fe..1ec3e144c378f 100644 --- a/ggml.c +++ b/ggml.c @@ -1751,7 +1751,7 @@ struct ggml_compute_state_shared { int64_t perf_node_start_cycles; int64_t perf_node_start_time_us; - const int n_threads; + int n_threads; // synchronization primitives atomic_int n_active; // num active threads @@ -19486,12 +19486,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa if (n_threads <= 0) { n_threads = GGML_DEFAULT_N_THREADS; } -#if defined(GGML_USE_OPENMP) - // Limit the number of threads used to avoid deadlock - // ref: https://github.com/ggerganov/llama.cpp/pull/7606 - n_threads = MIN(n_threads, omp_get_max_threads()); - n_threads = MIN(n_threads, omp_get_thread_limit()); -#endif size_t work_size = 0; @@ -19676,9 +19670,20 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * enum ggml_status compute_status = GGML_STATUS_SUCCESS; #ifdef GGML_USE_OPENMP -#pragma omp parallel num_threads(n_threads) - { - ggml_graph_compute_thread(&workers[omp_get_thread_num()]); + if (n_threads > 1) { + #pragma omp parallel num_threads(n_threads) + { + #pragma omp single + { + // update the number of threads from the actual number of threads that we got from OpenMP + n_threads = omp_get_num_threads(); + workers[0].shared->n_threads = n_threads; + workers[0].shared->n_active = n_threads; + } + ggml_graph_compute_thread(&workers[omp_get_thread_num()]); + } + } else { + ggml_graph_compute_thread(&workers[0]); } #else // create thread pool @@ -19724,7 +19729,12 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl } } - const int n_threads = cplan->n_threads; + int n_threads = cplan->n_threads; + +#if defined(GGML_USE_OPENMP) + n_threads = MIN(n_threads, omp_get_max_threads()); + n_threads = MIN(n_threads, omp_get_thread_limit()); +#endif struct ggml_compute_state_shared state_shared = { /*.cgraph =*/ cgraph,