Skip to content

Commit

Permalink
Merge branch 'master' into any_all_vectorized_tuple_bool
Browse files Browse the repository at this point in the history
  • Loading branch information
nsajko authored Jan 22, 2025
2 parents d2543e2 + f91436e commit 987b82e
Show file tree
Hide file tree
Showing 23 changed files with 184 additions and 92 deletions.
15 changes: 13 additions & 2 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
sig_n === Bottom && continue
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::MethodTable
Expand Down Expand Up @@ -614,7 +615,7 @@ function abstract_call_method(interp::AbstractInterpreter,
sigtuple = unwrap_unionall(sig)
sigtuple isa DataType ||
return Future(MethodCallResult(Any, Any, Effects(), nothing, false, false))
all(@nospecialize(x) -> valid_as_lattice(unwrapva(x), true), sigtuple.parameters) ||
all(@nospecialize(x) -> isvarargtype(x) || valid_as_lattice(x, true), sigtuple.parameters) ||
return Future(MethodCallResult(Union{}, Any, EFFECTS_THROWS, nothing, false, false)) # catch bad type intersections early

if is_nospecializeinfer(method)
Expand Down Expand Up @@ -2840,6 +2841,7 @@ function abstract_call_unknown(interp::AbstractInterpreter, @nospecialize(ft),
end
# non-constant function, but the number of arguments is known and the `f` is not a builtin or intrinsic
atype = argtypes_to_type(arginfo.argtypes)
atype === Bottom && return Future(CallMeta(Union{}, Union{}, EFFECTS_THROWS, NoCallInfo())) # accidentally unreachable
return abstract_call_gf_by_type(interp, nothing, arginfo, si, atype, sv, max_methods)::Future
end

Expand Down Expand Up @@ -3785,14 +3787,23 @@ function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState,
slottypes = frame.slottypes
rt = widenreturn(rt, BestguessInfo(interp, bestguess, nargs, slottypes, currstate))
# narrow representation of bestguess slightly to prepare for tmerge with rt
if rt isa InterConditional && bestguess isa Const
if rt isa InterConditional && bestguess isa Const && bestguess.val isa Bool
slot_id = rt.slot
old_id_type = widenconditional(slottypes[slot_id])
if bestguess.val === true && rt.elsetype !== Bottom
bestguess = InterConditional(slot_id, old_id_type, Bottom)
elseif bestguess.val === false && rt.thentype !== Bottom
bestguess = InterConditional(slot_id, Bottom, old_id_type)
end
# or narrow representation of rt slightly to prepare for tmerge with bestguess
elseif bestguess isa InterConditional && rt isa Const && rt.val isa Bool
slot_id = bestguess.slot
old_id_type = widenconditional(slottypes[slot_id])
if rt.val === true && bestguess.elsetype !== Bottom
rt = InterConditional(slot_id, old_id_type, Bottom)
elseif rt.val === false && bestguess.thentype !== Bottom
rt = InterConditional(slot_id, Bottom, old_id_type)
end
end
# copy limitations to return value
if !isempty(frame.pclimitations)
Expand Down
1 change: 1 addition & 0 deletions Compiler/src/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,7 @@ function handle_call!(todo::Vector{Pair{Int,Any}},
cases === nothing && return nothing
cases, handled_all_cases, fully_covered, joint_effects = cases
atype = argtypes_to_type(sig.argtypes)
atype === Union{} && return nothing # accidentally actually unreachable
handle_cases!(todo, ir, idx, stmt, atype, cases, handled_all_cases, fully_covered, joint_effects)
end

Expand Down
36 changes: 20 additions & 16 deletions Compiler/src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3016,24 +3016,28 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
isvarargtype(argtypes[2]) && return Future(CallMeta(Bool, ArgumentError, EFFECTS_THROWS, NoCallInfo()))
argtypes = argtypes[2:end]
atype = argtypes_to_type(argtypes)
matches = find_method_matches(interp, argtypes, atype; max_methods)
info = NoCallInfo()
if isa(matches, FailedMethodMatch)
rt = Bool # too many matches to analyze
if atype === Union{}
rt = Union{} # accidentally unreachable code
else
(; valid_worlds, applicable) = matches
update_valid_age!(sv, valid_worlds)
napplicable = length(applicable)
if napplicable == 0
rt = Const(false) # never any matches
elseif !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
rt = Bool
matches = find_method_matches(interp, argtypes, atype; max_methods)
info = NoCallInfo()
if isa(matches, FailedMethodMatch)
rt = Bool # too many matches to analyze
else
rt = Const(true) # has applicable matches
end
if rt !== Bool
info = VirtualMethodMatchInfo(matches.info)
(; valid_worlds, applicable) = matches
update_valid_age!(sv, valid_worlds)
napplicable = length(applicable)
if napplicable == 0
rt = Const(false) # never any matches
elseif !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
rt = Bool
else
rt = Const(true) # has applicable matches
end
if rt !== Bool
info = VirtualMethodMatchInfo(matches.info)
end
end
end
return Future(CallMeta(rt, Union{}, EFFECTS_TOTAL, info))
Expand Down
7 changes: 7 additions & 0 deletions Compiler/src/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ end
end
a = Bool
elseif isa(b, ConditionalT)
if isa(a, Const) && isa(a.val, Bool)
if (a.val === true && b.thentype === Any && b.elsetype === Bottom) ||
(a.val === false && b.elsetype === Any && b.thentype === Bottom)
# this Conditional contains distinctly no lattice information, and is simply an alternative representation of the Const Bool used for internal tracking purposes
return true
end
end
return false
end
return (widenlattice(lattice), a, b)
Expand Down
7 changes: 6 additions & 1 deletion Compiler/src/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ has_extended_info(@nospecialize x) = (!isa(x, Type) && !isvarargtype(x)) || isTy
# certain combinations of `a` and `b` where one/both isa/are `Union`/`UnionAll` type(s)s.
isnotbrokensubtype(@nospecialize(a), @nospecialize(b)) = (!iskindtype(b) || !isType(a) || hasuniquerep(a.parameters[1]) || b <: a)

argtypes_to_type(argtypes::Array{Any,1}) = Tuple{anymap(@nospecialize(a) -> isvarargtype(a) ? a : widenconst(a), argtypes)...}
function argtypes_to_type(argtypes::Array{Any,1})
argtypes = anymap(@nospecialize(a) -> isvarargtype(a) ? a : widenconst(a), argtypes)
filter!(@nospecialize(x) -> !isvarargtype(x) || valid_as_lattice(unwrapva(x), true), argtypes)
all(@nospecialize(x) -> isvarargtype(x) || valid_as_lattice(x, true), argtypes) || return Bottom
return Tuple{argtypes...}
end

function isknownlength(t::DataType)
isvatuple(t) || return true
Expand Down
26 changes: 26 additions & 0 deletions Compiler/test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,26 @@ vacond(cnd, va...) = cnd ? va : 0
vacond(isa(x, Tuple{Int,Int}), x, x)
end |> only == Union{Int,Tuple{Any,Any}}

let A = Core.Const(true)
B = Core.InterConditional(2, Tuple, Union{})
C = Core.InterConditional(2, Any, Union{})
L = ipo_lattice(Compiler.NativeInterpreter())
@test !⊑(L, A, B)
@test ⊑(L, B, A)
@test tmerge(L, A, B) == C
@test ⊑(L, A, C)
end
function tail_is_ntuple((@nospecialize t::Tuple))
if unknown
t isa Tuple
else
tail_is_ntuple(t)
end
end
tail_is_ntuple_val((@nospecialize t::Tuple)) = Val(tail_is_ntuple(t))
@test Base.return_types(tail_is_ntuple, (Tuple,)) |> only === Bool
@test Base.return_types(tail_is_ntuple_val, (Tuple,)) |> only === Val{true}

# https://github.com/JuliaLang/julia/issues/47435
is_closed_ex(e::InvalidStateException) = true
is_closed_ex(e) = false
Expand Down Expand Up @@ -6162,3 +6182,9 @@ end <: Any
end
return out
end == Union{Float64,DomainError}

# issue #56628
@test Compiler.argtypes_to_type(Any[ Int, UnitRange{Int}, Vararg{Pair{Any, Union{}}} ]) === Tuple{Int, UnitRange{Int}}
@test Compiler.argtypes_to_type(Any[ Int, UnitRange{Int}, Vararg{Pair{Any, Union{}}}, Float64 ]) === Tuple{Int, UnitRange{Int}, Float64}
@test Compiler.argtypes_to_type(Any[ Int, UnitRange{Int}, Vararg{Pair{Any, Union{}}}, Float64, Memory{2} ]) === Union{}
@test Base.return_types(Tuple{Tuple{Int, Vararg{Pair{Any, Union{}}}}},) do x; Returns(true)(x...); end |> only === Bool
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ New library functions
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`. ([#45793])
* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]).
* `Sys.detectwsl()` allows to testing if Julia is running inside WSL at runtime. ([#57069])

New library features
--------------------
Expand Down
14 changes: 6 additions & 8 deletions base/floatfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ significantly more expensive than `x*y+z`. `fma` is used to improve accuracy in
algorithms. See [`muladd`](@ref).
"""
function fma end
function fma_emulated(a::Float16, b::Float16, c::Float16)
Float16(muladd(Float32(a), Float32(b), Float32(c))) #don't use fma if the hardware doesn't have it.
end
function fma_emulated(a::Float32, b::Float32, c::Float32)::Float32
ab = Float64(a) * b
res = ab+c
Expand Down Expand Up @@ -348,19 +351,14 @@ function fma_emulated(a::Float64, b::Float64,c::Float64)
s = (abs(abhi) > abs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo)
return r+s
end
fma_llvm(x::Float32, y::Float32, z::Float32) = fma_float(x, y, z)
fma_llvm(x::Float64, y::Float64, z::Float64) = fma_float(x, y, z)

# Disable LLVM's fma if it is incorrect, e.g. because LLVM falls back
# onto a broken system libm; if so, use a software emulated fma
@assume_effects :consistent fma(x::Float32, y::Float32, z::Float32) = Core.Intrinsics.have_fma(Float32) ? fma_llvm(x,y,z) : fma_emulated(x,y,z)
@assume_effects :consistent fma(x::Float64, y::Float64, z::Float64) = Core.Intrinsics.have_fma(Float64) ? fma_llvm(x,y,z) : fma_emulated(x,y,z)

function fma(a::Float16, b::Float16, c::Float16)
Float16(muladd(Float32(a), Float32(b), Float32(c))) #don't use fma if the hardware doesn't have it.
@assume_effects :consistent function fma(x::T, y::T, z::T) where {T<:IEEEFloat}
Core.Intrinsics.have_fma(T) ? fma_float(x,y,z) : fma_emulated(x,y,z)
end

# This is necessary at least on 32-bit Intel Linux, since fma_llvm may
# This is necessary at least on 32-bit Intel Linux, since fma_float may
# have called glibc, and some broken glibc fma implementations don't
# properly restore the rounding mode
Rounding.setrounding_raw(Float32, Rounding.JL_FE_TONEAREST)
Expand Down
24 changes: 23 additions & 1 deletion base/sysinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ export BINDIR,
isreadable,
iswritable,
username,
which
which,
detectwsl

import ..Base: show

Expand Down Expand Up @@ -532,6 +533,27 @@ including e.g. a WebAssembly JavaScript embedding in a web browser.
"""
isjsvm(os::Symbol) = (os === :Emscripten)

"""
Sys.detectwsl()
Runtime predicate for testing if Julia is running inside
Windows Subsystem for Linux (WSL).
!!! note
Unlike `Sys.iswindows`, `Sys.islinux` etc., this is a runtime test, and thus
cannot meaningfully be used in `@static if` constructs.
!!! compat "Julia 1.12"
This function requires at least Julia 1.12.
"""
function detectwsl()
# We use the same approach as canonical/snapd do to detect WSL
islinux() && (
isfile("/proc/sys/fs/binfmt_misc/WSLInterop")
|| isdir("/run/WSL")
)
end

for f in (:isunix, :islinux, :isbsd, :isapple, :iswindows, :isfreebsd, :isopenbsd, :isnetbsd, :isdragonfly, :isjsvm)
@eval $f() = $(getfield(@__MODULE__, f)(KERNEL))
end
Expand Down
2 changes: 1 addition & 1 deletion src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2223,7 +2223,7 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, jl_
output.imaging_mode = jl_options.image_codegen;
output.temporary_roots = jl_alloc_array_1d(jl_array_any_type, 0);
JL_GC_PUSH1(&output.temporary_roots);
auto decls = jl_emit_code(m, mi, src, NULL, output);
auto decls = jl_emit_code(m, mi, src, mi->specTypes, src->rettype, output);
output.temporary_roots = nullptr;
JL_GC_POP(); // GC the global_targets array contents now since reflection doesn't need it

Expand Down
17 changes: 8 additions & 9 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4381,7 +4381,7 @@ static jl_llvm_functions_t
jl_method_instance_t *lam,
jl_code_info_t *src,
jl_value_t *abi,
jl_value_t *rettype,
jl_value_t *jlrettype,
jl_codegen_params_t &params);

static void emit_hasnofield_error_ifnot(jl_codectx_t &ctx, Value *ok, jl_datatype_t *type, jl_cgval_t name);
Expand Down Expand Up @@ -5533,12 +5533,12 @@ static jl_value_t *get_ci_abi(jl_code_instance_t *ci)
return jl_get_ci_mi(ci)->specTypes;
}

static jl_cgval_t emit_call_specfun_other(jl_codectx_t &ctx, jl_code_instance_t *ci, jl_value_t *jlretty, StringRef specFunctionObject, jl_code_instance_t *fromexternal,
static jl_cgval_t emit_call_specfun_other(jl_codectx_t &ctx, jl_code_instance_t *ci, StringRef specFunctionObject, jl_code_instance_t *fromexternal,
ArrayRef<jl_cgval_t> argv, size_t nargs, jl_returninfo_t::CallingConv *cc, unsigned *return_roots, jl_value_t *inferred_retty, Value *age_ok)
{
jl_method_instance_t *mi = jl_get_ci_mi(ci);
bool is_opaque_closure = jl_is_method(mi->def.value) && mi->def.method->is_for_opaque_closure;
return emit_call_specfun_other(ctx, is_opaque_closure, get_ci_abi(ci), jlretty, NULL,
return emit_call_specfun_other(ctx, is_opaque_closure, get_ci_abi(ci), ci->rettype, NULL,
specFunctionObject, fromexternal, argv, nargs, cc, return_roots, inferred_retty, age_ok);
}

Expand Down Expand Up @@ -5688,7 +5688,7 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, ArrayR
jl_returninfo_t::CallingConv cc = jl_returninfo_t::CallingConv::Boxed;
unsigned return_roots = 0;
if (specsig)
result = emit_call_specfun_other(ctx, codeinst, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, &cc, &return_roots, rt, age_ok);
result = emit_call_specfun_other(ctx, codeinst, protoname, external ? codeinst : nullptr, argv, nargs, &cc, &return_roots, rt, age_ok);
else
result = emit_call_specfun_boxed(ctx, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, rt, age_ok);
if (need_to_emit) {
Expand Down Expand Up @@ -10029,7 +10029,8 @@ jl_llvm_functions_t jl_emit_code(
orc::ThreadSafeModule &m,
jl_method_instance_t *li,
jl_code_info_t *src,
jl_value_t *abi,
jl_value_t *abi_at,
jl_value_t *abi_rt,
jl_codegen_params_t &params)
{
JL_TIMING(CODEGEN, CODEGEN_LLVM);
Expand All @@ -10038,10 +10039,8 @@ jl_llvm_functions_t jl_emit_code(
assert((params.params == &jl_default_cgparams /* fast path */ || !params.cache ||
compare_cgparams(params.params, &jl_default_cgparams)) &&
"functions compiled with custom codegen params must not be cached");
if (!abi)
abi = li->specTypes;
JL_TRY {
decls = emit_function(m, li, src, abi, src->rettype, params);
decls = emit_function(m, li, src, abi_at, abi_rt, params);
auto stream = *jl_ExecutionEngine->get_dump_emitted_mi_name_stream();
if (stream) {
jl_printf(stream, "%s\t", decls.specFunctionObject.c_str());
Expand Down Expand Up @@ -10112,7 +10111,7 @@ jl_llvm_functions_t jl_emit_codeinst(
return jl_llvm_functions_t(); // user error
}
//assert(jl_egal((jl_value_t*)jl_atomic_load_relaxed(&codeinst->debuginfo), (jl_value_t*)src->debuginfo) && "trying to generate code for a codeinst for an incompatible src");
jl_llvm_functions_t decls = jl_emit_code(m, jl_get_ci_mi(codeinst), src, get_ci_abi(codeinst), params);
jl_llvm_functions_t decls = jl_emit_code(m, jl_get_ci_mi(codeinst), src, get_ci_abi(codeinst), codeinst->rettype, params);
return decls;
}

Expand Down
6 changes: 3 additions & 3 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -3034,19 +3034,19 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
return codeinst;
}

jl_value_t *jl_fptr_const_return(jl_value_t *f, jl_value_t **args, uint32_t nargs, jl_code_instance_t *m)
JL_DLLEXPORT jl_value_t *jl_fptr_const_return(jl_value_t *f, jl_value_t **args, uint32_t nargs, jl_code_instance_t *m)
{
return m->rettype_const;
}

jl_value_t *jl_fptr_args(jl_value_t *f, jl_value_t **args, uint32_t nargs, jl_code_instance_t *m)
JL_DLLEXPORT jl_value_t *jl_fptr_args(jl_value_t *f, jl_value_t **args, uint32_t nargs, jl_code_instance_t *m)
{
jl_fptr_args_t invoke = jl_atomic_load_relaxed(&m->specptr.fptr1);
assert(invoke && "Forgot to set specptr for jl_fptr_args!");
return invoke(f, args, nargs);
}

jl_value_t *jl_fptr_sparam(jl_value_t *f, jl_value_t **args, uint32_t nargs, jl_code_instance_t *m)
JL_DLLEXPORT jl_value_t *jl_fptr_sparam(jl_value_t *f, jl_value_t **args, uint32_t nargs, jl_code_instance_t *m)
{
jl_svec_t *sparams = jl_get_ci_mi(m)->sparam_vals;
assert(sparams != jl_emptysvec);
Expand Down
3 changes: 2 additions & 1 deletion src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ jl_llvm_functions_t jl_emit_code(
orc::ThreadSafeModule &M,
jl_method_instance_t *mi,
jl_code_info_t *src,
jl_value_t *abi,
jl_value_t *abi_at,
jl_value_t *abi_rt,
jl_codegen_params_t &params);

jl_llvm_functions_t jl_emit_codeinst(
Expand Down
1 change: 0 additions & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,6 @@
XX(jl_tagged_gensym) \
XX(jl_take_buffer) \
XX(jl_task_get_next) \
XX(jl_task_stack_buffer) \
XX(jl_termios_size) \
XX(jl_test_cpu_feature) \
XX(jl_threadid) \
Expand Down
9 changes: 0 additions & 9 deletions src/julia_gcext.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,6 @@ JL_DLLEXPORT int jl_gc_conservative_gc_support_enabled(void);
// NOTE: Only valid to call from within a GC context.
JL_DLLEXPORT jl_value_t *jl_gc_internal_obj_base_ptr(void *p) JL_NOTSAFEPOINT;

// Return a non-null pointer to the start of the stack area if the task
// has an associated stack buffer. In that case, *size will also contain
// the size of that stack buffer upon return. Also, if task is a thread's
// current task, that thread's id will be stored in *tid; otherwise,
// *tid will be set to -1.
//
// DEPRECATED: use jl_active_task_stack() instead.
JL_DLLEXPORT void *jl_task_stack_buffer(jl_task_t *task, size_t *size, int *tid);

// Query the active and total stack range for the given task, and set
// *active_start and *active_end respectively *total_start and *total_end
// accordingly. The range for the active part is a best-effort approximation
Expand Down
Loading

0 comments on commit 987b82e

Please sign in to comment.