Skip to content

Commit

Permalink
Try to refine pivot in choose_pivot
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Dec 23, 2021
1 parent 5dd6143 commit 7a89465
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 26 deletions.
73 changes: 52 additions & 21 deletions src/quicksort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,18 @@ function Base.sort!(
return v
end

function choose_pivot(xs, order)
return _median(
order,
(
xs[1],
xs[end÷8],
xs[end÷4],
xs[3*(end÷8)],
xs[end÷2],
xs[5*(end÷8)],
xs[3*(end÷4)],
xs[7*(end÷8)],
xs[end],
),
)
end

function _quicksort!(ys, xs, alg, order, givenpivot = nothing)
@check length(ys) == length(xs)
if length(ys) <= max(8, alg.basesize)
return _quicksort_serial!(ys, xs, alg, order)
end
isrefined = false
pivot = if givenpivot === nothing
choose_pivot(ys, order)
let pivot, ishomogenous
pivot, ishomogenous, isrefined = choose_pivot(ys, alg.basesize, order)
ishomogenous && return ys
pivot
end
else
something(givenpivot)
end
Expand Down Expand Up @@ -95,6 +83,7 @@ function _quicksort!(ys, xs, alg, order, givenpivot = nothing)
total_nbelows = above_offsets[1]
if total_nbelows == 0
@assert givenpivot === nothing
@assert !isrefined
betterpivot, ishomogenous = refine_pivot(ys, pivot, alg.basesize, order)
ishomogenous && return ys
return _quicksort!(ys, xs, alg, order, Some(betterpivot))
Expand Down Expand Up @@ -124,7 +113,7 @@ function _quicksort_serial!(ys, xs, alg, order)
if length(ys) <= max(8, alg.smallsize)
return sort!(ys, alg.smallsort, order)
end
pivot = choose_pivot(ys, order)
_, pivot = samples_and_pivot(ys, order)

nbelows, naboves = quicksort_partition!(xs, ys, pivot, order)
@DBG @check nbelows + naboves == length(xs)
Expand Down Expand Up @@ -161,6 +150,50 @@ function quicksort_copyback!(ys, xs_chunk, nbelows, below_offset, above_offset)
end
end

@inline function samples_and_pivot(xs, order)
samples = (
xs[1],
xs[end÷8],
xs[end÷4],
xs[3*(end÷8)],
xs[end÷2],
xs[5*(end÷8)],
xs[3*(end÷4)],
xs[7*(end÷8)],
xs[end],
)
pivot = _median(order, samples)
return samples, pivot
end

"""
choose_pivot(xs, basesize, order) -> (pivot, ishomogenous::Bool, isrefined::Bool)
"""
function choose_pivot(xs, basesize, order)
samples, pivot = samples_and_pivot(xs, order)
if (
eq(order, samples[1], pivot) &&
eq(order, samples[1], samples[2]) &&
eq(order, samples[2], samples[3]) &&
eq(order, samples[3], samples[4]) &&
eq(order, samples[4], samples[5]) &&
eq(order, samples[5], samples[6]) &&
eq(order, samples[6], samples[7]) &&
eq(order, samples[7], samples[8]) &&
eq(order, samples[8], samples[9])
)
pivot, ishomogenous =
refine_pivot_serial(@view(xs[begin:min(end, begin + 127)]), pivot, order)
if ishomogenous
length(xs) <= 128 && return (pivot, true, true)
pivot, ishomogenous =
refine_pivot(@view(xs[begin+128:end]), pivot, basesize, order)
return (pivot, ishomogenous, true)
end
end
return (pivot, false, false)
end

"""
refine_pivot(ys, badpivot::T, basesize, order) -> (pivot::T, ishomogenous::Bool)
Expand Down Expand Up @@ -224,5 +257,3 @@ end
# TODO: Check if the homogeneity check can be done in `quicksort_partition!`
# without overall performance degradation? Use it to determine the pivot
# for the next recursion.
# TODO: Do this right after `choose_pivot` if it finds out that all samples are
# equivalent?
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ function elsizeof(::Type{T}) where {T}
end
end

eq(order, a, b) = !(Base.lt(order, a, b) || Base.lt(order, b, a))
@inline eq(order, a, b) = !(Base.lt(order, a, b) || Base.lt(order, b, a))

function _median(order, (a, b, c)::NTuple{3,Any})
@inline function _median(order, (a, b, c)::NTuple{3,Any})
# Sort `(a, b, c)`:
if Base.lt(order, b, a)
a, b = b, a
Expand All @@ -53,7 +53,7 @@ function _median(order, (a, b, c)::NTuple{3,Any})
return b
end

_median(order, (a, b, c, d, e, f, g, h, i)::NTuple{9,Any}) = _median(
@inline _median(order, (a, b, c, d, e, f, g, h, i)::NTuple{9,Any}) = _median(
order,
(_median(order, (a, b, c)), _median(order, (d, e, f)), _median(order, (g, h, i))),
)
Expand Down
6 changes: 4 additions & 2 deletions test/test_sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ using ThreadsX.Implementations: refine_pivot
end
end

divby(x) = Base.Fix2(÷, x)

@testset "stable sort" begin
@testset for alg in [ThreadsX.MergeSort, ThreadsX.StableQuickSort]
@test ThreadsX.sort(1:45; alg = alg, basesize = 25, by = _ -> 1) == 1:45
@test ThreadsX.sort(1:1000; alg = alg, basesize = 200, by = _ -> 1) == 1:1000
@test ThreadsX.sort(1:45; alg = alg, basesize = 25, by = divby(2)) == 1:45
@test ThreadsX.sort(1:1000; alg = alg, basesize = 200, by = divby(2)) == 1:1000
end
end

Expand Down

0 comments on commit 7a89465

Please sign in to comment.