Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU backend #647

Merged
merged 20 commits into from
Jan 30, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.4"
Reactant_jll = "0.0.52"
Reactant_jll = "0.0.56"
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Expand Down
4 changes: 3 additions & 1 deletion ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
module ReactantCUDAExt

using CUDA
using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber
using ReactantCore: @trace
using KernelAbstractions: KernelAbstractions
using Libdl

using Adapt

KernelAbstractions.get_backend(::AnyTracedRArray) = CUDABackend()
KernelAbstractions.get_backend(::AnyConcreteRArray) = CUDABackend()

struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
ptr::Core.LLVMPtr{T,A}
Expand Down
28 changes: 23 additions & 5 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ end

const DEBUG_KERNEL = Ref{Bool}(false)

function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
function compile_mlir!(
mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu"
)
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues

Expand All @@ -456,7 +458,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
if isdefined(Reactant_jll, :ptxas_path)
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
end
if DEBUG_KERNEL[]

if backend == "cpu"
kern = "lower-kernel{openmp=false backend=cpu},symbol-dce"
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
elseif DEBUG_KERNEL[]
curesulthandler = XLA.Libdl.dlsym(
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
)
Expand Down Expand Up @@ -604,7 +609,9 @@ end
@code_hlo [optimize = ...] [no_nan = <true/false>] f(args...)
"""
macro code_hlo(args...)
default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false)
default_options = Dict{Symbol,Any}(
:optimize => true, :no_nan => false, :backend => "gpu"
)
compile_expr, (; compiled) = compile_call_expr(
__module__, compile_mlir, default_options, args...
)
Expand Down Expand Up @@ -975,12 +982,23 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

if client !== nothing
backend = XLA.ClientGetPlatformName(client)
else
backend = XLA.ClientGetPlatformName(XLA.default_backend[])
end
if backend == "CUDA"
backend = "GPU"
elseif backend == "CPU"
backend = "cpu"
end

MLIR.IR.activate!(ctx)
results = try
# compile function to MLIR module
mod = MLIR.IR.Module(MLIR.IR.Location())
linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!(
mod, f, args; optimize, no_nan
mod, f, args; optimize, no_nan, backend
)

# Resolve client and device
Expand Down Expand Up @@ -1027,7 +1045,7 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
num_partitions=1,
use_shardy_partitioner=false,
)
return (
(
exec,
linear_args,
linear_results,
Expand Down
1 change: 1 addition & 0 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ function __init__()
end
end

@ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid

# This wasn't properly exported on macos, we'll remove the try once macOS JLL
Expand Down
Loading