We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minimal Lux examples on GPU are still impractical due to excessive compile times, and occasional failures.
Here is a minimal example:
using LinearAlgebra, Random, Statistics, Optimisers using CUDA using Lux, LuxCUDA import ProgressMeter as PM import Enzyme, Zygote using Lux.MLDataDevices: AbstractDevice Enzyme.Compiler.VERBOSE_ERRORS[] = true function makedata(rng::AbstractRNG) X = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128)) y = evalpoly.(X, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0 (X, y) end function main(dev::AbstractDevice, rng=Random.Xoshiro(999), model=Chain(Dense(1=>16, gelu), Dense(16=>1)), (X, y)=makedata(rng) |> dev; nepochs=300, opt=Adam(0.01f0), backend=AutoEnzyme(), ) (θ, ψ) = Lux.setup(rng, model) |> dev s = Lux.Training.TrainState(model, θ, ψ, opt) pm = PM.Progress(nepochs) for j ∈ 1:nepochs (∂s, ℓ, stats, s) = Lux.Training.single_train_step!( backend, MSELoss(), (X, y), s, ) PM.next!(pm) end PM.finish!(pm) (yhat, _) = Lux.apply(model, X, θ, ψ) (yhat, y) end
On the CPU I get
◖◗ @time main(cpu_device()); Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00 43.469745 seconds (481.86 M allocations: 20.523 GiB, 5.22% gc time, 99.98% compilation time)
This is a bit worryingly slow since it is such a small example, but depending on how it scales it isn't necessarily prohibitive.
However, on GPU it tries to compile for about 10 to 15 minutes and then gives the following error
◖◗ @time main(gpu_device()); ERROR: Enzyme compilation failed due to an internal error. Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false) Current scope: define internal fastcc void @julia_nonblocking_synchronize_187572({} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %0) unnamed_addr #197 !dbg !8001 { top: %1 = alloca [3 x [2 x {} addrspace(10)*]], align 8 %pgcstack = call {}*** @julia.get_pgcstack() %ptls_field6 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2 %2 = bitcast {}*** %ptls_field6 to i64*** %ptls_load78 = load i64**, i64*** %2, align 8, !tbaa !381 %3 = getelementptr inbounds i64*, i64** %ptls_load78, i64 2 %safepoint = load i64*, i64** %3, align 8, !tbaa !385 fence syncscope("singlethread") seq_cst call void @julia.safepoint(i64* %safepoint), !dbg !8002 fence syncscope("singlethread") seq_cst %4 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* addrspacecast ({}* inttoptr (i64 124712895855264 to {}*) to {} addrspace(11)*)) #413, !dbg !8003 %ptr.i = bitcast {}* %4 to i32*, !dbg !8007 %rv.i = atomicrmw add i32* %ptr.i, i32 1 acq_rel, align 4, !dbg !8007 %5 = and i32 %rv.i, 3, !dbg !8010 %.not = icmp eq i32 %5, 0, !dbg !8018 %narrow = select i1 %.not, i32 4, i32 %5, !dbg !8020 %6 = zext i32 %narrow to i64, !dbg !8020 %7 = load i64, i64* inttoptr (i64 124712895855408 to i64*), align 16, !dbg !8022, !tbaa !633, !alias.scope !636, !noalias !637 %8 = add nsw i64 %6, -1, !dbg !8035 %.not9 = icmp ult i64 %8, %7, !dbg !8038 br i1 %.not9, label %L40, label %L49, !dbg !8032 L40: ; preds = %top %9 = load {} addrspace(10)**, {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***), align 32, !dbg !8040, !tbaa !642, !alias.scope !636, !noalias !637 %10 = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 124712895855400 to {} addrspace(10)**), align 8, !dbg !8040, !tbaa !642, !alias.scope !636, !noalias !637, !dereferenceable_or_null !494, !align !503 %11 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %10, {} addrspace(10)** %9), !dbg !8043 %12 = bitcast {} addrspace(10)* addrspace(13)* %11 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !8043 %13 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %12, i64 %8, i64 0, i64 0, !dbg !8043 %14 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %13, align 8, !dbg !8043, !tbaa !4485, !alias.scope !429, !noalias !430 %.not24 = icmp eq {} addrspace(10)* %14, null, !dbg !8043 br i1 %.not24, label %L49, label %pass3, !dbg !8034 L49: ; preds = %L40, %top call fastcc void @julia_create_synchronization_worker_189388(i64 signext %6), !dbg !8045 %.pre = load {} addrspace(10)**, {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***), align 32, !dbg !8046, !tbaa !642, !alias.scope !636, !noalias !637 %.pre25 = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 124712895855400 to {} addrspace(10)**), align 8, !dbg !8046, !tbaa !642, !alias.scope !636, !noalias !637 %.pre26 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %.pre25, {} addrspace(10)** %.pre), !dbg !8046 %.pre27 = bitcast {} addrspace(10)* addrspace(13)* %.pre26 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !8046 %.unpack.elt.phi.trans.insert = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre27, i64 %8, i64 0, i64 0 %.unpack.unpack.pre = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack.elt.phi.trans.insert, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430 %.not23 = icmp eq {} addrspace(10)* %.unpack.unpack.pre, null, !dbg !8046 br i1 %.not23, label %fail2, label %pass3, !dbg !8046 L75: ; preds = %pass3 call fastcc void @julia_throw_api_error_187593(i32 zeroext %18) #414, !dbg !8049 unreachable, !dbg !8049 L77: ; preds = %pass3 ret void, !dbg !8050 fail2: ; preds = %L49 %15 = load {}*, {}** @jl_undefref_exception, align 8, !dbg !8046, !tbaa !385, !alias.scope !437, !noalias !438, !nonnull !380 %16 = addrspacecast {}* %15 to {} addrspace(12)*, !dbg !8046 call void @ijl_throw({} addrspace(12)* %16), !dbg !8046 unreachable, !dbg !8046 pass3: ; preds = %L40, %L49 %nodecayed..pre-phi2834 = phi {} addrspace(10)* %nodecayedoff..pre-phi2834 = phi i64 %.pre-phi2834 = phi [3 x [2 x {} addrspace(10)*]] addrspace(13)* [ %.pre27, %L49 ], [ %12, %L40 ] %.unpack.unpack33 = phi {} addrspace(10)* [ %.unpack.unpack.pre, %L49 ], [ %14, %L40 ] %.unpack.elt14 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 0, i64 1, !dbg !8046 %.unpack.unpack15 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack.elt14, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430 %.unpack11.elt = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 1, i64 0, !dbg !8046 %.unpack11.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack11.elt, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430 %.unpack11.elt17 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 1, i64 1, !dbg !8046 %.unpack11.unpack18 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack11.elt17, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430 %.unpack13.elt = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 2, i64 0, !dbg !8046 %.unpack13.unpack = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack13.elt, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430 %.unpack13.elt20 = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]] addrspace(13)* %.pre-phi2834, i64 %8, i64 2, i64 1, !dbg !8046 %.unpack13.unpack21 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %.unpack13.elt20, align 8, !dbg !8046, !tbaa !4485, !alias.scope !429, !noalias !430 %.fca.0.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 0, i64 0, !dbg !8051 store {} addrspace(10)* %.unpack.unpack33, {} addrspace(10)** %.fca.0.0.gep, align 8, !dbg !8051, !noalias !566 %.fca.0.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 0, i64 1, !dbg !8051 store {} addrspace(10)* %.unpack.unpack15, {} addrspace(10)** %.fca.0.1.gep, align 8, !dbg !8051, !noalias !566 %.fca.1.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 1, i64 0, !dbg !8051 store {} addrspace(10)* %.unpack11.unpack, {} addrspace(10)** %.fca.1.0.gep, align 8, !dbg !8051, !noalias !566 %.fca.1.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 1, i64 1, !dbg !8051 store {} addrspace(10)* %.unpack11.unpack18, {} addrspace(10)** %.fca.1.1.gep, align 8, !dbg !8051, !noalias !566 %.fca.2.0.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 2, i64 0, !dbg !8051 store {} addrspace(10)* %.unpack13.unpack, {} addrspace(10)** %.fca.2.0.gep, align 8, !dbg !8051, !noalias !566 %.fca.2.1.gep = getelementptr inbounds [3 x [2 x {} addrspace(10)*]], [3 x [2 x {} addrspace(10)*]]* %1, i64 0, i64 2, i64 1, !dbg !8051 store {} addrspace(10)* %.unpack13.unpack21, {} addrspace(10)** %.fca.2.1.gep, align 8, !dbg !8051, !noalias !566 %17 = addrspacecast [3 x [2 x {} addrspace(10)*]]* %1 to [3 x [2 x {} addrspace(10)*]] addrspace(11)*, !dbg !8051 %18 = call fastcc i32 @julia_put__189354([3 x [2 x {} addrspace(10)*]] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(48) %17, {} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %0), !dbg !8051 %19 = icmp eq i32 %18, 0, !dbg !8052 br i1 %19, label %L77, label %L75, !dbg !8057 } Could not analyze garbage collection behavior of inst: %.pre-phi2834 = phi [3 x [2 x {} addrspace(10)*]] addrspace(13)* [ %.pre27, %L49 ], [ %12, %L40 ] v0: %.pre27 = bitcast {} addrspace(10)* addrspace(13)* %.pre26 to [3 x [2 x {} addrspace(10)*]] addrspace(13)*, !dbg !461 v: {} addrspace(10)*** inttoptr (i64 124712895855392 to {} addrspace(10)***) offset: i64 0 hasload: true Stacktrace: [1] #synchronize#1003 @ ~/.julia/packages/CUDA/1kIOw/lib/cudadrv/synchronization.jl:200 [2] multiple call sites @ unknown:0 Stacktrace: [1] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:969 [2] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:653 [3] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:682 [4] (::Enzyme.Compiler.var"#getparent#69"{…})(b::LLVM.IRBuilder, v::LLVM.Value, offset::LLVM.Value, hasload::Bool, phicache::Dict{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:818 [5] nodecayed_phis!(mod::LLVM.Module) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/llvm/transforms.jl:976 [6] optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler/optimize.jl:582 [7] nested_codegen!(mode::Enzyme.API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, world::UInt64) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:401 [8] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{…}, shadowR::Ptr{…}, tape::Nothing) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/customrules.jl:960 [9] enzyme_custom_augfwd @ ~/.julia/packages/Enzyme/R6sE8/src/rules/customrules.jl:1503 [inlined] [10] enzyme_custom_augfwd_cfunc(B::Ptr{…}, OrigCI::Ptr{…}, gutils::Ptr{…}, normalR::Ptr{…}, shadowR::Ptr{…}, tapeR::Ptr{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/rules/llvmrules.jl:18 [11] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool) @ Enzyme.API ~/.julia/packages/Enzyme/R6sE8/src/api.jl:268 [12] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:1706 [13] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4550 [14] codegen @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:3353 [inlined] [15] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410 [16] _thunk @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410 [inlined] [17] cached_compilation @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5462 [inlined] [18] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…}) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5573 [19] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type) @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5758 [20] autodiff @ ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:485 [inlined] [21] compute_gradients_impl(ad::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…}) @ LuxEnzymeExt ~/.julia/packages/Lux/DHtyL/ext/LuxEnzymeExt/training.jl:8 [22] compute_gradients @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:200 [inlined] [23] single_train_step_impl! @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:320 [inlined] [24] #single_train_step!#6 @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:288 [inlined] [25] single_train_step!(backend::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…}) @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:284 [26] main(dev::CUDADevice{…}, rng::Xoshiro, model::Chain{…}, ::Tuple{…}; nepochs::Int64, opt::Adam, backend::AutoEnzyme{…}) @ Main ~/src/lux_enzyme_test.jl:29 [27] main(dev::CUDADevice{Nothing}, rng::Xoshiro, model::Chain{@NamedTuple{…}, Nothing}, ::Tuple{CuArray{…}, CuArray{…}}) @ Main ~/src/lux_enzyme_test.jl:17 [28] macro expansion @ ./timing.jl:581 [inlined] [29] top-level scope @ ./REPL[2]:1 Some type information was truncated. Use `show(err)` to see complete types.
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Minimal Lux examples on GPU are still impractical due to excessive compile times, and occasional failures.
Here is a minimal example:
On the CPU I get
This is a bit worryingly slow since it is such a small example, but depending on how it scales it isn't necessarily prohibitive.
However, on GPU it tries to compile for about 10 to 15 minutes and then gives the following error
The text was updated successfully, but these errors were encountered: