diff --git a/src/BSeries.jl b/src/BSeries.jl index fb4f7e2e..764c128f 100644 --- a/src/BSeries.jl +++ b/src/BSeries.jl @@ -989,10 +989,11 @@ function _evaluate(f, u, dt, series, ::EagerEvaluation, reduce_order_by) end """ - modified_equation(series_integrator) + modified_equation(series_integrator, thread::Bool = Threads.nthreads() > 1) Compute the B-series of the modified equation of the time integration method -with B-series `series_integrator`. +with B-series `series_integrator` using multiple threads if Julia is started +with multiple threads and `thread` is set to `true`. Given an ordinary differential equation (ODE) ``u'(t) = f(u(t))`` and a Runge-Kutta method, the idea is to interpret the numerical solution with @@ -1014,11 +1015,41 @@ Section 3.2 of Foundations of Computational Mathematics [DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1) """ -function modified_equation(series_integrator) - _modified_equation(series_integrator, evaluation_type(series_integrator)) +function modified_equation(series_integrator, thread::Bool = Threads.nthreads() > 1) + if thread + _modified_equation_thread(series_integrator, + evaluation_type(series_integrator)) + else + _modified_equation_serial(series_integrator, + evaluation_type(series_integrator)) + end +end + +function _modified_equation_serial(series_integrator, ::EagerEvaluation) + # Setup shared between the serial and threaded versions + series, series_keys, series_ex, iter = _modified_equation_shared(series_integrator) + + # Recursively solve + # substitute(series, series_ex, t) == series_integrator[t] + # This works because + # substitute(series, series_ex, t) = series[t] + lower order terms + + # Since the `keys` are ordered, we don't need to use nested loops of the form + # for o in 2:order + # for _t in RootedTreeIterator(o) + # t = copy(_t) + # which are slightly less efficient due to additional computations and + # allocations. + while iter !== nothing + t, t_state = iter + series[t] += series_integrator[t] - substitute(series, series_ex, t) + iter = iterate(series_keys, t_state) + end + + return series end -function _modified_equation(series_integrator, ::EagerEvaluation) +@inline function _modified_equation_shared(series_integrator) V = valtype(series_integrator) # B-series of the exact solution @@ -1050,10 +1081,24 @@ function _modified_equation(series_integrator, ::EagerEvaluation) iter = iterate(series_keys, t_state) end + return series, series_keys, series_ex, iter +end + +function _modified_equation_thread(series_integrator, ::EagerEvaluation) + # Setup shared between the serial and threaded versions + series, series_keys, series_ex, iter = _modified_equation_shared(series_integrator) + # Recursively solve # substitute(series, series_ex, t) == series_integrator[t] # This works because # substitute(series, series_ex, t) = series[t] + lower order terms + + # Here, we use the serial version up to a specified `cutoff_order`, i.e., + # for low-order trees, since it avoids the parallel overhead. We only use + # the parallel (threaded) version for trees of an order of at least + # `cutoff_order`. + cutoff_order = 5 + # Since the `keys` are ordered, we don't need to use nested loops of the form # for o in 2:order # for _t in RootedTreeIterator(o) @@ -1062,10 +1107,49 @@ function _modified_equation(series_integrator, ::EagerEvaluation) # allocations. while iter !== nothing t, t_state = iter + order(t) >= cutoff_order && break series[t] += series_integrator[t] - substitute(series, series_ex, t) iter = iterate(series_keys, t_state) end + # The algorithm has a data dependency: It is assumed that the coefficients + # of the new `series` are already computed for all trees with a lower order + # than the current tree. Thus, we can use threaded parallelism only over a + # set of trees of the same order. + # for o in cutoff_order:order(series_integrator) + # # We need to collect the trees we will iterate over in a vector for + # # threaded parallelism. + # # TODO: This should be the iterator type specified by the keys + # # of the series_integrator + # trees = map(copy, RootedTreeIterator(o)) + # Threads.@threads for t in trees + # series[t] += series_integrator[t] - substitute(series, series_ex, t) + # end + # end + + idx_stop = findfirst(==(cutoff_order) ∘ order, series_integrator.coef.keys) + if idx_stop === nothing + return series + else + idx_stop = idx_stop - 1 + end + for o in cutoff_order:order(series_integrator) + # TODO: This uses internal implementation details... + idx_start = findnext(==(o) ∘ order, series_integrator.coef.keys, idx_stop) + idx_stop = findnext(==(o + 1) ∘ order, series_integrator.coef.keys, idx_start) + if idx_stop === nothing + idx_stop = lastindex(series_integrator.coef.keys) + end + # We iterate over the indices instead of the trees in the threaded + # loop since that is slightly more efficient at the time of writing + # due to less allocations. + indices = idx_start:idx_stop + Threads.@threads for i in indices + t = @inbounds series_integrator.coef.keys[i] + series[t] += series_integrator[t] - substitute(series, series_ex, t) + end + end + return series end