From 29c41a3798975621f6dd5d4f7e0816acd8718530 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh@wsmoses.com>
Date: Tue, 28 Jan 2025 19:33:50 -0500
Subject: [PATCH 01/18] CPU backend

---
 deps/ReactantExtra/WORKSPACE |  2 +-
 src/Compiler.jl              | 17 +++++++++++++----
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE
index cce06d553..7916c7305 100644
--- a/deps/ReactantExtra/WORKSPACE
+++ b/deps/ReactantExtra/WORKSPACE
@@ -9,7 +9,7 @@ http_archive(
     urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
 )
 
-ENZYMEXLA_COMMIT = "c38ca3f187ef11de6b2292f3cc55c5eb60530d15"
+ENZYMEXLA_COMMIT = "264115eb30f0b23f73040bd68ee6964b634330e9"
 ENZYMEXLA_SHA256 = ""
 
 http_archive(
diff --git a/src/Compiler.jl b/src/Compiler.jl
index 0d6b3fcc2..81074beab 100644
--- a/src/Compiler.jl
+++ b/src/Compiler.jl
@@ -429,7 +429,7 @@ end
 
 const DEBUG_KERNEL = Ref{Bool}(false)
 
-function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
+function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu")
     # Explicitly don't use block! to avoid creating a closure, which creates
     # both compile-time and relocatability issues
 
@@ -456,7 +456,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
     if isdefined(Reactant_jll, :ptxas_path)
         toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
     end
-    if DEBUG_KERNEL[]
+
+    if backend == "cpu"
+        kern = "lower-kernel{openmp=false backend=cpu},symbol-dce"
+    elseif DEBUG_KERNEL[]
         curesulthandler = XLA.Libdl.dlsym(
             Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
         )
@@ -604,7 +607,7 @@ end
     @code_hlo [optimize = ...] [no_nan = <true/false>] f(args...)
 """
 macro code_hlo(args...)
-    default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false)
+    default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false, :backend => "gpu")
     compile_expr, (; compiled) = compile_call_expr(
         __module__, compile_mlir, default_options, args...
     )
@@ -975,12 +978,18 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
     context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
     @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
 
+    if client !== nothing
+        backend = XLA.ClientGetPlatformName(backend)
+    else
+        backend = XLA.ClientGetPlatformName(XLA.default_backend[])
+    end
+
     MLIR.IR.activate!(ctx)
     results = try
         # compile function to MLIR module
         mod = MLIR.IR.Module(MLIR.IR.Location())
         linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!(
-            mod, f, args; optimize, no_nan
+            mod, f, args; optimize, no_nan, backend
         )
 
         # Resolve client and device

From 58bc103b3321b963327d7978616411fd7aa324b5 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh@wsmoses.com>
Date: Tue, 28 Jan 2025 19:50:30 -0500
Subject: [PATCH 02/18] add build

---
 deps/ReactantExtra/BUILD |  2 +-
 src/Compiler.jl          | 10 ++++++++--
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD
index 5ebcf2aaa..10bc8e61f 100644
--- a/deps/ReactantExtra/BUILD
+++ b/deps/ReactantExtra/BUILD
@@ -555,6 +555,7 @@ cc_library(
         "@llvm-project//mlir:CAPILLVMObjects",
         "@jax//jaxlib/mosaic:tpu_dialect_capi_objects",
         "@jax//jaxlib/triton:triton_dialect_capi_objects",
+        "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl",
     ] + select({
         "@xla//xla/tsl:is_cuda_enabled_and_oss":[
             "@xla//xla/stream_executor/cuda:all_runtime",
@@ -566,7 +567,6 @@ cc_library(
             "@xla//xla/backends/profiler/gpu:device_tracer",
         ],
         "//conditions:default": [
-            "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl",
         ],
     }) + if_rocm([
         "@xla//xla/service/gpu:amdgpu_compiler",
diff --git a/src/Compiler.jl b/src/Compiler.jl
index 81074beab..87619f3d4 100644
--- a/src/Compiler.jl
+++ b/src/Compiler.jl
@@ -979,10 +979,16 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
     @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
 
     if client !== nothing
-        backend = XLA.ClientGetPlatformName(backend)
+        backend = XLA.ClientGetPlatformName(client)
     else
         backend = XLA.ClientGetPlatformName(XLA.default_backend[])
     end
+    if backend == "CUDA"
+        backend = "GPU"
+    elseif backend == "CPU"
+        backend = "cpu"
+    end
+    @show backend
 
     MLIR.IR.activate!(ctx)
     results = try
@@ -1036,7 +1042,7 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
             num_partitions=1,
             use_shardy_partitioner=false,
         )
-        return (
+        (
             exec,
             linear_args,
             linear_results,

From e58c4b6333d9b8185c78bd73cf4a63c5157251b6 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh@wsmoses.com>
Date: Tue, 28 Jan 2025 21:29:39 -0500
Subject: [PATCH 03/18] cpu handle

---
 deps/ReactantExtra/BUILD     | 2 ++
 deps/ReactantExtra/WORKSPACE | 2 +-
 src/XLA.jl                   | 1 +
 3 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD
index 10bc8e61f..65f269172 100644
--- a/deps/ReactantExtra/BUILD
+++ b/deps/ReactantExtra/BUILD
@@ -362,6 +362,7 @@ cc_library(
     ) + [
 	"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
 	"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
+	"@enzyme_ad//src/enzyme_ad/jax:cpu.cc",
         # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
         # "@xla//xla:xla.pb.cc",
         "@xla//xla:xla_data.pb.cc",
@@ -437,6 +438,7 @@ cc_library(
 "-Wl,-exported_symbol,_RegisterCustomCallTarget",
 "-Wl,-exported_symbol,_ConvertLLVMToMLIR",
 "-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler",
+"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler",
 "-Wl,-exported_symbol,_ReactantThrowError",
 "-Wl,-exported_symbol,_ReactantHandleCuResult",
 "-Wl,-exported_symbol,_CreateProfilerSession",
diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE
index 7916c7305..3a4499a75 100644
--- a/deps/ReactantExtra/WORKSPACE
+++ b/deps/ReactantExtra/WORKSPACE
@@ -9,7 +9,7 @@ http_archive(
     urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
 )
 
-ENZYMEXLA_COMMIT = "264115eb30f0b23f73040bd68ee6964b634330e9"
+ENZYMEXLA_COMMIT = "57c817748e620b9c688e9e6129af721d44c23f19"
 ENZYMEXLA_SHA256 = ""
 
 http_archive(
diff --git a/src/XLA.jl b/src/XLA.jl
index 224a1deeb..435fdbcbc 100644
--- a/src/XLA.jl
+++ b/src/XLA.jl
@@ -178,6 +178,7 @@ function __init__()
         end
     end
 
+    @ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid
     @ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid
 
     # This wasn't properly exported on macos, we'll remove the try once macOS JLL

From a2de1505b74e41f3f4f209bc0e79214555cfef8f Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh@wsmoses.com>
Date: Tue, 28 Jan 2025 21:31:19 -0500
Subject: [PATCH 04/18] KA backend

---
 ext/ReactantCUDAExt.jl | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl
index 6d863fc53..502b8cdc1 100644
--- a/ext/ReactantCUDAExt.jl
+++ b/ext/ReactantCUDAExt.jl
@@ -1,7 +1,7 @@
 module ReactantCUDAExt
 
 using CUDA
-using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
+using Reactant: Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber
 using ReactantCore: @trace
 using KernelAbstractions: KernelAbstractions
 using Libdl
@@ -9,6 +9,7 @@ using Libdl
 using Adapt
 
 KernelAbstractions.get_backend(::AnyTracedRArray) = CUDABackend()
+KernelAbstractions.get_backend(::AnyConcreteRArray) = CUDABackend()
 
 struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
     ptr::Core.LLVMPtr{T,A}

From 575fc9243d0d1f987ab82b2b27395c385705add3 Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 29 Jan 2025 09:59:07 +0100
Subject: [PATCH 05/18] Update WORKSPACE

---
 deps/ReactantExtra/WORKSPACE | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE
index 3a4499a75..6d571742b 100644
--- a/deps/ReactantExtra/WORKSPACE
+++ b/deps/ReactantExtra/WORKSPACE
@@ -9,7 +9,7 @@ http_archive(
     urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
 )
 
-ENZYMEXLA_COMMIT = "57c817748e620b9c688e9e6129af721d44c23f19"
+ENZYMEXLA_COMMIT = "ab95bec0a69c92be7065f9eabc16e2d6b0da11b7"
 ENZYMEXLA_SHA256 = ""
 
 http_archive(

From 80e4e66a20f647bf1616f15521bec47edb8f6f0d Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 29 Jan 2025 12:05:54 +0100
Subject: [PATCH 06/18] Update WORKSPACE

---
 deps/ReactantExtra/WORKSPACE | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE
index 6d571742b..b18f64beb 100644
--- a/deps/ReactantExtra/WORKSPACE
+++ b/deps/ReactantExtra/WORKSPACE
@@ -9,7 +9,7 @@ http_archive(
     urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
 )
 
-ENZYMEXLA_COMMIT = "ab95bec0a69c92be7065f9eabc16e2d6b0da11b7"
+ENZYMEXLA_COMMIT = "b37119b68e20370c94736a957d34b424dc5825f6"
 ENZYMEXLA_SHA256 = ""
 
 http_archive(

From ca7d5b681dc265aeda5ca6024cea11e8e4804ccb Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 29 Jan 2025 12:15:54 +0100
Subject: [PATCH 07/18] Update WORKSPACE

---
 deps/ReactantExtra/WORKSPACE | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE
index b18f64beb..ddb1eca3f 100644
--- a/deps/ReactantExtra/WORKSPACE
+++ b/deps/ReactantExtra/WORKSPACE
@@ -9,7 +9,7 @@ http_archive(
     urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
 )
 
-ENZYMEXLA_COMMIT = "b37119b68e20370c94736a957d34b424dc5825f6"
+ENZYMEXLA_COMMIT = "e079972eb38fdda8e004df08c764ac7e93a2d85c"
 ENZYMEXLA_SHA256 = ""
 
 http_archive(

From 9818c8f976f1cff79d5a2fe1695aabaabfb970a3 Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 29 Jan 2025 12:18:35 +0100
Subject: [PATCH 08/18] Update WORKSPACE

---
 deps/ReactantExtra/WORKSPACE | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE
index ddb1eca3f..44d4458a3 100644
--- a/deps/ReactantExtra/WORKSPACE
+++ b/deps/ReactantExtra/WORKSPACE
@@ -9,7 +9,7 @@ http_archive(
     urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
 )
 
-ENZYMEXLA_COMMIT = "e079972eb38fdda8e004df08c764ac7e93a2d85c"
+ENZYMEXLA_COMMIT = "df1905af736d0cc7741e8753f18d3f536fbc2e19"
 ENZYMEXLA_SHA256 = ""
 
 http_archive(

From e1b1e8f539f471240233abe7c4d5f5fed0e23afe Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 29 Jan 2025 12:45:17 +0100
Subject: [PATCH 09/18] Update WORKSPACE

---
 deps/ReactantExtra/WORKSPACE | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE
index 44d4458a3..935efdbf4 100644
--- a/deps/ReactantExtra/WORKSPACE
+++ b/deps/ReactantExtra/WORKSPACE
@@ -9,7 +9,7 @@ http_archive(
     urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
 )
 
-ENZYMEXLA_COMMIT = "df1905af736d0cc7741e8753f18d3f536fbc2e19"
+ENZYMEXLA_COMMIT = "d89468ed883ca18c04346eec10f784bbe2b754fc"
 ENZYMEXLA_SHA256 = ""
 
 http_archive(

From 4ea9f49d84757087ae00abc46427aaec4913dc9a Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 29 Jan 2025 17:02:21 +0100
Subject: [PATCH 10/18] Update Project.toml

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index d6f5d06a7..1baceeae1 100644
--- a/Project.toml
+++ b/Project.toml
@@ -72,7 +72,7 @@ PythonCall = "0.9"
 Random = "1.10"
 Random123 = "1.7"
 ReactantCore = "0.1.4"
-Reactant_jll = "0.0.52"
+Reactant_jll = "0.0.56"
 Scratch = "1.2"
 Sockets = "1.10"
 SpecialFunctions = "2.4"

From 7d906aeadecbd50836332f15c439124b55e3484a Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 29 Jan 2025 23:25:18 +0100
Subject: [PATCH 11/18] Update Compiler.jl

---
 src/Compiler.jl | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/Compiler.jl b/src/Compiler.jl
index 87619f3d4..2145555f7 100644
--- a/src/Compiler.jl
+++ b/src/Compiler.jl
@@ -988,7 +988,6 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
     elseif backend == "CPU"
         backend = "cpu"
     end
-    @show backend
 
     MLIR.IR.activate!(ctx)
     results = try

From 6014df7518d2ca966b543f6ad462ccc3667a94ac Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 29 Jan 2025 23:26:03 +0100
Subject: [PATCH 12/18] Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
---
 src/Compiler.jl | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/src/Compiler.jl b/src/Compiler.jl
index 2145555f7..dcb48ae39 100644
--- a/src/Compiler.jl
+++ b/src/Compiler.jl
@@ -429,7 +429,9 @@ end
 
 const DEBUG_KERNEL = Ref{Bool}(false)
 
-function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu")
+function compile_mlir!(
+    mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu"
+)
     # Explicitly don't use block! to avoid creating a closure, which creates
     # both compile-time and relocatability issues
 
@@ -607,7 +609,9 @@ end
     @code_hlo [optimize = ...] [no_nan = <true/false>] f(args...)
 """
 macro code_hlo(args...)
-    default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false, :backend => "gpu")
+    default_options = Dict{Symbol,Any}(
+        :optimize => true, :no_nan => false, :backend => "gpu"
+    )
     compile_expr, (; compiled) = compile_call_expr(
         __module__, compile_mlir, default_options, args...
     )

From b51ad2dc47ac15a8e6e98255f07ece5b3ff556a0 Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Wed, 29 Jan 2025 23:29:02 +0100
Subject: [PATCH 13/18] Update BUILD

---
 deps/ReactantExtra/BUILD | 1 +
 1 file changed, 1 insertion(+)

diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD
index 65f269172..82c0bef25 100644
--- a/deps/ReactantExtra/BUILD
+++ b/deps/ReactantExtra/BUILD
@@ -450,6 +450,7 @@ cc_library(
 "-Wl,-exported_symbol,_ProfilerActivityEnd",
 "-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
 "-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
+"-Wl,-exported_symbol,_ClientGetPlatformName",
 "-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions"
     ]}),
     deps = [

From 264ec4be75347db169737529259a6bfc10b8c02e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mos=C3=A8=20Giordano?=
 <765740+giordano@users.noreply.github.com>
Date: Wed, 29 Jan 2025 23:30:15 +0100
Subject: [PATCH 14/18] Update ext/ReactantCUDAExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
---
 ext/ReactantCUDAExt.jl | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl
index 502b8cdc1..7cf2a61f3 100644
--- a/ext/ReactantCUDAExt.jl
+++ b/ext/ReactantCUDAExt.jl
@@ -1,7 +1,8 @@
 module ReactantCUDAExt
 
 using CUDA
-using Reactant: Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber
+using Reactant:
+    Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber
 using ReactantCore: @trace
 using KernelAbstractions: KernelAbstractions
 using Libdl

From b778228ff052e34a623c427652f254b7d293f564 Mon Sep 17 00:00:00 2001
From: Avik Pal <avikpal@mit.edu>
Date: Wed, 29 Jan 2025 23:18:28 -0500
Subject: [PATCH 15/18] Update Project.toml

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 1baceeae1..22982abc2 100644
--- a/Project.toml
+++ b/Project.toml
@@ -72,7 +72,7 @@ PythonCall = "0.9"
 Random = "1.10"
 Random123 = "1.7"
 ReactantCore = "0.1.4"
-Reactant_jll = "0.0.56"
+Reactant_jll = "0.0.58"
 Scratch = "1.2"
 Sockets = "1.10"
 SpecialFunctions = "2.4"

From b6254523b9a6f5538c10c662bc99e0a8759f0703 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh@wsmoses.com>
Date: Thu, 30 Jan 2025 09:41:34 +0100
Subject: [PATCH 16/18] Adapt to compile abi change

---
 src/Compiler.jl | 15 +--------
 src/XLA.jl      | 82 +++++++++++++++++++++++++++++++++++++++++--------
 2 files changed, 70 insertions(+), 27 deletions(-)

diff --git a/src/Compiler.jl b/src/Compiler.jl
index dcb48ae39..42f1c53f6 100644
--- a/src/Compiler.jl
+++ b/src/Compiler.jl
@@ -1002,7 +1002,6 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
         )
 
         # Resolve client and device
-        device_ordinal = -1
         if device === nothing
             if length(linear_args) > 0
                 devices_list = [
@@ -1020,30 +1019,18 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
                 client = XLA.client(device)
             else
                 client = XLA.default_backend[]
-                device = XLA.ClientGetDevice(client, XLA.default_device_idx[])
-                device_ordinal = XLA.default_device_idx[]
             end
         else
             if device !== nothing
                 @assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same"
-            else
-                device = XLA.ClientGetDevice(client, XLA.default_device_idx[])
-                device_ordinal = XLA.default_device_idx[]
             end
         end
 
-        if device_ordinal < 0
-            device_ordinal = XLA.DeviceToClientDeviceOrdinal(device)
-        end
 
         # compile MLIR module to XLA executable
         exec = XLA.Compile(
             client,
-            mod;
-            device_ordinal,
-            num_replicas=1,
-            num_partitions=1,
-            use_shardy_partitioner=false,
+            mod
         )
         (
             exec,
diff --git a/src/XLA.jl b/src/XLA.jl
index 435fdbcbc..f3fb34407 100644
--- a/src/XLA.jl
+++ b/src/XLA.jl
@@ -15,17 +15,47 @@ end
 
 mutable struct Client
     client::Ptr{Cvoid}
+    global_ordinals::Vector{Cint}
 
     function Client(client::Ptr{Cvoid})
         @assert client != C_NULL
-        return new(client)
+        global_ordinals = Cint[]
+        client = new(client, global_ordinals)
+
+        # https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127
+        devices = [
+            ClientGetAddressableDevice(client, i - 1) for
+            i in 1:ClientNumAddressableDevices(client)
+        ]
+        sort!(devices; lt=(a, b) -> DeviceGetLocalDeviceId(a) < DeviceGetLocalDeviceId(b))
+
+        local_ids = [DeviceGetLocalDeviceId(device) + 1 for device in devices]
+        max_local_id = maximum(local_ids)
+        resize!(global_ordinals, max_local_id)
+        global_ordinals .= -1
+        for (i, device) in enumerate(devices)
+            global_ordinals[local_ids[i]] = i - 1
+        end
+        return client
     end
 end
 
 Base.:(==)(a::Client, b::Client) = a.client == b.client
 
-function Base.show(io::IO, ::MIME"text/plain", client::Client)
-    print(io, "Client($(client.client), platform_name=$(ClientGetPlatformName(client)))")
+function device_ordinal(client::Client, device::Device)
+    return client.global_ordinals[DeviceGetLocalDeviceId(device) + 1]
+end
+
+function DeviceToString(device::Device)
+    pjrtclient = client(device)
+    platform_name = ClientGetPlatformName(pjrtclient)
+    return "$(uppercase(platform_name)):$(device_ordinal(pjrtclient, device))"
+end
+
+function Base.show(io::IO, ::MIME"text/plain", device::Device)
+    pjrtclient = client(device)
+    platform_name = ClientGetPlatformName(pjrtclient)
+    print(io, "Device($(device.device), platform_name=$(platform_name))")
     return nothing
 end
 
@@ -337,7 +367,7 @@ Return an [`AllocatorStats`](@ref) instance with information about the device sp
     This method is currently not implemented for the CPU device.
 """
 function allocatorstats(
-    device::Device=ClientGetDevice(default_backend[], default_device_idx[])
+    device::Device=ClientGetAddressableDevice(default_backend[], default_device_idx[])
 )
     ref = Ref{JLAllocatorStats}()
     @ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats(
@@ -540,21 +570,16 @@ end
 
 function Compile(
     client::Client,
-    mod::MLIR.IR.Module;
-    device_ordinal::Int=-1,
-    num_replicas::Int=1,
-    num_partitions::Int=1,
-    use_shardy_partitioner::Bool=false,
+    mod::MLIR.IR.Module
 )
+    max_local_id = length(client.global_ordinals)
     GC.@preserve client mod begin
         executable = LoadedExecutable(
             @ccall MLIR.API.mlir_c.ClientCompile(
                 client.client::Ptr{Cvoid},
                 mod.module_::MLIR.API.MlirModule,
-                device_ordinal::Cint,
-                num_replicas::Cint,
-                num_partitions::Cint,
-                use_shardy_partitioner::Bool,
+                client.global_ordinals::Ptr{Cint},
+                max_local_id::Cint,
             )::Ptr{Cvoid}
         )
     end
@@ -609,6 +634,37 @@ function ClientGetPlatformName(client::Client)
     return unsafe_string(str)
 end
 
+function DeviceGetLocalDeviceId(device::Device)
+    GC.@preserve device begin
+        return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalDeviceId(
+            device.device::Ptr{Cvoid}
+        )::Cint
+    end
+end
+
+function PjRtLoadedExecutableGetClient(exec::LoadedExecutable)
+    GC.@preserve exec begin
+        return Client(
+            @ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetClient(
+                exec.exec::Ptr{Cvoid}
+            )::Ptr{Cvoid}
+        )
+    end
+end
+
+function replicate_buffer_on_all_addressable_devices(buffer::Buffer)
+    pjrtclient = client(buffer)
+    devices = [
+        ClientGetAddressableDevice(pjrtclient, i - 1) for
+        i in 1:ClientNumAddressableDevices(pjrtclient)
+    ]
+    orig_device = device(buffer)
+    return [
+        device == orig_device ? buffer : CopyBufferToDevice(buffer, device) for
+        device in devices
+    ]
+end
+
 function is_ready(future::Future)
     GC.@preserve future begin
         return (@ccall MLIR.API.mlir_c.FutureIsReady(future.future::Ptr{Cvoid})::UInt8) != 0

From 53a1b52ecf3464adfcd2bc9e58be1b8fe69a1fb0 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh@wsmoses.com>
Date: Thu, 30 Jan 2025 09:46:35 +0100
Subject: [PATCH 17/18] move device

---
 src/XLA.jl | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/XLA.jl b/src/XLA.jl
index f3fb34407..37f57e4c7 100644
--- a/src/XLA.jl
+++ b/src/XLA.jl
@@ -42,6 +42,10 @@ end
 
 Base.:(==)(a::Client, b::Client) = a.client == b.client
 
+struct Device
+    device::Ptr{Cvoid}
+end
+
 function device_ordinal(client::Client, device::Device)
     return client.global_ordinals[DeviceGetLocalDeviceId(device) + 1]
 end
@@ -258,10 +262,6 @@ mutable struct Buffer
     end
 end
 
-struct Device
-    device::Ptr{Cvoid}
-end
-
 function Base.show(io::IO, ::MIME"text/plain", device::Device)
     pjrtclient = client(device)
     platform_name = ClientGetPlatformName(pjrtclient)

From 3c258562f559d29686f6836531066c676be04b42 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh@wsmoses.com>
Date: Thu, 30 Jan 2025 09:47:35 +0100
Subject: [PATCH 18/18] rm redundant

---
 src/XLA.jl | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/src/XLA.jl b/src/XLA.jl
index 37f57e4c7..ddae554fa 100644
--- a/src/XLA.jl
+++ b/src/XLA.jl
@@ -262,13 +262,6 @@ mutable struct Buffer
     end
 end
 
-function Base.show(io::IO, ::MIME"text/plain", device::Device)
-    pjrtclient = client(device)
-    platform_name = ClientGetPlatformName(pjrtclient)
-    print(io, "Device($(device.device), platform_name=$(platform_name))")
-    return nothing
-end
-
 function DeviceToClientDeviceOrdinal(device::Device)
     pjrtclient = client(device)
     naddressable_devices = ClientNumAddressableDevices(pjrtclient)