Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to refine pivot in choose_pivot #188

Draft
wants to merge 2 commits into
base: copy-twice
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 50 additions & 21 deletions src/quicksort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,18 @@
return v
end

function choose_pivot(xs, order)
return _median(
order,
(
xs[1],
xs[end÷8],
xs[end÷4],
xs[3*(end÷8)],
xs[end÷2],
xs[5*(end÷8)],
xs[3*(end÷4)],
xs[7*(end÷8)],
xs[end],
),
)
end

function _quicksort!(ys, xs, alg, order, givenpivot = nothing)
@check length(ys) == length(xs)
if length(ys) <= max(8, alg.basesize)
return _quicksort_serial!(ys, xs, alg, order)
end
isrefined = false
pivot = if givenpivot === nothing
choose_pivot(ys, order)
let pivot, ishomogenous

Check warning on line 38 in src/quicksort.jl

View check run for this annotation

Codecov / codecov/patch

src/quicksort.jl#L38

Added line #L38 was not covered by tests
pivot, ishomogenous, isrefined = choose_pivot(ys, alg.basesize, order)
ishomogenous && return ys
pivot
end
else
something(givenpivot)
end
Expand Down Expand Up @@ -95,6 +83,7 @@
total_nbelows = above_offsets[1]
if total_nbelows == 0
@assert givenpivot === nothing
@assert !isrefined
betterpivot, ishomogenous = refine_pivot(ys, pivot, alg.basesize, order)
ishomogenous && return ys
return _quicksort!(ys, xs, alg, order, Some(betterpivot))
Expand Down Expand Up @@ -124,7 +113,7 @@
if length(ys) <= max(8, alg.smallsize)
return sort!(ys, alg.smallsort, order)
end
pivot = choose_pivot(ys, order)
_, pivot = samples_and_pivot(ys, order)

Check warning on line 116 in src/quicksort.jl

View check run for this annotation

Codecov / codecov/patch

src/quicksort.jl#L116

Added line #L116 was not covered by tests

nbelows, naboves = quicksort_partition!(xs, ys, pivot, order)
@DBG @check nbelows + naboves == length(xs)
Expand Down Expand Up @@ -161,6 +150,48 @@
end
end

@inline function samples_and_pivot(xs, order)
samples = (
xs[1],
xs[end÷8],
xs[end÷4],
xs[3*(end÷8)],
xs[end÷2],
xs[5*(end÷8)],
xs[3*(end÷4)],
xs[7*(end÷8)],
xs[end],
)
pivot = _median(order, samples)
return samples, pivot
end

"""
choose_pivot(xs, basesize, order) -> (pivot, ishomogenous::Bool, isrefined::Bool)
"""
function choose_pivot(xs, basesize, order)
samples, pivot = samples_and_pivot(xs, order)
if (
eq(order, samples[1], pivot) &&
eq(order, samples[1], samples[2]) &&
eq(order, samples[2], samples[3]) &&

Check warning on line 177 in src/quicksort.jl

View check run for this annotation

Codecov / codecov/patch

src/quicksort.jl#L176-L177

Added lines #L176 - L177 were not covered by tests
eq(order, samples[3], samples[4]) &&
eq(order, samples[4], samples[5]) &&
eq(order, samples[5], samples[6]) &&
eq(order, samples[6], samples[7]) &&
eq(order, samples[7], samples[8]) &&
eq(order, samples[8], samples[9])
)
pivot, ishomogenous = refine_pivot_serial(@view(xs[1:min(end, 128)]), pivot, order)
if ishomogenous
length(xs) <= 128 && return (pivot, true, true)
pivot, ishomogenous = refine_pivot(@view(xs[129:end]), pivot, basesize, order)
return (pivot, ishomogenous, true)

Check warning on line 189 in src/quicksort.jl

View check run for this annotation

Codecov / codecov/patch

src/quicksort.jl#L188-L189

Added lines #L188 - L189 were not covered by tests
end
end
return (pivot, false, false)
end

"""
refine_pivot(ys, badpivot::T, basesize, order) -> (pivot::T, ishomogenous::Bool)

Expand Down Expand Up @@ -224,5 +255,3 @@
# TODO: Check if the homogeneity check can be done in `quicksort_partition!`
# without overall performance degradation? Use it to determine the pivot
# for the next recursion.
# TODO: Do this right after `choose_pivot` if it finds out that all samples are
# equivalent?
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ function elsizeof(::Type{T}) where {T}
end
end

eq(order, a, b) = !(Base.lt(order, a, b) || Base.lt(order, b, a))
@inline eq(order, a, b) = !(Base.lt(order, a, b) || Base.lt(order, b, a))

function _median(order, (a, b, c)::NTuple{3,Any})
@inline function _median(order, (a, b, c)::NTuple{3,Any})
# Sort `(a, b, c)`:
if Base.lt(order, b, a)
a, b = b, a
Expand All @@ -53,7 +53,7 @@ function _median(order, (a, b, c)::NTuple{3,Any})
return b
end

_median(order, (a, b, c, d, e, f, g, h, i)::NTuple{9,Any}) = _median(
@inline _median(order, (a, b, c, d, e, f, g, h, i)::NTuple{9,Any}) = _median(
order,
(_median(order, (a, b, c)), _median(order, (d, e, f)), _median(order, (g, h, i))),
)
Expand Down
6 changes: 4 additions & 2 deletions test/test_sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ using ThreadsX.Implementations: refine_pivot
end
end

divby(x) = Base.Fix2(÷, x)

@testset "stable sort" begin
@testset for alg in [ThreadsX.MergeSort, ThreadsX.StableQuickSort]
@test ThreadsX.sort(1:45; alg = alg, basesize = 25, by = _ -> 1) == 1:45
@test ThreadsX.sort(1:1000; alg = alg, basesize = 200, by = _ -> 1) == 1:1000
@test ThreadsX.sort(1:45; alg = alg, basesize = 25, by = divby(2)) == 1:45
@test ThreadsX.sort(1:1000; alg = alg, basesize = 200, by = divby(2)) == 1:1000
end
end

Expand Down