diff --git a/src/quicksort.jl b/src/quicksort.jl index f30ef29d..7f2660a6 100644 --- a/src/quicksort.jl +++ b/src/quicksort.jl @@ -28,30 +28,18 @@ function Base.sort!( return v end -function choose_pivot(xs, order) - return _median( - order, - ( - xs[1], - xs[end÷8], - xs[end÷4], - xs[3*(end÷8)], - xs[end÷2], - xs[5*(end÷8)], - xs[3*(end÷4)], - xs[7*(end÷8)], - xs[end], - ), - ) -end - function _quicksort!(ys, xs, alg, order, givenpivot = nothing) @check length(ys) == length(xs) if length(ys) <= max(8, alg.basesize) return _quicksort_serial!(ys, xs, alg, order) end + isrefined = false pivot = if givenpivot === nothing - choose_pivot(ys, order) + let pivot, ishomogenous + pivot, ishomogenous, isrefined = choose_pivot(ys, alg.basesize, order) + ishomogenous && return ys + pivot + end else something(givenpivot) end @@ -95,6 +83,7 @@ function _quicksort!(ys, xs, alg, order, givenpivot = nothing) total_nbelows = above_offsets[1] if total_nbelows == 0 @assert givenpivot === nothing + @assert !isrefined betterpivot, ishomogenous = refine_pivot(ys, pivot, alg.basesize, order) ishomogenous && return ys return _quicksort!(ys, xs, alg, order, Some(betterpivot)) @@ -124,7 +113,7 @@ function _quicksort_serial!(ys, xs, alg, order) if length(ys) <= max(8, alg.smallsize) return sort!(ys, alg.smallsort, order) end - pivot = choose_pivot(ys, order) + _, pivot = samples_and_pivot(ys, order) nbelows, naboves = quicksort_partition!(xs, ys, pivot, order) @DBG @check nbelows + naboves == length(xs) @@ -161,6 +150,48 @@ function quicksort_copyback!(ys, xs_chunk, nbelows, below_offset, above_offset) end end +@inline function samples_and_pivot(xs, order) + samples = ( + xs[1], + xs[end÷8], + xs[end÷4], + xs[3*(end÷8)], + xs[end÷2], + xs[5*(end÷8)], + xs[3*(end÷4)], + xs[7*(end÷8)], + xs[end], + ) + pivot = _median(order, samples) + return samples, pivot +end + +""" + choose_pivot(xs, basesize, order) -> (pivot, ishomogenous::Bool, isrefined::Bool) +""" +function choose_pivot(xs, basesize, order) + samples, pivot = samples_and_pivot(xs, order) + if ( + eq(order, samples[1], pivot) && + eq(order, samples[1], samples[2]) && + eq(order, samples[2], samples[3]) && + eq(order, samples[3], samples[4]) && + eq(order, samples[4], samples[5]) && + eq(order, samples[5], samples[6]) && + eq(order, samples[6], samples[7]) && + eq(order, samples[7], samples[8]) && + eq(order, samples[8], samples[9]) + ) + pivot, ishomogenous = refine_pivot_serial(@view(xs[1:min(end, 128)]), pivot, order) + if ishomogenous + length(xs) <= 128 && return (pivot, true, true) + pivot, ishomogenous = refine_pivot(@view(xs[129:end]), pivot, basesize, order) + return (pivot, ishomogenous, true) + end + end + return (pivot, false, false) +end + """ refine_pivot(ys, badpivot::T, basesize, order) -> (pivot::T, ishomogenous::Bool) @@ -224,5 +255,3 @@ end # TODO: Check if the homogeneity check can be done in `quicksort_partition!` # without overall performance degradation? Use it to determine the pivot # for the next recursion. -# TODO: Do this right after `choose_pivot` if it finds out that all samples are -# equivalent? diff --git a/src/utils.jl b/src/utils.jl index 4d951ba0..fbf6123b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -37,9 +37,9 @@ function elsizeof(::Type{T}) where {T} end end -eq(order, a, b) = !(Base.lt(order, a, b) || Base.lt(order, b, a)) +@inline eq(order, a, b) = !(Base.lt(order, a, b) || Base.lt(order, b, a)) -function _median(order, (a, b, c)::NTuple{3,Any}) +@inline function _median(order, (a, b, c)::NTuple{3,Any}) # Sort `(a, b, c)`: if Base.lt(order, b, a) a, b = b, a @@ -53,7 +53,7 @@ function _median(order, (a, b, c)::NTuple{3,Any}) return b end -_median(order, (a, b, c, d, e, f, g, h, i)::NTuple{9,Any}) = _median( +@inline _median(order, (a, b, c, d, e, f, g, h, i)::NTuple{9,Any}) = _median( order, (_median(order, (a, b, c)), _median(order, (d, e, f)), _median(order, (g, h, i))), ) diff --git a/test/test_sort.jl b/test/test_sort.jl index 365bb806..063bc884 100644 --- a/test/test_sort.jl +++ b/test/test_sort.jl @@ -19,10 +19,12 @@ using ThreadsX.Implementations: refine_pivot end end +divby(x) = Base.Fix2(÷, x) + @testset "stable sort" begin @testset for alg in [ThreadsX.MergeSort, ThreadsX.StableQuickSort] - @test ThreadsX.sort(1:45; alg = alg, basesize = 25, by = _ -> 1) == 1:45 - @test ThreadsX.sort(1:1000; alg = alg, basesize = 200, by = _ -> 1) == 1:1000 + @test ThreadsX.sort(1:45; alg = alg, basesize = 25, by = divby(2)) == 1:45 + @test ThreadsX.sort(1:1000; alg = alg, basesize = 200, by = divby(2)) == 1:1000 end end