From ae7a83baa077fc7aec7ffae145211e02ca111b23 Mon Sep 17 00:00:00 2001
From: Valentin Churavy <v.churavy@gmail.com>
Date: Thu, 2 Dec 2021 16:45:49 -0500
Subject: [PATCH 1/2] Add error handler to endpoint

---
 src/UCX.jl | 70 ++++++++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 57 insertions(+), 13 deletions(-)

diff --git a/src/UCX.jl b/src/UCX.jl
index 5b89c83..5ee540f 100644
--- a/src/UCX.jl
+++ b/src/UCX.jl
@@ -532,9 +532,20 @@ mutable struct UCXEndpoint
 end
 Base.unsafe_convert(::Type{API.ucp_ep_h}, ep::UCXEndpoint) = ep.handle
 
-function UCXEndpoint(worker::UCXWorker, ip::IPv4, port)
+function ucp_err_handler(arg::Ptr{Cvoid}, ep::API.ucp_ep_h, status::API.ucs_status_t)
+    @error "Endpoint error" exception=UCXException(status)
+    # TODO should we throw here and close the endpoint?
+    return nothing
+end
+
+function UCXEndpoint(worker::UCXWorker, ip::IPv4, port;
+                     error_handling=true)
     field_mask = API.UCP_EP_PARAM_FIELD_FLAGS |
                  API.UCP_EP_PARAM_FIELD_SOCK_ADDR
+    if error_handling
+        field_mask |= API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
+                      API.UCP_EP_PARAM_FIELD_ERR_HANDLER
+    end
     flags      = API.UCP_EP_PARAMS_FLAGS_CLIENT_SERVER
     sockaddr   = Ref(IP.sockaddr_in(InetAddr(ip, port)))
 
@@ -549,8 +560,15 @@ function UCXEndpoint(worker::UCXWorker, ip::IPv4, port)
         set!(params, :field_mask,   field_mask)
         set!(params, :sockaddr,     ucs_sockaddr)
         set!(params, :flags,        flags)
-
-        # TODO: Error callback
+        if error_handling
+            err_handler = API.ucp_err_handler(
+                @cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)),
+                C_NULL
+            )
+
+            set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER)
+            set!(params, :err_handler, err_handler)
+        end
     
         @check API.ucp_ep_create(worker, params, r_handle)
     end
@@ -558,9 +576,15 @@ function UCXEndpoint(worker::UCXWorker, ip::IPv4, port)
     UCXEndpoint(worker, r_handle[])
 end
 
-function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest)
+function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest;
+                     error_handling=true)
     field_mask = API.UCP_EP_PARAM_FIELD_FLAGS |
                  API.UCP_EP_PARAM_FIELD_CONN_REQUEST
+    if error_handling
+        field_mask |= API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
+                      API.UCP_EP_PARAM_FIELD_ERR_HANDLER
+    end
+
     flags      = API.UCP_EP_PARAMS_FLAGS_NO_LOOPBACK
 
     params = Ref{API.ucp_ep_params}()
@@ -568,8 +592,15 @@ function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest)
     set!(params, :field_mask,   field_mask)
     set!(params, :conn_request, conn_request.handle)
     set!(params, :flags,        flags)
-
-    # TODO: Error callback
+    if error_handling
+        err_handler = API.ucp_err_handler(
+            @cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)),
+            C_NULL
+        )
+
+        set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER)
+        set!(params, :err_handler, err_handler)
+    end
 
     r_handle = Ref{API.ucp_ep_h}()
     @check API.ucp_ep_create(worker, params, r_handle)
@@ -577,29 +608,42 @@ function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest)
     UCXEndpoint(worker, r_handle[])
 end
 
-function UCXEndpoint(worker::UCXWorker, addr::UCXAddress)
+function UCXEndpoint(worker::UCXWorker, addr::UCXAddress;
+                     error_handling=true)
     GC.@preserve addr begin
-        _UCXEndpoint(worker, addr.handle)
+        _UCXEndpoint(worker, addr.handle, error_handling)
     end
 end
 
-function UCXEndpoint(worker::UCXWorker, addr_buf::Vector{UInt8})
+function UCXEndpoint(worker::UCXWorker, addr_buf::Vector{UInt8};
+                     error_handling=true)
     GC.@preserve addr_buf begin
         addr = Base.unsafe_convert(Ptr{API.ucp_address_t}, pointer(addr_buf))
-        _UCXEndpoint(worker, addr)
+        _UCXEndpoint(worker, addr, error_handling)
     end
 end
 
-function _UCXEndpoint(worker::UCXWorker, addr::Ptr{API.ucp_address_t})
+function _UCXEndpoint(worker::UCXWorker, addr::Ptr{API.ucp_address_t}, error_handling)
     field_mask = API.UCP_EP_PARAM_FIELD_REMOTE_ADDRESS
+    if error_handling
+        field_mask |= API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
+                      API.UCP_EP_PARAM_FIELD_ERR_HANDLER
+    end
 
     r_handle = Ref{API.ucp_ep_h}()
     params = Ref{API.ucp_ep_params}()
     memzero!(params)
     set!(params, :field_mask,   field_mask)
     set!(params, :address,      addr)
-
-    # TODO: Error callback
+    if error_handling
+        err_handler = API.ucp_err_handler(
+            @cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)),
+            C_NULL
+        )
+
+        set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER)
+        set!(params, :err_handler, err_handler)
+    end
 
     @check API.ucp_ep_create(worker, params, r_handle)
 

From 728c64d1109b483a202e6deb28976f665214d963 Mon Sep 17 00:00:00 2001
From: Valentin Churavy <v.churavy@gmail.com>
Date: Sun, 12 Dec 2021 20:45:26 -0500
Subject: [PATCH 2/2] add test

---
 src/UCX.jl       | 30 ++++++++++++++++--------------
 test/runtests.jl | 39 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 55 insertions(+), 14 deletions(-)

diff --git a/src/UCX.jl b/src/UCX.jl
index 5ee540f..9337800 100644
--- a/src/UCX.jl
+++ b/src/UCX.jl
@@ -271,11 +271,7 @@ mutable struct UCXWorker
         end
 
         worker = new(handle, fd, context, IdDict{Any,Nothing}(), Dict{UInt16, Any}(), fill(false, Base.Threads.nthreads()), true, progress_mode)
-        finalizer(worker) do worker
-            worker.open = false
-            @assert isempty(worker.inflight)
-            API.ucp_worker_destroy(worker)
-        end
+        finalizer(destroy, worker)
         return worker
     end
 end
@@ -313,6 +309,15 @@ function fence(worker::UCXWorker)
     @check API.ucp_worker_fence(worker)
 end
 
+function destroy(worker::UCXWorker)
+    if worker.handle != C_NULL
+        close(worker)
+        @assert isempty(worker.inflight)
+        API.ucp_worker_destroy(worker)
+        worker.handle = C_NULL
+    end
+end
+
 function lock_am(worker::UCXWorker)
     tid = Base.Threads.threadid()
     if worker.in_amhandler[tid]
@@ -387,18 +392,16 @@ function Base.notify(worker::UCXWorker)
 end
 
 function Base.isopen(worker::UCXWorker)
-    worker.open
+    worker.open && worker.handle != C_NULL
 end
 
 function Base.close(worker::UCXWorker)
-    @debug "Close worker"
-    worker.open = false
-    notify(worker)
+    if isopen(worker)
+        worker.open = false
+        notify(worker)
+    end
 end
 
-
-
-
 """
     AMHandler(func)
 
@@ -533,8 +536,7 @@ end
 Base.unsafe_convert(::Type{API.ucp_ep_h}, ep::UCXEndpoint) = ep.handle
 
 function ucp_err_handler(arg::Ptr{Cvoid}, ep::API.ucp_ep_h, status::API.ucs_status_t)
-    @error "Endpoint error" exception=UCXException(status)
-    # TODO should we throw here and close the endpoint?
+    throw(UCXException(status))
     return nothing
 end
 
diff --git a/test/runtests.jl b/test/runtests.jl
index 70bfc14..f4956e4 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -38,6 +38,45 @@ end
     @test addr.len > 0
 end
 
+@testset "Error handler" begin
+    ctx = UCX.UCXContext()
+    server = UCX.UCXWorker(ctx)
+
+    UCX.@spawn_showerr begin
+        while isopen(server)
+            wait(server)
+        end
+        close(server)
+    end
+
+    barrier = Base.Event()
+
+    am_called = Ref{Int}(0)
+    AM_TEST = 1
+    function am_test(worker, header, header_length, data, length, _param)
+        am_called[] += 1
+        notify(barrier)
+        return UCX.API.UCS_OK
+    end
+    UCX.AMHandler(server, am_test, AM_TEST) 
+
+    server_addr = UCX.UCXAddress(server)
+    client = UCX.UCXWorker(ctx)
+    ep = UCX.UCXEndpoint(client, server_addr)
+
+    req = UCX.am_send(ep, AM_TEST, Int[])
+    wait(req) # wait on request to be send before suspending in `take!`
+    wait(barrier)
+
+    @test am_called[] == 1
+
+    UCX.destroy(server)
+    barrier = Base.Event()
+    req = UCX.am_send(ep, AM_TEST, Int[])
+    @test_throws UCX.UCXException wait(req) # wait on request to be send before suspending in `take!`
+    @test am_called[] == 1
+end
+
 @testset "Active Messages" begin
     cmd = Base.julia_cmd()
     if Base.JLOptions().project != C_NULL