diff --git a/src/JET.jl b/src/JET.jl index 02072f70a..3ca6be441 100644 --- a/src/JET.jl +++ b/src/JET.jl @@ -842,30 +842,32 @@ function report_text(text::AbstractString, return JETToplevelResult(analyzer′, res, source; analyzer, jetconfigs...) end +# we have to go on hacks; see `transform_abstract_global_symbols!` and `resolve_toplevel_symbols` function analyze_toplevel!(analyzer::AbstractAnalyzer, src::CodeInfo) # construct toplevel `MethodInstance` mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ()); - mi.uninferred = src mi.specTypes = Tuple{} - transform_abstract_global_symbols!(analyzer, src) - mi.def = get_toplevelmod(analyzer) + mi.def = mod = get_toplevelmod(analyzer) + src = transform_abstract_global_symbols!(analyzer, src) + src = resolve_toplevel_symbols(mod, src) + mi.uninferred = src result = InferenceResult(mi); - # toplevel frame doesn't need to be cached (and so it won't be optimized), nor should - # go through JET's code generation error check - frame = InferenceState(result, src, #=cached=# false, analyzer); + # toplevel frames don't really need to be cached, but still better to be optimized + # in order to get reasonable `LocalUndefVarErrorReport` and `UncaughtExceptionReport` + frame = InferenceState(result, src, #=cached=# true, analyzer); return analyze_frame!(analyzer, frame) end -# HACK this is an native hack to re-use `AbstractInterpreter`'s approximated slot types for +# HACK this is very naive hack to re-use `AbstractInterpreter`'s slot type approximation for # assignments of abstract global variables, which are represented as toplevel symbols at this point; -# the idea is just to transform them into slots from symbols and use their approximated type -# on their assignment. +# the idea is just to transform them into slot from symbol and use their approximated type +# on their assignment (see `finish(::InferenceState, ::AbstractAnalyzer)`). # NOTE that `transform_abstract_global_symbols!` will produce really invalid code for # actual interpretation or execution, but all the statements won't be interpreted anymore -# by `ConcreteInterpreter` nor executed anyway since toplevel frames aren't cached +# by `ConcreteInterpreter` nor executed by the native compilation pipeline anyway function transform_abstract_global_symbols!(analyzer::AbstractAnalyzer, src::CodeInfo) nslots = length(src.slotnames) abstrct_global_variables = Dict{Symbol,Int}() @@ -901,6 +903,36 @@ function transform_abstract_global_symbols!(analyzer::AbstractAnalyzer, src::Cod end set_global_slots!(analyzer, Dict(idx => slotname for (slotname, idx) in abstrct_global_variables)) + + return src +end + +# resolve toplevel symbols (and other expressions like `:foreigncall`) +# so that the returned `CodeInfo` is eligible for abstractintepret and optimization +@static if VERSION ≥ v"1.8.0-DEV.420" + function resolve_toplevel_symbols(mod::Module, src::CodeInfo) + newsrc = copy(src) + @ccall jl_resolve_globals_in_ir(newsrc.code::Any, mod::Any, svec()::Any, 1::Any)::Cvoid + return newsrc + end +else + # HACK before https://github.com/JuliaLang/julia/pull/42013, we need to go through + # the method definition pipeline to get the effect of `jl_resolve_globals_in_ir` + function resolve_toplevel_symbols(mod::Module, src::CodeInfo) + sig = Core.svec( + svec(typeof(__toplevelf__)), + svec(), + QuoteNode(LineNumberNode(@__LINE__, @__FILE__))) + # branching on https://github.com/JuliaLang/julia/pull/41137 + method = (@static if isdefined(Core.Compiler, :OverlayMethodTable) + ccall(:jl_method_def, Any, (Any, Ptr{Cvoid}, Any, Any), sig, C_NULL, src, mod) + else + ccall(:jl_method_def, Cvoid, (Any, Any, Any), sig, src, mod) + only(methods(__toplevelf__)) + end)::Method + return CC.uncompressed_ir(method) + end + function __toplevelf__ end end # TODO `analyze_builtin!` ? diff --git a/src/abstractinterpretation.jl b/src/abstractinterpretation.jl index 880f5b034..ef24a4841 100644 --- a/src/abstractinterpretation.jl +++ b/src/abstractinterpretation.jl @@ -385,9 +385,6 @@ function CC.abstract_eval_special_value(analyzer::AbstractAnalyzer, @nospecializ # if it's really not defined, the error will be generated later anyway e = GlobalRef(get_toplevelmod(analyzer), get_slotname(sv, e)) end - elseif isa(e, Symbol) - # (already concretized) toplevel global symbols - e = GlobalRef(get_toplevelmod(analyzer), e) end end @@ -749,7 +746,7 @@ function is_constant_declared(name::Symbol, sv::InferenceState) return any(sv.src.code) do @nospecialize(x) if @isexpr(x, :const) arg = first(x.args) - # `transform_global_symbols!` replaces all the global symbols in this toplevel frame with `Slot`s + # `transform_abstract_global_symbols!` replaces all the global symbols in this toplevel frame with `Slot`s if isa(arg, Slot) return get_slotname(sv, arg) === name end diff --git a/src/analyzer.jl b/src/analyzer.jl index 6f386d288..e55eded0d 100644 --- a/src/analyzer.jl +++ b/src/analyzer.jl @@ -482,8 +482,7 @@ function maybe_initialize_caches!(analyzer::AbstractAnalyzer) end # check if we're in a toplevel module -@inline istoplevel(sv::InferenceState) = istoplevel(sv.linfo) -@inline istoplevel(::OptimizationState) = false # optimization never happen for top-level code +@inline istoplevel(sv::State) = istoplevel(sv.linfo) @inline istoplevel(linfo::MethodInstance) = isa(linfo.def, Module) is_global_slot(analyzer::AbstractAnalyzer, slot::Int) = slot in keys(get_global_slots(analyzer)) diff --git a/src/locinfo.jl b/src/locinfo.jl index 23748ada2..b60b5d89e 100644 --- a/src/locinfo.jl +++ b/src/locinfo.jl @@ -160,15 +160,7 @@ function _get_sig_type((sv, _)::StateAtPC, arg::Argument) return Any[sig, typ], typ end _get_sig_type(_::StateAtPC, gr::GlobalRef) = Any[string(gr.mod, '.', gr.name)], nothing -function _get_sig_type(s::StateAtPC, name::Symbol) - sv = first(s) - if istoplevel(sv) - # this is concrete global variable, form the global reference - return _get_sig_type(s, GlobalRef(sv.linfo.def, name)) - else - return Any[repr(name; context = :compact => true)], nothing - end -end +_get_sig_type(_::StateAtPC, name::Symbol) = Any[repr(name; context = :compact => true)], nothing function _get_sig_type(s::StateAtPC, gotoifnot::GotoIfNot) sig = Any[string("goto %", gotoifnot.dest, " if not "), _get_sig(s, gotoifnot.cond)...] return sig, nothing diff --git a/src/typeinfer.jl b/src/typeinfer.jl index 42670b01b..c7eae0c37 100644 --- a/src/typeinfer.jl +++ b/src/typeinfer.jl @@ -287,7 +287,8 @@ function (::SoundBasicPass)(::Type{UncaughtExceptionReport}, analyzer::AbstractA throw_locs = get_throw_locs(analyzer) throw_calls = Tuple{Int,Expr}[] for (pc, stmt) in enumerate(stmts) - is_throw_call_expr(analyzer, frame, stmt) || continue + isa(stmt, Expr) || continue + is_throw_call(stmt) || continue # if this `throw` is already reported, don't duplciate linetable[codelocs[pc]]::LineInfoNode in throw_locs && continue push!(throw_calls, (pc, stmt)) @@ -303,22 +304,3 @@ function (::SoundBasicPass)(::Type{UncaughtExceptionReport}, analyzer::AbstractA empty!(get_uncaught_exceptions(analyzer)) end end - -# basically same as `is_throw_call`, but also toplevel module handling added -function is_throw_call_expr(analyzer::AbstractAnalyzer, frame::InferenceState, @nospecialize(e)) - if isa(e, Expr) - if e.head === :call - f = e.args[1] - if istoplevel(frame) && isa(f, Symbol) - f = GlobalRef(get_toplevelmod(analyzer), f) - end - if isa(f, GlobalRef) - ff = CC.abstract_eval_global(f.mod, f.name) - if isa(ff, Const) && ff.val === Core.throw - return true - end - end - end - end - return false -end diff --git a/test/test_abstractinterpretation.jl b/test/test_abstractinterpretation.jl index c1994d2cd..dd4d0189d 100644 --- a/test/test_abstractinterpretation.jl +++ b/test/test_abstractinterpretation.jl @@ -134,9 +134,7 @@ end @test isempty(get_reports(analyzer)) end - # with the current approach, local undefined variables in toplevel frame can't be found - # since we don't cache toplevel frame and thus it won't be optimized - let + let # should work for top-level analysis res = @analyze_toplevel begin foo = let if rand(Bool) @@ -146,7 +144,9 @@ end end end end - @test_broken !isempty(res.inference_error_reports) + @test length(res.inference_error_reports) === 1 && + first(res.inference_error_reports) isa LocalUndefVarErrorReport && + first(res.inference_error_reports).name === :bar end end