Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: LIBUV progress #46

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 77 additions & 99 deletions src/UCX.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,6 @@ function __init__()
ccall((:ucs_debug_disable_signals, API.libucs), Cvoid, ())

@assert version() >= VersionNumber(API.UCP_API_MAJOR, API.UCP_API_MINOR)
mode = get(ENV, "JLUCX_PROGRESS_MODE", "idling")
if mode == "busy"
PROGRESS_MODE[] = :busy
elseif mode == "idling"
PROGRESS_MODE[] = :idling
elseif mode == "polling"
PROGRESS_MODE[] = :polling
else
error("JLUCX_PROGRESS_MODE set to unkown progress mode: $mode")
end
@debug "UCX progress mode" mode
end

function memzero!(ref::Ref)
Expand All @@ -59,6 +48,8 @@ end
val
end

sync_send(data::Ptr{Cvoid}) = ccall(:uv_async_send, Cint, (Ptr{Cvoid},), data)

# Exceptions/Status

uintptr_t(ptr::Ptr) = reinterpret(UInt, ptr)
Expand Down Expand Up @@ -262,13 +253,13 @@ mutable struct UCXWorker
handle = r_handle[]

# TODO: Verify that UCXContext has been created with WAKEUP
if progress_mode === :polling
r_fd = Ref{API.Cint}()
@check API.ucp_worker_get_efd(handle, r_fd)
fd = Libc.RawFD(r_fd[])
else
# if progress_mode === :polling
# r_fd = Ref{API.Cint}()
# @check API.ucp_worker_get_efd(handle, r_fd)
# fd = Libc.RawFD(r_fd[])
# else
fd = RawFD(-1)
end
# end

worker = new(handle, fd, context, IdDict{Any,Nothing}(), Dict{UInt16, Any}(), fill(false, Base.Threads.nthreads()), true, progress_mode)
finalizer(worker) do worker
Expand All @@ -281,8 +272,6 @@ mutable struct UCXWorker
end
Base.unsafe_convert(::Type{API.ucp_worker_h}, worker::UCXWorker) = worker.handle

ispolling(worker::UCXWorker) = worker.fd != RawFD(-1)
progress_mode(worker::UCXWorker) = worker.mode
context(worker::UCXWorker) = worker.context

"""
Expand All @@ -293,97 +282,59 @@ and call callbacks.

Returns `true` if progress was made, `false` if no work was waiting.
"""
function progress(worker::UCXWorker, allow_yield=true)
tid = Base.Threads.threadid()
if worker.in_amhandler[tid]
@debug """
UCXWorker is processing a Active Message on this thread
Calling `progress` is not permitted and leads to recursion.
""" tid exception=(UCXException(API.UCS_ERR_NO_PROGRESS), catch_backtrace())
if allow_yield
yield()
end
return false
else
return API.ucp_worker_progress(worker) != 0
end
function progress(worker::UCXWorker)
return API.ucp_worker_progress(worker) != 0
end

function fence(worker::UCXWorker)
@check API.ucp_worker_fence(worker)
function async_progress(worker::UCXWorker)
return Base.@threadcall((:ucp_worker_progress, API.libucp), Cuint, (ucp_worker_h,), worker) != 0
end

function lock_am(worker::UCXWorker)
tid = Base.Threads.threadid()
if worker.in_amhandler[tid]
error("UCXWorker already in AMHandler on this thread! Concurrency violation.")
end
worker.in_amhandler[tid] = true
end

function unlock_am(worker::UCXWorker)
tid = Base.Threads.threadid()
if !worker.in_amhandler[tid]
error("UCXWorker is not in AMHandler on this thread! Concurrency violation.")
end
worker.in_amhandler[tid] = false
function fence(worker::UCXWorker)
@check API.ucp_worker_fence(worker)
end

include("idle.jl")

import FileWatching: poll_fd
function Base.wait(worker::UCXWorker)
if ispolling(worker)
@assert progress_mode(worker) === :polling
# Use fd_poll to suspend worker when no progress is being made
# if ispolling(worker)
# @assert progress_mode(worker) === :polling
# # Use fd_poll to suspend worker when no progress is being made
# while isopen(worker)
# if progress(worker)
# # progress was made
# yield()
# continue
# end

# # Wait for poll
# status = API.ucp_worker_arm(worker)
# if status == API.UCS_OK
# if !isopen(worker)
# error("UCXWorker already closed")
# end
# # wait for events
# poll_fd(worker.fd; writable=true, readable=true)
# progress(worker)
# break
# elseif status == API.UCS_ERR_BUSY
# # could not arm, need to progress more
# continue
# else
# @check status
# end
# end
# elseif progress_mode(worker) === :libuv
async_progress(worker)
while isopen(worker)
if progress(worker)
# progress was made
yield()
continue
end

# Wait for poll
status = API.ucp_worker_arm(worker)
if status == API.UCS_OK
if !isopen(worker)
error("UCXWorker already closed")
end
# wait for events
poll_fd(worker.fd; writable=true, readable=true)
progress(worker)
break
elseif status == API.UCS_ERR_BUSY
# could not arm, need to progress more
continue
else
@check status
end
end
elseif progress_mode(worker) === :idling
idler = UvWorkerIdle(worker)
wait(idler)
close(idler)
elseif progress_mode(worker) === :busy
progress(worker)
while isopen(worker)
# Temporary solution before we have gc transition support in codegen.
# XXX: `yield()` is supposed to be a safepoint, but without this we easily
# deadlock in a multithreaded test.
ccall(:jl_gc_safepoint, Cvoid, ())
yield()
progress(worker)
async_progress(worker)
end
else
throw(UCXException(API.UCS_ERR_UNREACHABLE))
end
end

function Base.notify(worker::UCXWorker)
# If we don't use polling, we can't signal the worker
if ispolling(worker)
@check API.ucp_worker_signal(worker)
end
# if ispolling(worker)
# @check API.ucp_worker_signal(worker)
# end
end

function Base.isopen(worker::UCXWorker)
Expand All @@ -396,9 +347,6 @@ function Base.close(worker::UCXWorker)
notify(worker)
end




"""
AMHandler(func)

Expand All @@ -416,9 +364,39 @@ or call `am_recv`.
"""
mutable struct AMHandler
func::FunctionWrapper{API.ucs_status_t, Tuple{UCXWorker, Ptr{Cvoid}, Csize_t, Ptr{Cvoid}, Csize_t, Ptr{API.ucp_am_recv_param_t}}}

worker::UCXWorker
end

mutable struct Queue{T}
@atomic head::Ptr{Cvoid}
@atomic tail::Ptr{Cvoid}
size::Csize_t
end

function enqueue!(queue::Queue{T}, item::T) where T
head = @atomic :acquire queue.head
tail = @atomic :acquire queue.tail

while true
new_head = head + sizeof(T)
if new_head >= tail
return false # queue is full
end
head, success = @atomicreplace :acquire_release queue.head head => new_head
if success
break
end
end
if head + sizeof(T) > tail
return false # queue is full
end

Base.unsafe_store!(head, item)
return true
end


function am_recv_callback(arg::Ptr{Cvoid}, header::Ptr{Cvoid}, header_length::Csize_t, data::Ptr{Cvoid}, length::Csize_t, param::Ptr{API.ucp_am_recv_param_t})::API.ucs_status_t
handler = Base.unsafe_pointer_to_objref(arg)::AMHandler
try
Expand Down