Skip to content

Commit

Permalink
Add final benchmarked optimization for solver omp_parallel_for
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesYang007 committed Nov 27, 2024
1 parent 51c9a6e commit ef2bf49
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions adelie/src/include/adelie_core/solver/solver_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ inline void update_abs_grad(
)
{
using state_t = std::decay_t<StateType>;
using value_t = typename state_t::value_t;
using vec_value_t = typename state_t::vec_value_t;
using rowmat_uint64_t = util::rowmat_type<uint64_t>;

Expand Down Expand Up @@ -94,7 +95,12 @@ inline void update_abs_grad(
try_failed = true;
}
};
util::omp_parallel_for(routine, 0, groups.size(), n_threads);
const bool is_not_all_none = util::rowvec_type<bool>::NullaryExpr(
constraints.size(),
[&](auto i) { return constraints[i] != nullptr; }
).any();
const size_t n_bytes = sizeof(value_t) * abs_grad.size();
util::omp_parallel_for(routine, 0, groups.size(), n_threads * (is_not_all_none || (n_bytes > Configs::min_bytes)));
if (try_failed) {
throw util::adelie_core_solver_error(
"exception raised in constraint->solve_zero(). "
Expand Down Expand Up @@ -157,6 +163,8 @@ inline auto sparsify_dual(
VecValueType& values
)
{
using index_t = typename StateType::index_t;
using value_t = typename StateType::value_t;
using vec_index_t = typename StateType::vec_index_t;
using vec_value_t = typename StateType::vec_value_t;
using sp_vec_value_t = typename StateType::sp_vec_value_t;
Expand Down Expand Up @@ -192,7 +200,12 @@ inline auto sparsify_dual(
constraint->dual(indices_v, values_v);
indices_v += dual_groups[i];
};
util::omp_parallel_for(routine, 0, n_constraints, n_threads);
const bool is_not_all_none = util::rowvec_type<bool>::NullaryExpr(
constraints.size(),
[&](auto i) { return constraints[i] != nullptr; }
).any();
const size_t n_bytes = (sizeof(index_t) + sizeof(value_t)) * indices.size();
util::omp_parallel_for(routine, 0, n_constraints, n_threads * (is_not_all_none || (n_bytes > Configs::min_bytes)));
}

const auto last_constraint = constraints[n_constraints-1];
Expand Down

0 comments on commit ef2bf49

Please sign in to comment.