Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

excessive compile times and failures with GPU #2283

Open
ExpandingMan opened this issue Jan 28, 2025 · 0 comments
Open

excessive compile times and failures with GPU #2283

ExpandingMan opened this issue Jan 28, 2025 · 0 comments

Comments

@ExpandingMan
Copy link
Contributor

Minimal Lux examples on GPU are still impractical due to excessive compile times, and occasional failures.

Here is a minimal example:

using LinearAlgebra, Random, Statistics, Optimisers
using CUDA
using Lux, LuxCUDA
import ProgressMeter as PM
import Enzyme, Zygote

using Lux.MLDataDevices: AbstractDevice

Enzyme.Compiler.VERBOSE_ERRORS[] = true

function makedata(rng::AbstractRNG)
    X = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(X, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    (X, y)
end

function main(dev::AbstractDevice, rng=Random.Xoshiro(999),
              model=Chain(Dense(1=>16, gelu), Dense(16=>1)),
              (X, y)=makedata(rng) |> dev;
              nepochs=300,
              opt=Adam(0.01f0),
              backend=AutoEnzyme(),
             )
    (θ, ψ) = Lux.setup(rng, model) |> dev

    s = Lux.Training.TrainState(model, θ, ψ, opt)
    pm = PM.Progress(nepochs)
    for j  1:nepochs
        (∂s, ℓ, stats, s) = Lux.Training.single_train_step!(
            backend, MSELoss(),
            (X, y), s,
        )
        PM.next!(pm)
    end
    PM.finish!(pm)

    (yhat, _) = Lux.apply(model, X, θ, ψ)

    (yhat, y)
end

On the CPU I get

◖◗ @time main(cpu_device());
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
 43.469745 seconds (481.86 M allocations: 20.523 GiB, 5.22% gc time, 99.98% compilation time)

This is a bit worryingly slow since it is such a small example, but depending on how it scales it isn't necessarily prohibitive.

However, on GPU it tries to compile for about 10 to 15 minutes and then gives the following error

◖◗ @time main(gpu_device());
ERROR: Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Current scope:
define internal fastcc void @julia_nonblocking_synchronize_187572({} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %0) unnamed_addr #197 !dbg !8001 {
top:
  %1 = alloca [3 x [2 x {} addrspace(10)*]], align 8
  %pgcstack = call {}*** @julia.get_pgcstack()
  %ptls_field6 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %2 = bitcast {}*** %ptls_field6 to i64***
  %ptls_load78 = load i64**, i64*** %2, align 8, !tbaa !381
  %3 = getelementptr inbounds i64*, i64** %ptls_load78, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !385
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint), !dbg !8002
  fence syncscope("singlethread") seq_cst
  %4 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* addrspacecast ({}* inttoptr (i64 124712895855264 to {}*) to {} addrspace(11)*)) #413, !dbg !8003
  %ptr.i = bitcast {}* %4 to i32*, !dbg !8007
  %rv.i = atomicrmw add i32* %ptr.i, i32 1 acq_rel, align 4, !dbg !8007
  %5 = and i32 %rv.i, 3, !dbg !8010
  %.not = icmp eq i32 %5, 0, !dbg !8018
  %narrow = select i1 %.not, i32 4, i32 %5, !dbg !8020
  %6 = zext i32 %narrow to i64, !dbg !8020
  %7 = load i64, i64* inttoptr (i64 124712895855408 to i64*), align 16, !dbg !8022, !tbaa !633, !alias.scope !636, !noalias !637
  %8 = add nsw i64 %6, -1, !dbg !8035
  %.not9 = icmp ult i64 %8, %7, !dbg !8038
  br i1 %.not9, label %L40, label %L49, !dbg !8032

L40:                                              ; preds = %top
  %9 = load {} addrspace(10)**, {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***), align 32, !dbg !8040, !tbaa !642, !alias.scope !636, !noalias !637
  %10 = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 124712895855400 to {} addrspace(10)**), align 8, !dbg !8040, !tbaa !642, !alias.scope !636, !noalias !637, !dereferenceable_or_null !494, !align !503
  %11 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %10, {} addrspace(10)** %9), !dbg !8043
  %12 = bitcast {} addrspace(10)* addrspace(13)* %11 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !8043
  %13 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %12, i64 %8, i64 0, i64 0, !dbg !8043
  %14 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %13, align 8, !dbg !8043, !tbaa !4485, !alias.scope !429, !noalias !430
  %.not24 = icmp eq {} addrspace(10)* %14, null, !dbg !8043
  br i1 %.not24, label %L49, label %pass3, !dbg !8034

L49:                                              ; preds = %L40, %top
  call fastcc void @julia_create_synchronization_worker_189388(i64 signext %6), !dbg !8045
  %.pre = load {} addrspace(10)**, {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***), align 32, !dbg !8046, !tbaa !642, !alias.scope !636, !noalias !637
  %.pre25 = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 124712895855400 to {} addrspace(10)**), align 8, !dbg !8046, !tbaa !642, !alias.scope !636, !noalias !637
  %.pre26 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %.pre25, {} addrspace(10)** %.pre), !dbg !8046
  %.pre27 = bitcast {} addrspace(10)* addrspace(13)* %.pre26 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !8046
  %.unpack.elt.phi.trans.insert = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre27, i64 %8, i64 0, i64 0
  %.unpack.unpack.pre = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack.elt.phi.trans.insert, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.not23 = icmp eq {} addrspace(10)* %.unpack.unpack.pre, null, !dbg !8046
  br i1 %.not23, label %fail2, label %pass3, !dbg !8046

L75:                                              ; preds = %pass3
  call fastcc void @julia_throw_api_error_187593(i32 zeroext %18) #414, !dbg !8049
  unreachable, !dbg !8049

L77:                                              ; preds = %pass3
  ret void, !dbg !8050

fail2:                                            ; preds = %L49
  %15 = load {}*, {}** @jl_undefref_exception, align 8, !dbg !8046, !tbaa !385, !alias.scope !437, !noalias !438, !nonnull !380
  %16 = addrspacecast {}* %15 to {} addrspace(12)*, !dbg !8046
  call void @ijl_throw({} addrspace(12)* %16), !dbg !8046
  unreachable, !dbg !8046

pass3:                                            ; preds = %L40, %L49
  %nodecayed..pre-phi2834 = phi {} addrspace(10)*
  %nodecayedoff..pre-phi2834 = phi i64
  %.pre-phi2834 = phi [3 x [2 x {} addrspace(10)*]] addrspace(13)* [ %.pre27, %L49 ], [ %12, %L40 ]
  %.unpack.unpack33 = phi {} addrspace(10)* [ %.unpack.unpack.pre, %L49 ], [ %14, %L40 ]
  %.unpack.elt14 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 0, i64 1, !dbg !8046
  %.unpack.unpack15 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack.elt14, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.unpack11.elt = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 1, i64 0, !dbg !8046
  %.unpack11.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack11.elt, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.unpack11.elt17 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 1, i64 1, !dbg !8046
  %.unpack11.unpack18 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack11.elt17, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.unpack13.elt = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 2, i64 0, !dbg !8046
  %.unpack13.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack13.elt, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.unpack13.elt20 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 2, i64 1, !dbg !8046
  %.unpack13.unpack21 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack13.elt20, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430
  %.fca.0.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 0, i64 0, !dbg !8051
  store {} addrspace(10)* %.unpack.unpack33, {} addrspace(10)** %.fca.0.0.gep, align 8, !dbg !8051, !noalias !566
  %.fca.0.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 0, i64 1, !dbg !8051
  store {} addrspace(10)* %.unpack.unpack15, {} addrspace(10)** %.fca.0.1.gep, align 8, !dbg !8051, !noalias !566
  %.fca.1.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 1, i64 0, !dbg !8051
  store {} addrspace(10)* %.unpack11.unpack, {} addrspace(10)** %.fca.1.0.gep, align 8, !dbg !8051, !noalias !566
  %.fca.1.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 1, i64 1, !dbg !8051
  store {} addrspace(10)* %.unpack11.unpack18, {} addrspace(10)** %.fca.1.1.gep, align 8, !dbg !8051, !noalias !566
  %.fca.2.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 2, i64 0, !dbg !8051
  store {} addrspace(10)* %.unpack13.unpack, {} addrspace(10)** %.fca.2.0.gep, align 8, !dbg !8051, !noalias !566
  %.fca.2.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 2, i64 1, !dbg !8051
  store {} addrspace(10)* %.unpack13.unpack21, {} addrspace(10)** %.fca.2.1.gep, align 8, !dbg !8051, !noalias !566
  %17 = addrspacecast [3 x [2 x {} addrspace(10)*]]* %1 to [3 x [2 x {} addrspace(10)*]] addrspace(11)*, !dbg !8051
  %18 = call fastcc i32 @julia_put__189354([3 x [2 x {} addrspace(10)*]] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(48) %17, {} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %0), !dbg !8051
  %19 = icmp eq i32 %18, 0, !dbg !8052
  br i1 %19, label %L77, label %L75, !dbg !8057
}

Could not analyze garbage collection behavior of
 inst:   %.pre-phi2834 = phi [3 x [2 x {} addrspace(10)*]] addrspace(13)* [ %.pre27, %L49 ], [ %12, %L40 ]
 v0:   %.pre27 = bitcast {} addrspace(10)* addrspace(13)* %.pre26 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !461
 v: {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***)
 offset: i64 0
 hasload: true


Stacktrace:
 [1] #synchronize#1003
   @ ~/.julia/packages/CUDA/1kIOw/lib/cudadrv/synchronization.jl:200
 [2] multiple call sites
   @ unknown:0

Stacktrace:
  [1] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:969
  [2] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:653
  [3] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:682
  [4] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:818
  [5] nodecayed_phis!(mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:976
  [6] optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler/optimize.jl:582
  [7] nested_codegen!(mode::Enzyme.API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, world::UInt64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:401
  [8] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{…}, shadowR::Ptr{…}, tape::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/customrules.jl:960
  [9] enzyme_custom_augfwd
    @ ~/.julia/packages/Enzyme/R6sE8/src/rules/customrules.jl:1503 [inlined]
 [10] enzyme_custom_augfwd_cfunc(B::Ptr{…}, OrigCI::Ptr{…}, gutils::Ptr{…}, normalR::Ptr{…}, shadowR::Ptr{…}, tapeR::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/llvmrules.jl:18
 [11] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/R6sE8/src/api.jl:268
 [12] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:1706
 [13] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4550
 [14] codegen
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:3353 [inlined]
 [15] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410
 [16] _thunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410 [inlined]
 [17] cached_compilation
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5462 [inlined]
 [18] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5573
 [19] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5758
 [20] autodiff
    @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:485 [inlined]
 [21] compute_gradients_impl(ad::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ LuxEnzymeExt ~/.julia/packages/Lux/DHtyL/ext/LuxEnzymeExt/training.jl:8
 [22] compute_gradients
    @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:200 [inlined]
 [23] single_train_step_impl!
    @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:320 [inlined]
 [24] #single_train_step!#6
    @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:288 [inlined]
 [25] single_train_step!(backend::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:284
 [26] main(dev::CUDADevice{…}, rng::Xoshiro, model::Chain{…}, ::Tuple{…}; nepochs::Int64, opt::Adam, backend::AutoEnzyme{…})
    @ Main ~/src/lux_enzyme_test.jl:29
 [27] main(dev::CUDADevice{Nothing}, rng::Xoshiro, model::Chain{@NamedTuple{…}, Nothing}, ::Tuple{CuArray{…}, CuArray{…}})
    @ Main ~/src/lux_enzyme_test.jl:17
 [28] macro expansion
    @ ./timing.jl:581 [inlined]
 [29] top-level scope
    @ ./REPL[2]:1
Some type information was truncated. Use `show(err)` to see complete types.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant