Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU backend #647

Merged
merged 20 commits into from
Jan 30, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -72,7 +72,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.4"
Reactant_jll = "0.0.52"
Reactant_jll = "0.0.58"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
4 changes: 3 additions & 1 deletion ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
module ReactantCUDAExt

using CUDA
using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber
using ReactantCore: @trace
using KernelAbstractions: KernelAbstractions
using Libdl

using Adapt

KernelAbstractions.get_backend(::AnyTracedRArray) = CUDABackend()
KernelAbstractions.get_backend(::AnyConcreteRArray) = CUDABackend()

struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
ptr::Core.LLVMPtr{T,A}
43 changes: 24 additions & 19 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
@@ -429,7 +429,9 @@ end

const DEBUG_KERNEL = Ref{Bool}(false)

function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
function compile_mlir!(
mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu"
)
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues

@@ -456,7 +458,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
if isdefined(Reactant_jll, :ptxas_path)
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
end
if DEBUG_KERNEL[]

if backend == "cpu"
kern = "lower-kernel{openmp=false backend=cpu},symbol-dce"
elseif DEBUG_KERNEL[]
curesulthandler = XLA.Libdl.dlsym(
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
)
@@ -604,7 +609,9 @@ end
@code_hlo [optimize = ...] [no_nan = <true/false>] f(args...)
"""
macro code_hlo(args...)
default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false)
default_options = Dict{Symbol,Any}(
:optimize => true, :no_nan => false, :backend => "gpu"
)
compile_expr, (; compiled) = compile_call_expr(
__module__, compile_mlir, default_options, args...
)
@@ -975,16 +982,26 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

if client !== nothing
backend = XLA.ClientGetPlatformName(client)
else
backend = XLA.ClientGetPlatformName(XLA.default_backend[])
end
if backend == "CUDA"
backend = "GPU"
elseif backend == "CPU"
backend = "cpu"
end

MLIR.IR.activate!(ctx)
results = try
# compile function to MLIR module
mod = MLIR.IR.Module(MLIR.IR.Location())
linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!(
mod, f, args; optimize, no_nan
mod, f, args; optimize, no_nan, backend
)

# Resolve client and device
device_ordinal = -1
if device === nothing
if length(linear_args) > 0
devices_list = [
@@ -1002,32 +1019,20 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
client = XLA.client(device)
else
client = XLA.default_backend[]
device = XLA.ClientGetDevice(client, XLA.default_device_idx[])
device_ordinal = XLA.default_device_idx[]
end
else
if device !== nothing
@assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same"
else
device = XLA.ClientGetDevice(client, XLA.default_device_idx[])
device_ordinal = XLA.default_device_idx[]
end
end

if device_ordinal < 0
device_ordinal = XLA.DeviceToClientDeviceOrdinal(device)
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# compile MLIR module to XLA executable
exec = XLA.Compile(
client,
mod;
device_ordinal,
num_replicas=1,
num_partitions=1,
use_shardy_partitioner=false,
mod
)
Comment on lines 1031 to 1034
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
exec = XLA.Compile(
client,
mod;
device_ordinal,
num_replicas=1,
num_partitions=1,
use_shardy_partitioner=false,
mod
)
exec = XLA.Compile(client, mod)

return (
(
exec,
linear_args,
linear_results,
98 changes: 74 additions & 24 deletions src/XLA.jl
Original file line number Diff line number Diff line change
@@ -15,17 +15,51 @@ end

mutable struct Client
client::Ptr{Cvoid}
global_ordinals::Vector{Cint}

function Client(client::Ptr{Cvoid})
@assert client != C_NULL
return new(client)
global_ordinals = Cint[]
client = new(client, global_ordinals)

# https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127
devices = [
ClientGetAddressableDevice(client, i - 1) for
i in 1:ClientNumAddressableDevices(client)
]
sort!(devices; lt=(a, b) -> DeviceGetLocalDeviceId(a) < DeviceGetLocalDeviceId(b))

local_ids = [DeviceGetLocalDeviceId(device) + 1 for device in devices]
max_local_id = maximum(local_ids)
resize!(global_ordinals, max_local_id)
global_ordinals .= -1
for (i, device) in enumerate(devices)
global_ordinals[local_ids[i]] = i - 1
end
return client
end
end

Base.:(==)(a::Client, b::Client) = a.client == b.client

function Base.show(io::IO, ::MIME"text/plain", client::Client)
print(io, "Client($(client.client), platform_name=$(ClientGetPlatformName(client)))")
struct Device
device::Ptr{Cvoid}
end

function device_ordinal(client::Client, device::Device)
return client.global_ordinals[DeviceGetLocalDeviceId(device) + 1]
end

function DeviceToString(device::Device)
pjrtclient = client(device)
platform_name = ClientGetPlatformName(pjrtclient)
return "$(uppercase(platform_name)):$(device_ordinal(pjrtclient, device))"
end

function Base.show(io::IO, ::MIME"text/plain", device::Device)
pjrtclient = client(device)
platform_name = ClientGetPlatformName(pjrtclient)
print(io, "Device($(device.device), platform_name=$(platform_name))")
return nothing
end

@@ -178,6 +212,7 @@ function __init__()
end
end

@ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid

# This wasn't properly exported on macos, we'll remove the try once macOS JLL
@@ -227,17 +262,6 @@ mutable struct Buffer
end
end

struct Device
device::Ptr{Cvoid}
end

function Base.show(io::IO, ::MIME"text/plain", device::Device)
pjrtclient = client(device)
platform_name = ClientGetPlatformName(pjrtclient)
print(io, "Device($(device.device), platform_name=$(platform_name))")
return nothing
end

function DeviceToClientDeviceOrdinal(device::Device)
pjrtclient = client(device)
naddressable_devices = ClientNumAddressableDevices(pjrtclient)
@@ -336,7 +360,7 @@ Return an [`AllocatorStats`](@ref) instance with information about the device sp
This method is currently not implemented for the CPU device.
"""
function allocatorstats(
device::Device=ClientGetDevice(default_backend[], default_device_idx[])
device::Device=ClientGetAddressableDevice(default_backend[], default_device_idx[])
)
ref = Ref{JLAllocatorStats}()
@ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats(
@@ -539,21 +563,16 @@ end

function Compile(
client::Client,
mod::MLIR.IR.Module;
device_ordinal::Int=-1,
num_replicas::Int=1,
num_partitions::Int=1,
use_shardy_partitioner::Bool=false,
mod::MLIR.IR.Module
)
Comment on lines 564 to 567
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function Compile(
client::Client,
mod::MLIR.IR.Module;
device_ordinal::Int=-1,
num_replicas::Int=1,
num_partitions::Int=1,
use_shardy_partitioner::Bool=false,
mod::MLIR.IR.Module
)
function Compile(client::Client, mod::MLIR.IR.Module)

max_local_id = length(client.global_ordinals)
GC.@preserve client mod begin
executable = LoadedExecutable(
@ccall MLIR.API.mlir_c.ClientCompile(
client.client::Ptr{Cvoid},
mod.module_::MLIR.API.MlirModule,
device_ordinal::Cint,
num_replicas::Cint,
num_partitions::Cint,
use_shardy_partitioner::Bool,
client.global_ordinals::Ptr{Cint},
max_local_id::Cint,
)::Ptr{Cvoid}
)
end
@@ -608,6 +627,37 @@ function ClientGetPlatformName(client::Client)
return unsafe_string(str)
end

function DeviceGetLocalDeviceId(device::Device)
GC.@preserve device begin
return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalDeviceId(
device.device::Ptr{Cvoid}
)::Cint
end
end

function PjRtLoadedExecutableGetClient(exec::LoadedExecutable)
GC.@preserve exec begin
return Client(
@ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetClient(
exec.exec::Ptr{Cvoid}
)::Ptr{Cvoid}
)
end
end

function replicate_buffer_on_all_addressable_devices(buffer::Buffer)
pjrtclient = client(buffer)
devices = [
ClientGetAddressableDevice(pjrtclient, i - 1) for
i in 1:ClientNumAddressableDevices(pjrtclient)
]
orig_device = device(buffer)
return [
device == orig_device ? buffer : CopyBufferToDevice(buffer, device) for
device in devices
]
end

function is_ready(future::Future)
GC.@preserve future begin
return (@ccall MLIR.API.mlir_c.FutureIsReady(future.future::Ptr{Cvoid})::UInt8) != 0