From 29c41a3798975621f6dd5d4f7e0816acd8718530 Mon Sep 17 00:00:00 2001 From: "William S. Moses" <gh@wsmoses.com> Date: Tue, 28 Jan 2025 19:33:50 -0500 Subject: [PATCH 01/18] CPU backend --- deps/ReactantExtra/WORKSPACE | 2 +- src/Compiler.jl | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index cce06d553..7916c7305 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "c38ca3f187ef11de6b2292f3cc55c5eb60530d15" +ENZYMEXLA_COMMIT = "264115eb30f0b23f73040bd68ee6964b634330e9" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/src/Compiler.jl b/src/Compiler.jl index 0d6b3fcc2..81074beab 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -429,7 +429,7 @@ end const DEBUG_KERNEL = Ref{Bool}(false) -function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) +function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu") # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues @@ -456,7 +456,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: if isdefined(Reactant_jll, :ptxas_path) toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] end - if DEBUG_KERNEL[] + + if backend == "cpu" + kern = "lower-kernel{openmp=false backend=cpu},symbol-dce" + elseif DEBUG_KERNEL[] curesulthandler = XLA.Libdl.dlsym( Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" ) @@ -604,7 +607,7 @@ end @code_hlo [optimize = ...] [no_nan = <true/false>] f(args...) """ macro code_hlo(args...) - default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false) + default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false, :backend => "gpu") compile_expr, (; compiled) = compile_call_expr( __module__, compile_mlir, default_options, args... ) @@ -975,12 +978,18 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid + if client !== nothing + backend = XLA.ClientGetPlatformName(backend) + else + backend = XLA.ClientGetPlatformName(XLA.default_backend[]) + end + MLIR.IR.activate!(ctx) results = try # compile function to MLIR module mod = MLIR.IR.Module(MLIR.IR.Location()) linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!( - mod, f, args; optimize, no_nan + mod, f, args; optimize, no_nan, backend ) # Resolve client and device From 58bc103b3321b963327d7978616411fd7aa324b5 Mon Sep 17 00:00:00 2001 From: "William S. Moses" <gh@wsmoses.com> Date: Tue, 28 Jan 2025 19:50:30 -0500 Subject: [PATCH 02/18] add build --- deps/ReactantExtra/BUILD | 2 +- src/Compiler.jl | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 5ebcf2aaa..10bc8e61f 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -555,6 +555,7 @@ cc_library( "@llvm-project//mlir:CAPILLVMObjects", "@jax//jaxlib/mosaic:tpu_dialect_capi_objects", "@jax//jaxlib/triton:triton_dialect_capi_objects", + "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", ] + select({ "@xla//xla/tsl:is_cuda_enabled_and_oss":[ "@xla//xla/stream_executor/cuda:all_runtime", @@ -566,7 +567,6 @@ cc_library( "@xla//xla/backends/profiler/gpu:device_tracer", ], "//conditions:default": [ - "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", ], }) + if_rocm([ "@xla//xla/service/gpu:amdgpu_compiler", diff --git a/src/Compiler.jl b/src/Compiler.jl index 81074beab..87619f3d4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -979,10 +979,16 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid if client !== nothing - backend = XLA.ClientGetPlatformName(backend) + backend = XLA.ClientGetPlatformName(client) else backend = XLA.ClientGetPlatformName(XLA.default_backend[]) end + if backend == "CUDA" + backend = "GPU" + elseif backend == "CPU" + backend = "cpu" + end + @show backend MLIR.IR.activate!(ctx) results = try @@ -1036,7 +1042,7 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic num_partitions=1, use_shardy_partitioner=false, ) - return ( + ( exec, linear_args, linear_results, From e58c4b6333d9b8185c78bd73cf4a63c5157251b6 Mon Sep 17 00:00:00 2001 From: "William S. Moses" <gh@wsmoses.com> Date: Tue, 28 Jan 2025 21:29:39 -0500 Subject: [PATCH 03/18] cpu handle --- deps/ReactantExtra/BUILD | 2 ++ deps/ReactantExtra/WORKSPACE | 2 +- src/XLA.jl | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 10bc8e61f..65f269172 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -362,6 +362,7 @@ cc_library( ) + [ "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp", "@enzyme_ad//src/enzyme_ad/jax:gpu.cc", + "@enzyme_ad//src/enzyme_ad/jax:cpu.cc", # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc", # "@xla//xla:xla.pb.cc", "@xla//xla:xla_data.pb.cc", @@ -437,6 +438,7 @@ cc_library( "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", "-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler", +"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler", "-Wl,-exported_symbol,_ReactantThrowError", "-Wl,-exported_symbol,_ReactantHandleCuResult", "-Wl,-exported_symbol,_CreateProfilerSession", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 7916c7305..3a4499a75 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "264115eb30f0b23f73040bd68ee6964b634330e9" +ENZYMEXLA_COMMIT = "57c817748e620b9c688e9e6129af721d44c23f19" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/src/XLA.jl b/src/XLA.jl index 224a1deeb..435fdbcbc 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -178,6 +178,7 @@ function __init__() end end + @ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid @ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid # This wasn't properly exported on macos, we'll remove the try once macOS JLL From a2de1505b74e41f3f4f209bc0e79214555cfef8f Mon Sep 17 00:00:00 2001 From: "William S. Moses" <gh@wsmoses.com> Date: Tue, 28 Jan 2025 21:31:19 -0500 Subject: [PATCH 04/18] KA backend --- ext/ReactantCUDAExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 6d863fc53..502b8cdc1 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -1,7 +1,7 @@ module ReactantCUDAExt using CUDA -using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber +using Reactant: Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber using ReactantCore: @trace using KernelAbstractions: KernelAbstractions using Libdl @@ -9,6 +9,7 @@ using Libdl using Adapt KernelAbstractions.get_backend(::AnyTracedRArray) = CUDABackend() +KernelAbstractions.get_backend(::AnyConcreteRArray) = CUDABackend() struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} ptr::Core.LLVMPtr{T,A} From 575fc9243d0d1f987ab82b2b27395c385705add3 Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 29 Jan 2025 09:59:07 +0100 Subject: [PATCH 05/18] Update WORKSPACE --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 3a4499a75..6d571742b 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "57c817748e620b9c688e9e6129af721d44c23f19" +ENZYMEXLA_COMMIT = "ab95bec0a69c92be7065f9eabc16e2d6b0da11b7" ENZYMEXLA_SHA256 = "" http_archive( From 80e4e66a20f647bf1616f15521bec47edb8f6f0d Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 29 Jan 2025 12:05:54 +0100 Subject: [PATCH 06/18] Update WORKSPACE --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 6d571742b..b18f64beb 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "ab95bec0a69c92be7065f9eabc16e2d6b0da11b7" +ENZYMEXLA_COMMIT = "b37119b68e20370c94736a957d34b424dc5825f6" ENZYMEXLA_SHA256 = "" http_archive( From ca7d5b681dc265aeda5ca6024cea11e8e4804ccb Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 29 Jan 2025 12:15:54 +0100 Subject: [PATCH 07/18] Update WORKSPACE --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index b18f64beb..ddb1eca3f 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "b37119b68e20370c94736a957d34b424dc5825f6" +ENZYMEXLA_COMMIT = "e079972eb38fdda8e004df08c764ac7e93a2d85c" ENZYMEXLA_SHA256 = "" http_archive( From 9818c8f976f1cff79d5a2fe1695aabaabfb970a3 Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 29 Jan 2025 12:18:35 +0100 Subject: [PATCH 08/18] Update WORKSPACE --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index ddb1eca3f..44d4458a3 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "e079972eb38fdda8e004df08c764ac7e93a2d85c" +ENZYMEXLA_COMMIT = "df1905af736d0cc7741e8753f18d3f536fbc2e19" ENZYMEXLA_SHA256 = "" http_archive( From e1b1e8f539f471240233abe7c4d5f5fed0e23afe Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 29 Jan 2025 12:45:17 +0100 Subject: [PATCH 09/18] Update WORKSPACE --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 44d4458a3..935efdbf4 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "df1905af736d0cc7741e8753f18d3f536fbc2e19" +ENZYMEXLA_COMMIT = "d89468ed883ca18c04346eec10f784bbe2b754fc" ENZYMEXLA_SHA256 = "" http_archive( From 4ea9f49d84757087ae00abc46427aaec4913dc9a Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 29 Jan 2025 17:02:21 +0100 Subject: [PATCH 10/18] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d6f5d06a7..1baceeae1 100644 --- a/Project.toml +++ b/Project.toml @@ -72,7 +72,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.4" -Reactant_jll = "0.0.52" +Reactant_jll = "0.0.56" Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" From 7d906aeadecbd50836332f15c439124b55e3484a Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 29 Jan 2025 23:25:18 +0100 Subject: [PATCH 11/18] Update Compiler.jl --- src/Compiler.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 87619f3d4..2145555f7 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -988,7 +988,6 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic elseif backend == "CPU" backend = "cpu" end - @show backend MLIR.IR.activate!(ctx) results = try From 6014df7518d2ca966b543f6ad462ccc3667a94ac Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 29 Jan 2025 23:26:03 +0100 Subject: [PATCH 12/18] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Compiler.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 2145555f7..dcb48ae39 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -429,7 +429,9 @@ end const DEBUG_KERNEL = Ref{Bool}(false) -function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu") +function compile_mlir!( + mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu" +) # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues @@ -607,7 +609,9 @@ end @code_hlo [optimize = ...] [no_nan = <true/false>] f(args...) """ macro code_hlo(args...) - default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false, :backend => "gpu") + default_options = Dict{Symbol,Any}( + :optimize => true, :no_nan => false, :backend => "gpu" + ) compile_expr, (; compiled) = compile_call_expr( __module__, compile_mlir, default_options, args... ) From b51ad2dc47ac15a8e6e98255f07ece5b3ff556a0 Mon Sep 17 00:00:00 2001 From: William Moses <gh@wsmoses.com> Date: Wed, 29 Jan 2025 23:29:02 +0100 Subject: [PATCH 13/18] Update BUILD --- deps/ReactantExtra/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 65f269172..82c0bef25 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -450,6 +450,7 @@ cc_library( "-Wl,-exported_symbol,_ProfilerActivityEnd", "-Wl,-exported_symbol,_ReactantFuncSetArgAttr", "-Wl,-exported_symbol,_ReactantCudaDriverGetVersion", +"-Wl,-exported_symbol,_ClientGetPlatformName", "-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions" ]}), deps = [ From 264ec4be75347db169737529259a6bfc10b8c02e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= <765740+giordano@users.noreply.github.com> Date: Wed, 29 Jan 2025 23:30:15 +0100 Subject: [PATCH 14/18] Update ext/ReactantCUDAExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantCUDAExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 502b8cdc1..7cf2a61f3 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -1,7 +1,8 @@ module ReactantCUDAExt using CUDA -using Reactant: Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber +using Reactant: + Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber using ReactantCore: @trace using KernelAbstractions: KernelAbstractions using Libdl From b778228ff052e34a623c427652f254b7d293f564 Mon Sep 17 00:00:00 2001 From: Avik Pal <avikpal@mit.edu> Date: Wed, 29 Jan 2025 23:18:28 -0500 Subject: [PATCH 15/18] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1baceeae1..22982abc2 100644 --- a/Project.toml +++ b/Project.toml @@ -72,7 +72,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.4" -Reactant_jll = "0.0.56" +Reactant_jll = "0.0.58" Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" From b6254523b9a6f5538c10c662bc99e0a8759f0703 Mon Sep 17 00:00:00 2001 From: "William S. Moses" <gh@wsmoses.com> Date: Thu, 30 Jan 2025 09:41:34 +0100 Subject: [PATCH 16/18] Adapt to compile abi change --- src/Compiler.jl | 15 +-------- src/XLA.jl | 82 +++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 70 insertions(+), 27 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index dcb48ae39..42f1c53f6 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1002,7 +1002,6 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic ) # Resolve client and device - device_ordinal = -1 if device === nothing if length(linear_args) > 0 devices_list = [ @@ -1020,30 +1019,18 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic client = XLA.client(device) else client = XLA.default_backend[] - device = XLA.ClientGetDevice(client, XLA.default_device_idx[]) - device_ordinal = XLA.default_device_idx[] end else if device !== nothing @assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same" - else - device = XLA.ClientGetDevice(client, XLA.default_device_idx[]) - device_ordinal = XLA.default_device_idx[] end end - if device_ordinal < 0 - device_ordinal = XLA.DeviceToClientDeviceOrdinal(device) - end # compile MLIR module to XLA executable exec = XLA.Compile( client, - mod; - device_ordinal, - num_replicas=1, - num_partitions=1, - use_shardy_partitioner=false, + mod ) ( exec, diff --git a/src/XLA.jl b/src/XLA.jl index 435fdbcbc..f3fb34407 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -15,17 +15,47 @@ end mutable struct Client client::Ptr{Cvoid} + global_ordinals::Vector{Cint} function Client(client::Ptr{Cvoid}) @assert client != C_NULL - return new(client) + global_ordinals = Cint[] + client = new(client, global_ordinals) + + # https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127 + devices = [ + ClientGetAddressableDevice(client, i - 1) for + i in 1:ClientNumAddressableDevices(client) + ] + sort!(devices; lt=(a, b) -> DeviceGetLocalDeviceId(a) < DeviceGetLocalDeviceId(b)) + + local_ids = [DeviceGetLocalDeviceId(device) + 1 for device in devices] + max_local_id = maximum(local_ids) + resize!(global_ordinals, max_local_id) + global_ordinals .= -1 + for (i, device) in enumerate(devices) + global_ordinals[local_ids[i]] = i - 1 + end + return client end end Base.:(==)(a::Client, b::Client) = a.client == b.client -function Base.show(io::IO, ::MIME"text/plain", client::Client) - print(io, "Client($(client.client), platform_name=$(ClientGetPlatformName(client)))") +function device_ordinal(client::Client, device::Device) + return client.global_ordinals[DeviceGetLocalDeviceId(device) + 1] +end + +function DeviceToString(device::Device) + pjrtclient = client(device) + platform_name = ClientGetPlatformName(pjrtclient) + return "$(uppercase(platform_name)):$(device_ordinal(pjrtclient, device))" +end + +function Base.show(io::IO, ::MIME"text/plain", device::Device) + pjrtclient = client(device) + platform_name = ClientGetPlatformName(pjrtclient) + print(io, "Device($(device.device), platform_name=$(platform_name))") return nothing end @@ -337,7 +367,7 @@ Return an [`AllocatorStats`](@ref) instance with information about the device sp This method is currently not implemented for the CPU device. """ function allocatorstats( - device::Device=ClientGetDevice(default_backend[], default_device_idx[]) + device::Device=ClientGetAddressableDevice(default_backend[], default_device_idx[]) ) ref = Ref{JLAllocatorStats}() @ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats( @@ -540,21 +570,16 @@ end function Compile( client::Client, - mod::MLIR.IR.Module; - device_ordinal::Int=-1, - num_replicas::Int=1, - num_partitions::Int=1, - use_shardy_partitioner::Bool=false, + mod::MLIR.IR.Module ) + max_local_id = length(client.global_ordinals) GC.@preserve client mod begin executable = LoadedExecutable( @ccall MLIR.API.mlir_c.ClientCompile( client.client::Ptr{Cvoid}, mod.module_::MLIR.API.MlirModule, - device_ordinal::Cint, - num_replicas::Cint, - num_partitions::Cint, - use_shardy_partitioner::Bool, + client.global_ordinals::Ptr{Cint}, + max_local_id::Cint, )::Ptr{Cvoid} ) end @@ -609,6 +634,37 @@ function ClientGetPlatformName(client::Client) return unsafe_string(str) end +function DeviceGetLocalDeviceId(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalDeviceId( + device.device::Ptr{Cvoid} + )::Cint + end +end + +function PjRtLoadedExecutableGetClient(exec::LoadedExecutable) + GC.@preserve exec begin + return Client( + @ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetClient( + exec.exec::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function replicate_buffer_on_all_addressable_devices(buffer::Buffer) + pjrtclient = client(buffer) + devices = [ + ClientGetAddressableDevice(pjrtclient, i - 1) for + i in 1:ClientNumAddressableDevices(pjrtclient) + ] + orig_device = device(buffer) + return [ + device == orig_device ? buffer : CopyBufferToDevice(buffer, device) for + device in devices + ] +end + function is_ready(future::Future) GC.@preserve future begin return (@ccall MLIR.API.mlir_c.FutureIsReady(future.future::Ptr{Cvoid})::UInt8) != 0 From 53a1b52ecf3464adfcd2bc9e58be1b8fe69a1fb0 Mon Sep 17 00:00:00 2001 From: "William S. Moses" <gh@wsmoses.com> Date: Thu, 30 Jan 2025 09:46:35 +0100 Subject: [PATCH 17/18] move device --- src/XLA.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/XLA.jl b/src/XLA.jl index f3fb34407..37f57e4c7 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -42,6 +42,10 @@ end Base.:(==)(a::Client, b::Client) = a.client == b.client +struct Device + device::Ptr{Cvoid} +end + function device_ordinal(client::Client, device::Device) return client.global_ordinals[DeviceGetLocalDeviceId(device) + 1] end @@ -258,10 +262,6 @@ mutable struct Buffer end end -struct Device - device::Ptr{Cvoid} -end - function Base.show(io::IO, ::MIME"text/plain", device::Device) pjrtclient = client(device) platform_name = ClientGetPlatformName(pjrtclient) From 3c258562f559d29686f6836531066c676be04b42 Mon Sep 17 00:00:00 2001 From: "William S. Moses" <gh@wsmoses.com> Date: Thu, 30 Jan 2025 09:47:35 +0100 Subject: [PATCH 18/18] rm redundant --- src/XLA.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/XLA.jl b/src/XLA.jl index 37f57e4c7..ddae554fa 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -262,13 +262,6 @@ mutable struct Buffer end end -function Base.show(io::IO, ::MIME"text/plain", device::Device) - pjrtclient = client(device) - platform_name = ClientGetPlatformName(pjrtclient) - print(io, "Device($(device.device), platform_name=$(platform_name))") - return nothing -end - function DeviceToClientDeviceOrdinal(device::Device) pjrtclient = client(device) naddressable_devices = ClientNumAddressableDevices(pjrtclient)