diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 4d10507b834f3..e110482ff9ac5 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -13,44 +13,26 @@ const _REF_NAME = Ref.body.name # See if the inference result of the current statement's result value might affect # the final answer for the method (aside from optimization potential and exceptions). # To do that, we need to check both for slot assignment and SSA usage. -call_result_unused(frame::InferenceState, currpc::Int) = - isexpr(frame.src.code[currpc], :call) && isempty(frame.ssavalue_uses[currpc]) +call_result_unused(sv::InferenceState, currpc::Int) = + isexpr(sv.src.code[currpc], :call) && isempty(sv.ssavalue_uses[currpc]) call_result_unused(si::StmtInfo) = !si.used -function get_max_methods(mod::Module, interp::AbstractInterpreter) - max_methods = ccall(:jl_get_module_max_methods, Cint, (Any,), mod) % Int - max_methods < 0 ? InferenceParams(interp).max_methods : max_methods +function get_max_methods(sv::AbsIntState, interp::AbstractInterpreter) + max_methods = ccall(:jl_get_module_max_methods, Cint, (Any,), frame_module(sv)) % Int + return max_methods < 0 ? InferenceParams(interp).max_methods : max_methods end -function get_max_methods(@nospecialize(f), mod::Module, interp::AbstractInterpreter) +function get_max_methods(@nospecialize(f), sv::AbsIntState, interp::AbstractInterpreter) if f !== nothing fmm = typeof(f).name.max_methods fmm !== UInt8(0) && return Int(fmm) end - return get_max_methods(mod, interp) -end - -function should_infer_this_call(interp::AbstractInterpreter, sv::InferenceState) - if InferenceParams(interp).unoptimize_throw_blocks - # Disable inference of calls in throw blocks, since we're unlikely to - # need their types. There is one exception however: If up until now, the - # function has not seen any side effects, we would like to make sure there - # aren't any in the throw block either to enable other optimizations. - if is_stmt_throw_block(get_curr_ssaflag(sv)) - should_infer_for_effects(sv) || return false - end - end - return true -end - -function should_infer_for_effects(sv::InferenceState) - effects = sv.ipo_effects - return is_terminates(effects) && is_effect_free(effects) + return get_max_methods(sv, interp) end function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), - sv::InferenceState, max_methods::Int) + sv::AbsIntState, max_methods::Int) โŠ‘โ‚š = โŠ‘(ipo_lattice(interp)) if !should_infer_this_call(interp, sv) add_remark!(interp, sv, "Skipped call in throw block") @@ -178,7 +160,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), @assert !(this_conditional isa Conditional || this_rt isa MustAlias) "invalid lattice element returned from inter-procedural context" seen += 1 rettype = tmerge(๐•ƒโ‚š, rettype, this_rt) - if has_conditional(๐•ƒโ‚š) && this_conditional !== Bottom && is_lattice_bool(๐•ƒโ‚š, rettype) && fargs !== nothing + if has_conditional(๐•ƒโ‚š, sv) && this_conditional !== Bottom && is_lattice_bool(๐•ƒโ‚š, rettype) && fargs !== nothing if conditionals === nothing conditionals = Any[Bottom for _ in 1:length(argtypes)], Any[Bottom for _ in 1:length(argtypes)] @@ -214,7 +196,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # Also considering inferring the compilation signature for this method, so # it is available to the compiler in case it ends up needing it. - if infer_compilation_signature(interp) && 1 == seen == napplicable && rettype !== Any && rettype !== Union{} && !is_removable_if_unused(all_effects) + if (isa(sv, InferenceState) && infer_compilation_signature(interp) && + (1 == seen == napplicable) && rettype !== Any && rettype !== Bottom && + !is_removable_if_unused(all_effects)) match = applicable[1]::MethodMatch method = match.method sig = match.spec_types @@ -238,10 +222,16 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), rettype = Any end add_call_backedges!(interp, rettype, all_effects, edges, matches, atype, sv) - if !isempty(sv.pclimitations) # remove self, if present - delete!(sv.pclimitations, sv) - for caller in sv.callers_in_cycle - delete!(sv.pclimitations, caller) + if isa(sv, InferenceState) + # TODO (#48913) implement a proper recursion handling for irinterp: + # This works just because currently the `:terminate` condition guarantees that + # irinterp doesn't fail into unresolved cycles, but it's not a good solution. + # We should revisit this once we have a better story for handling cycles in irinterp. + if !isempty(sv.pclimitations) # remove self, if present + delete!(sv.pclimitations, sv) + for caller in callers_in_cycle(sv) + delete!(sv.pclimitations, caller) + end end end return CallMeta(rettype, all_effects, info) @@ -349,7 +339,7 @@ function find_matching_methods(๐•ƒ::AbstractLattice, end """ - from_interprocedural!(๐•ƒโ‚š::AbstractLattice, rt, sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> newrt + from_interprocedural!(๐•ƒโ‚š::AbstractLattice, rt, sv::AbsIntState, arginfo::ArgInfo, maybecondinfo) -> newrt Converts inter-procedural return type `rt` into a local lattice element `newrt`, that is appropriate in the context of current local analysis frame `sv`, especially: @@ -368,7 +358,7 @@ In such cases `maybecondinfo` should be either of: When we deal with multiple `MethodMatch`es, it's better to precompute `maybecondinfo` by `tmerge`ing argument signature type of each method call. """ -function from_interprocedural!(๐•ƒโ‚š::AbstractLattice, @nospecialize(rt), sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo)) +function from_interprocedural!(๐•ƒโ‚š::AbstractLattice, @nospecialize(rt), sv::AbsIntState, arginfo::ArgInfo, @nospecialize(maybecondinfo)) rt = collect_limitations!(rt, sv) if isa(rt, InterMustAlias) rt = from_intermustalias(rt, arginfo) @@ -407,11 +397,13 @@ function from_intermustalias(rt::InterMustAlias, arginfo::ArgInfo) return widenmustalias(rt) end -function from_interconditional(๐•ƒโ‚š::AbstractLattice, @nospecialize(typ), - sv::InferenceState, (; fargs, argtypes)::ArgInfo, @nospecialize(maybecondinfo)) - ๐•ƒ = widenlattice(๐•ƒโ‚š) - has_conditional(๐•ƒโ‚š) || return widenconditional(typ) +function from_interconditional(๐•ƒโ‚š::AbstractLattice, + typ, sv::AbsIntState, arginfo::ArgInfo, maybecondinfo) + @nospecialize typ maybecondinfo + has_conditional(๐•ƒโ‚š, sv) || return widenconditional(typ) + (; fargs, argtypes) = arginfo fargs === nothing && return widenconditional(typ) + ๐•ƒ = widenlattice(๐•ƒโ‚š) slot = 0 alias = nothing thentype = elsetype = Any @@ -505,7 +497,7 @@ end function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), all_effects::Effects, edges::Vector{MethodInstance}, matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype), - sv::InferenceState) + sv::AbsIntState) # don't bother to add backedges when both type and effects information are already # maximized to the top since a new method couldn't refine or widen them anyway if rettype === Any @@ -515,7 +507,8 @@ function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype) all_effects = Effects(all_effects; nonoverlayed=false) end if (# ignore the `:noinbounds` property if `:consistent`-cy is tainted already - sv.ipo_effects.consistent === ALWAYS_FALSE || all_effects.consistent === ALWAYS_FALSE || + (sv isa InferenceState && sv.ipo_effects.consistent === ALWAYS_FALSE) || + all_effects.consistent === ALWAYS_FALSE || # or this `:noinbounds` doesn't taint it !stmt_taints_inbounds_consistency(sv)) all_effects = Effects(all_effects; noinbounds=false) @@ -541,7 +534,9 @@ const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Ann const RECURSION_MSG = "Bounded recursion detected. Call was widened to force convergence." const RECURSION_MSG_HARDLIMIT = "Bounded recursion detected under hardlimit. Call was widened to force convergence." -function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, si::StmtInfo, sv::InferenceState) +function abstract_call_method(interp::AbstractInterpreter, + method::Method, @nospecialize(sig), sparams::SimpleVector, + hardlimit::Bool, si::StmtInfo, sv::AbsIntState) if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base add_remark!(interp, sv, "Refusing to infer into `depwarn`") return MethodCallResult(Any, false, false, nothing, Effects()) @@ -554,9 +549,10 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp edgecycle = edgelimited = false topmost = nothing - for infstate in InfStackUnwind(sv) - if method === infstate.linfo.def - if infstate.linfo.specTypes::Type == sig::Type + for svโ€ฒ in AbsIntStackUnwind(sv) + infmi = frame_instance(svโ€ฒ) + if method === infmi.def + if infmi.specTypes::Type == sig::Type # avoid widening when detecting self-recursion # TODO: merge call cycle and return right away if call_result_unused(si) @@ -572,8 +568,8 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp break end topmost === nothing || continue - if edge_matches_sv(interp, infstate, method, sig, sparams, hardlimit, sv) - topmost = infstate + if edge_matches_sv(interp, svโ€ฒ, method, sig, sparams, hardlimit, sv) + topmost = svโ€ฒ edgecycle = true end end @@ -585,11 +581,12 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp msig = unwrap_unionall(method.sig)::DataType spec_len = length(msig.parameters) + 1 ls = length(sigtuple.parameters) + mi = frame_instance(sv) - if method === sv.linfo.def + if method === mi.def # Under direct self-recursion, permit much greater use of reducers. # here we assume that complexity(specTypes) :>= complexity(sig) - comparison = sv.linfo.specTypes + comparison = mi.specTypes l_comparison = length((unwrap_unionall(comparison)::DataType).parameters) spec_len = max(spec_len, l_comparison) else @@ -603,7 +600,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp end # see if the type is actually too big (relative to the caller), and limit it if required - newsig = limit_type_size(sig, comparison, hardlimit ? comparison : sv.linfo.specTypes, InferenceParams(interp).tuple_complexity_limit_depth, spec_len) + newsig = limit_type_size(sig, comparison, hardlimit ? comparison : mi.specTypes, InferenceParams(interp).tuple_complexity_limit_depth, spec_len) if newsig !== sig # continue inference, but note that we've limited parameter complexity @@ -618,9 +615,16 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp return MethodCallResult(Any, true, true, nothing, Effects()) end add_remark!(interp, sv, washardlimit ? RECURSION_MSG_HARDLIMIT : RECURSION_MSG) - topmost = topmost::InferenceState - parentframe = topmost.parent - poison_callstack(sv, parentframe === nothing ? topmost : parentframe) + # TODO (#48913) implement a proper recursion handling for irinterp: + # This works just because currently the `:terminate` condition guarantees that + # irinterp doesn't fail into unresolved cycles, but it's not a good solution. + # We should revisit this once we have a better story for handling cycles in irinterp. + if isa(topmost, InferenceState) + parentframe = frame_parent(topmost) + if isa(sv, InferenceState) && isa(parentframe, InferenceState) + poison_callstack!(sv, parentframe === nothing ? topmost : parentframe) + end + end sig = newsig sparams = svec() edgelimited = true @@ -680,7 +684,9 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp return MethodCallResult(rt, edgecycle, edgelimited, edge, effects) end -function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState) +function edge_matches_sv(interp::AbstractInterpreter, frame::AbsIntState, + method::Method, @nospecialize(sig), sparams::SimpleVector, + hardlimit::Bool, sv::AbsIntState) # The `method_for_inference_heuristics` will expand the given method's generator if # necessary in order to retrieve this field from the generated `CodeInfo`, if it exists. # The other `CodeInfo`s we inspect will already have this field inflated, so we just @@ -688,12 +694,12 @@ function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, met world = get_world_counter(interp) callee_method2 = method_for_inference_heuristics(method, sig, sparams, world) # Union{Method, Nothing} - inf_method2 = frame.src.method_for_inference_limit_heuristics # limit only if user token match + inf_method2 = method_for_inference_limit_heuristics(frame) # limit only if user token match inf_method2 isa Method || (inf_method2 = nothing) if callee_method2 !== inf_method2 return false end - if !hardlimit || InferenceParams(sv.interp).ignore_recursion_hardlimit + if !hardlimit || InferenceParams(interp).ignore_recursion_hardlimit # if this is a soft limit, # also inspect the parent of this edge, # to see if they are the same Method as sv @@ -702,11 +708,10 @@ function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, met # check in the cycle list first # all items in here are mutual parents of all others - if !any(p::InferenceState->matches_sv(p, sv), frame.callers_in_cycle) - let parent = frame.parent + if !any(p::AbsIntState->matches_sv(p, sv), callers_in_cycle(frame)) + let parent = frame_parent(frame) parent !== nothing || return false - parent = parent::InferenceState - (parent.cached || parent.parent !== nothing) || return false + (is_cached(parent) || frame_parent(parent) !== nothing) || return false matches_sv(parent, sv) || return false end end @@ -714,7 +719,7 @@ function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, met # If the method defines a recursion relation, give it a chance # to tell us that this recursion is actually ok. if isdefined(method, :recursion_relation) - if Core._apply_pure(method.recursion_relation, Any[method, callee_method2, sig, frame.linfo.specTypes]) + if Core._apply_pure(method.recursion_relation, Any[method, callee_method2, sig, frame_instance(frame).specTypes]) return false end end @@ -739,35 +744,35 @@ function method_for_inference_heuristics(method::Method, @nospecialize(sig), spa return nothing end -function matches_sv(parent::InferenceState, sv::InferenceState) - sv_method2 = sv.src.method_for_inference_limit_heuristics # limit only if user token match +function matches_sv(parent::AbsIntState, sv::AbsIntState) + sv_method2 = method_for_inference_limit_heuristics(sv) # limit only if user token match sv_method2 isa Method || (sv_method2 = nothing) - parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match + parent_method2 = method_for_inference_limit_heuristics(parent) # limit only if user token match parent_method2 isa Method || (parent_method2 = nothing) - return parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 + return frame_instance(parent).def === frame_instance(sv).def && sv_method2 === parent_method2 end -function is_edge_recursed(edge::MethodInstance, sv::InferenceState) - return any(InfStackUnwind(sv)) do infstate - return edge === infstate.linfo +function is_edge_recursed(edge::MethodInstance, caller::AbsIntState) + return any(AbsIntStackUnwind(caller)) do sv::AbsIntState + return edge === frame_instance(sv) end end -function is_method_recursed(method::Method, sv::InferenceState) - return any(InfStackUnwind(sv)) do infstate - return method === infstate.linfo.def +function is_method_recursed(method::Method, caller::AbsIntState) + return any(AbsIntStackUnwind(caller)) do sv::AbsIntState + return method === frame_instance(sv).def end end -function is_constprop_edge_recursed(edge::MethodInstance, sv::InferenceState) - return any(InfStackUnwind(sv)) do infstate - return edge === infstate.linfo && any(infstate.result.overridden_by_const) +function is_constprop_edge_recursed(edge::MethodInstance, caller::AbsIntState) + return any(AbsIntStackUnwind(caller)) do sv::AbsIntState + return edge === frame_instance(sv) && is_constproped(sv) end end -function is_constprop_method_recursed(method::Method, sv::InferenceState) - return any(InfStackUnwind(sv)) do infstate - return method === infstate.linfo.def && any(infstate.result.overridden_by_const) +function is_constprop_method_recursed(method::Method, caller::AbsIntState) + return any(AbsIntStackUnwind(caller)) do sv::AbsIntState + return method === frame_instance(sv).def && is_constproped(sv) end end @@ -792,7 +797,7 @@ end # - false: eligible for semi-concrete evaluation # - nothing: not eligible for either of it function concrete_eval_eligible(interp::AbstractInterpreter, - @nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState) + @nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::AbsIntState) # disable all concrete-evaluation if this function call is tainted by some overlayed # method since currently there is no direct way to execute overlayed methods if inbounds_option() === :off @@ -842,7 +847,7 @@ end function concrete_eval_call(interp::AbstractInterpreter, @nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, si::StmtInfo, - sv::InferenceState, invokecall::Union{Nothing,InvokeCall}=nothing) + sv::AbsIntState, invokecall::Union{Nothing,InvokeCall}=nothing) eligible = concrete_eval_eligible(interp, f, result, arginfo, sv) eligible === nothing && return false if eligible @@ -869,7 +874,7 @@ end any_conditional(argtypes::Vector{Any}) = any(@nospecialize(x)->isa(x, Conditional), argtypes) any_conditional(arginfo::ArgInfo) = any_conditional(arginfo.argtypes) -function const_prop_enabled(interp::AbstractInterpreter, sv::InferenceState, match::MethodMatch) +function const_prop_enabled(interp::AbstractInterpreter, sv::AbsIntState, match::MethodMatch) if !InferenceParams(interp).ipo_constant_propagation add_remark!(interp, sv, "[constprop] Disabled by parameter") return false @@ -893,7 +898,7 @@ struct ConstCallResults new(rt, const_result, effects, edge) end -# TODO MustAlias forwarding +# TODO implement MustAlias forwarding struct ConditionalArgtypes <: ForwardableArgtypes arginfo::ArgInfo @@ -958,9 +963,23 @@ function matching_cache_argtypes(๐•ƒ::AbstractLattice, linfo::MethodInstance, a return pick_const_args!(๐•ƒ, cache_argtypes, overridden_by_const, given_argtypes) end +# check if there is a cycle and duplicated inference of `mi` +function is_constprop_recursed(result::MethodCallResult, mi::MethodInstance, sv::AbsIntState) + result.edgecycle || return false + if result.edgelimited + return is_constprop_method_recursed(mi.def::Method, sv) + else + # if the type complexity limiting didn't decide to limit the call signature (as + # indicated by `result.edgelimited === false`), we can relax the cycle detection + # by comparing `MethodInstance`s and allow inference to propagate different + # constant elements if the recursion is finite over the lattice + return is_constprop_edge_recursed(mi, sv) + end +end + function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, match::MethodMatch, - sv::InferenceState, invokecall::Union{Nothing,InvokeCall}=nothing) + sv::AbsIntState, invokecall::Union{Nothing,InvokeCall}=nothing) if !const_prop_enabled(interp, sv, match) return nothing end @@ -974,19 +993,28 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, isa(res, ConstCallResults) && return res mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, si, match, sv) mi === nothing && return nothing + if is_constprop_recursed(result, mi, sv) + add_remark!(interp, sv, "[constprop] Edge cycle encountered") + return nothing + end # try semi-concrete evaluation if res::Bool && !any_conditional(arginfo) - mi_cache = WorldView(code_cache(interp), sv.world) + world = frame_world(sv) + mi_cache = WorldView(code_cache(interp), world) code = get(mi_cache, mi, nothing) if code !== nothing - ir = codeinst_to_ir(interp, code) - if isa(ir, IRCode) - irinterp = switch_to_irinterp(interp) - irsv = IRInterpretationState(irinterp, ir, mi, sv.world, arginfo.argtypes) - rt, nothrow = ir_abstract_constant_propagation(irinterp, irsv) - @assert !(rt isa Conditional || rt isa MustAlias) "invalid lattice element returned from IR interpretation" - if !isa(rt, Type) || typeintersect(rt, Bool) === Union{} - new_effects = Effects(result.effects; nothrow=nothrow) + irsv = IRInterpretationState(interp, code, mi, arginfo.argtypes, world) + if irsv !== nothing + irsv.parent = sv + rt, nothrow = ir_abstract_constant_propagation(interp, irsv) + @assert !(rt isa Conditional || rt isa MustAlias) "invalid lattice element returned from irinterp" + if !(isa(rt, Type) && hasintersect(rt, Bool)) + ir = irsv.ir + # TODO (#48913) enable double inlining pass when there are any calls + # that are newly resovled by irinterp + # state = InliningState(interp) + # ir = ssa_inlining_pass!(irsv.ir, state, propagate_inbounds(irsv)) + new_effects = Effects(result.effects; nothrow) return ConstCallResults(rt, SemiConcreteResult(mi, ir, new_effects), new_effects, mi) end end @@ -997,18 +1025,8 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, ๐•ƒแตข = typeinf_lattice(interp) inf_result = cache_lookup(๐•ƒแตข, mi, arginfo.argtypes, inf_cache) if inf_result === nothing - # if there might be a cycle, check to make sure we don't end up - # calling ourselves here. - if result.edgecycle && (result.edgelimited ? - is_constprop_method_recursed(match.method, sv) : - # if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`) - # we can relax the cycle detection by comparing `MethodInstance`s and allow inference to - # propagate different constant elements if the recursion is finite over the lattice - is_constprop_edge_recursed(mi, sv)) - add_remark!(interp, sv, "[constprop] Edge cycle encountered") - return nothing - end - argtypes = has_conditional(๐•ƒแตข) ? ConditionalArgtypes(arginfo, sv) : SimpleArgtypes(arginfo.argtypes) + # fresh constant prop' + argtypes = has_conditional(๐•ƒแตข, sv) ? ConditionalArgtypes(arginfo, sv) : SimpleArgtypes(arginfo.argtypes) inf_result = InferenceResult(mi, argtypes, typeinf_lattice(interp)) if !any(inf_result.overridden_by_const) add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes") @@ -1026,6 +1044,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, end @assert inf_result.result !== nothing else + # found the cache for this constant prop' if inf_result.result === nothing add_remark!(interp, sv, "[constprop] Found cached constant inference in a cycle") return nothing @@ -1038,7 +1057,7 @@ end # (hopefully without doing too much work), returns `MethodInstance`, or nothing otherwise function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::MethodCallResult, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, - match::MethodMatch, sv::InferenceState) + match::MethodMatch, sv::AbsIntState) method = match.method force = force_const_prop(interp, f, method) force || const_prop_entry_heuristic(interp, result, si, sv) || return nothing @@ -1050,8 +1069,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, return nothing end all_overridden = is_all_overridden(interp, arginfo, sv) - if !force && !const_prop_function_heuristic(interp, f, arginfo, nargs, all_overridden, - is_nothrow(sv.ipo_effects), sv) + if !force && !const_prop_function_heuristic(interp, f, arginfo, nargs, all_overridden, sv) add_remark!(interp, sv, "[constprop] Disabled by function heuristic") return nothing end @@ -1069,7 +1087,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, return mi end -function const_prop_entry_heuristic(interp::AbstractInterpreter, result::MethodCallResult, si::StmtInfo, sv::InferenceState) +function const_prop_entry_heuristic(interp::AbstractInterpreter, result::MethodCallResult, si::StmtInfo, sv::AbsIntState) if call_result_unused(si) && result.edgecycle add_remark!(interp, sv, "[constprop] Disabled by entry heuristic (edgecycle with unused result)") return false @@ -1108,12 +1126,12 @@ end # determines heuristically whether if constant propagation can be worthwhile # by checking if any of given `argtypes` is "interesting" enough to be propagated -function const_prop_argument_heuristic(interp::AbstractInterpreter, arginfo::ArgInfo, sv::InferenceState) +function const_prop_argument_heuristic(interp::AbstractInterpreter, arginfo::ArgInfo, sv::AbsIntState) ๐•ƒแตข = typeinf_lattice(interp) argtypes = arginfo.argtypes for i in 1:length(argtypes) a = argtypes[i] - if has_conditional(๐•ƒแตข) && isa(a, Conditional) && arginfo.fargs !== nothing + if has_conditional(๐•ƒแตข, sv) && isa(a, Conditional) && arginfo.fargs !== nothing is_const_prop_profitable_conditional(a, arginfo.fargs, sv) && return true else a = widenslotwrapper(a) @@ -1146,11 +1164,11 @@ function find_constrained_arg(cnd::Conditional, fargs::Vector{Any}, sv::Inferenc end # checks if all argtypes has additional information other than what `Type` can provide -function is_all_overridden(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, sv::InferenceState) +function is_all_overridden(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, sv::AbsIntState) ๐•ƒแตข = typeinf_lattice(interp) for i in 1:length(argtypes) a = argtypes[i] - if has_conditional(๐•ƒแตข) && isa(a, Conditional) && fargs !== nothing + if has_conditional(๐•ƒแตข, sv) && isa(a, Conditional) && fargs !== nothing is_const_prop_profitable_conditional(a, fargs, sv) || return false else is_forwardable_argtype(๐•ƒแตข, widenslotwrapper(a)) || return false @@ -1166,8 +1184,8 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method: istopfunction(f, :setproperty!) end -function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), arginfo::ArgInfo, - nargs::Int, all_overridden::Bool, still_nothrow::Bool, _::InferenceState) +function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), + arginfo::ArgInfo, nargs::Int, all_overridden::Bool, sv::AbsIntState) argtypes = arginfo.argtypes if nargs > 1 ๐•ƒแตข = typeinf_lattice(interp) @@ -1177,6 +1195,7 @@ function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecializ if arrty isa Type && arrty <: AbstractArray && !issingletontype(arrty) # For static arrays, allow the constprop if we could possibly # deduce nothrow as a result. + still_nothrow = isa(sv, InferenceState) ? is_nothrow(sv.ipo_effects) : false if !still_nothrow || ismutabletype(arrty) return false end @@ -1214,7 +1233,7 @@ end # where we would spend a lot of time, but are probably unlikely to get an improved # result anyway. function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, - mi::MethodInstance, arginfo::ArgInfo, sv::InferenceState) + mi::MethodInstance, arginfo::ArgInfo, sv::AbsIntState) method = mi.def::Method if method.is_for_opaque_closure # Not inlining an opaque closure can be very expensive, so be generous @@ -1258,7 +1277,6 @@ end # This is only for use with `Conditional`. # In general, usage of this is wrong. -ssa_def_slot(@nospecialize(arg), sv::IRCode) = nothing function ssa_def_slot(@nospecialize(arg), sv::InferenceState) code = sv.src.code init = sv.currpc @@ -1322,7 +1340,7 @@ AbstractIterationResult(cti::Vector{Any}, info::MaybeAbstractIterationInfo) = # Union of Tuples of the same length is converted to Tuple of Unions. # returns an array of types function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(typ), - sv::Union{InferenceState, IRCode}) + sv::AbsIntState) if isa(typ, PartialStruct) widet = typ.typ if isa(widet, DataType) && widet.name === Tuple.name @@ -1392,7 +1410,7 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft) end # simulate iteration protocol on container type up to fixpoint -function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(itertype), sv::Union{InferenceState, IRCode}) +function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(itertype), sv::AbsIntState) if isa(itft, Const) iteratef = itft.val else @@ -1481,8 +1499,7 @@ end # do apply(af, fargs...), where af is a function value function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, - sv::Union{InferenceState, IRCode}, - max_methods::Int = get_max_methods(sv.mod, interp)) + sv::AbsIntState, max_methods::Int=get_max_methods(sv, interp)) itft = argtype_by_index(argtypes, 2) aft = argtype_by_index(argtypes, 3) (itft === Bottom || aft === Bottom) && return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo()) @@ -1664,12 +1681,12 @@ end end function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs, argtypes)::ArgInfo, - sv::Union{InferenceState, IRCode}, max_methods::Int) + sv::AbsIntState, max_methods::Int) @nospecialize f la = length(argtypes) ๐•ƒแตข = typeinf_lattice(interp) โŠ‘แตข = โŠ‘(๐•ƒแตข) - if has_conditional(๐•ƒแตข) && f === Core.ifelse && fargs isa Vector{Any} && la == 4 + if has_conditional(๐•ƒแตข, sv) && f === Core.ifelse && fargs isa Vector{Any} && la == 4 cnd = argtypes[2] if isa(cnd, Conditional) newcnd = widenconditional(cnd) @@ -1708,7 +1725,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs end end end - elseif has_conditional(๐•ƒแตข) && (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any}) + elseif has_conditional(๐•ƒแตข, sv) && (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any}) # perform very limited back-propagation of type information for `is` and `isa` if f === isa a = ssa_def_slot(fargs[2], sv) @@ -1852,7 +1869,7 @@ function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{An return CallMeta(Any, EFFECTS_UNKNOWN, NoCallInfo()) end -function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, si::StmtInfo, sv::InferenceState) +function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, si::StmtInfo, sv::AbsIntState) ftโ€ฒ = argtype_by_index(argtypes, 2) ft = widenconst(ftโ€ฒ) ft === Bottom && return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo()) @@ -1915,7 +1932,7 @@ function invoke_rewrite(xs::Vector{Any}) return newxs end -function abstract_finalizer(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState) +function abstract_finalizer(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::AbsIntState) if length(argtypes) == 3 finalizer_argvec = Any[argtypes[2], argtypes[3]] call = abstract_call(interp, ArgInfo(nothing, finalizer_argvec), StmtInfo(false), sv, 1) @@ -1926,8 +1943,8 @@ end # call where the function is known exactly function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), - arginfo::ArgInfo, si::StmtInfo, sv::Union{InferenceState, IRCode}, - max_methods::Int = isa(sv, InferenceState) ? get_max_methods(f, sv.mod, interp) : 0) + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, + max_methods::Int = get_max_methods(f, sv, interp)) (; fargs, argtypes) = arginfo la = length(argtypes) @@ -2066,7 +2083,7 @@ end # call where the function is any lattice element function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, si::StmtInfo, - sv::Union{InferenceState, IRCode}, max_methods::Union{Int, Nothing} = isa(sv, IRCode) ? 0 : nothing) + sv::AbsIntState, max_methods::Union{Int, Nothing} = nothing) argtypes = arginfo.argtypes ft = widenslotwrapper(argtypes[1]) f = singleton_type(ft) @@ -2089,10 +2106,10 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, si::StmtIn return CallMeta(Any, Effects(), NoCallInfo()) end # non-constant function, but the number of arguments is known and the `f` is not a builtin or intrinsic - max_methods = max_methods === nothing ? get_max_methods(sv.mod, interp) : max_methods + max_methods = max_methods === nothing ? get_max_methods(sv, interp) : max_methods return abstract_call_gf_by_type(interp, nothing, arginfo, si, argtypes_to_type(argtypes), sv, max_methods) end - max_methods = max_methods === nothing ? get_max_methods(f, sv.mod, interp) : max_methods + max_methods = max_methods === nothing ? get_max_methods(f, sv, interp) : max_methods return abstract_call_known(interp, f, arginfo, si, sv, max_methods) end @@ -2132,10 +2149,10 @@ function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool) return unwraptv(T) end -function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::VarTable, sv::InferenceState) +function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing}, sv::AbsIntState) f = abstract_eval_value(interp, e.args[2], vtypes, sv) # rt = sp_type_rewrap(e.args[3], sv.linfo, true) - at = Any[ sp_type_rewrap(argt, sv.linfo, false) for argt in e.args[4]::SimpleVector ] + at = Any[ sp_type_rewrap(argt, frame_instance(sv), false) for argt in e.args[4]::SimpleVector ] pushfirst!(at, f) # this may be the wrong world for the call, # but some of the result is likely to be valid anyways @@ -2144,7 +2161,7 @@ function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::V nothing end -function abstract_eval_value_expr(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable, Nothing}, sv::Union{InferenceState, IRCode}) +function abstract_eval_value_expr(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing}, sv::AbsIntState) rt = Any head = e.head if head === :static_parameter @@ -2186,23 +2203,27 @@ function abstract_eval_value_expr(interp::AbstractInterpreter, e::Expr, vtypes:: return rt end -function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(e), vtypes::Union{VarTable, Nothing}, sv::Union{InferenceState, IRCode}) +function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(e), vtypes::Union{VarTable,Nothing}, sv::AbsIntState) if isa(e, QuoteNode) return Const(e.value) elseif isa(e, SSAValue) return abstract_eval_ssavalue(e, sv) elseif isa(e, SlotNumber) - vtyp = vtypes[slot_id(e)] - if vtyp.undef - merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; nothrow=false)) + if vtypes !== nothing + vtyp = vtypes[slot_id(e)] + if vtyp.undef + merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; nothrow=false)) + end + return vtyp.typ end - return vtyp.typ + merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; nothrow=false)) + return Any elseif isa(e, Argument) - if !isa(vtypes, Nothing) + if vtypes !== nothing return vtypes[slot_id(e)].typ else - @assert isa(sv, IRCode) - return sv.argtypes[e.n] + @assert isa(sv, IRInterpretationState) + return sv.ir.argtypes[e.n] # TODO frame_argtypes(sv)[e.n] and remove the assertion end elseif isa(e, GlobalRef) return abstract_eval_globalref(interp, e, sv) @@ -2211,7 +2232,7 @@ function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize( return Const(e) end -function abstract_eval_value(interp::AbstractInterpreter, @nospecialize(e), vtypes::Union{VarTable, Nothing}, sv::Union{InferenceState, IRCode}) +function abstract_eval_value(interp::AbstractInterpreter, @nospecialize(e), vtypes::Union{VarTable,Nothing}, sv::AbsIntState) if isa(e, Expr) return abstract_eval_value_expr(interp, e, vtypes, sv) else @@ -2220,7 +2241,7 @@ function abstract_eval_value(interp::AbstractInterpreter, @nospecialize(e), vtyp end end -function collect_argtypes(interp::AbstractInterpreter, ea::Vector{Any}, vtypes::Union{VarTable, Nothing}, sv::Union{InferenceState, IRCode}) +function collect_argtypes(interp::AbstractInterpreter, ea::Vector{Any}, vtypes::Union{VarTable,Nothing}, sv::AbsIntState) n = length(ea) argtypes = Vector{Any}(undef, n) @inbounds for i = 1:n @@ -2239,33 +2260,39 @@ struct RTEffects RTEffects(@nospecialize(rt), effects::Effects) = new(rt, effects) end -function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable, Nothing}, - sv::Union{InferenceState, IRCode}, mi::Union{MethodInstance, Nothing})::RTEffects +function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sv::InferenceState) + si = StmtInfo(!call_result_unused(sv, sv.currpc)) + (; rt, effects, info) = abstract_call(interp, arginfo, si, sv) + sv.stmt_info[sv.currpc] = info + # mark this call statement as DCE-elgible + # TODO better to do this in a single pass based on the `info` object at the end of abstractinterpret? + if is_removable_if_unused(effects) + add_curr_ssaflag!(sv, IR_FLAG_EFFECT_FREE) + else + sub_curr_ssaflag!(sv, IR_FLAG_EFFECT_FREE) + end + return RTEffects(rt, effects) +end + +function abstract_eval_call(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing}, + sv::AbsIntState) + ea = e.args + argtypes = collect_argtypes(interp, ea, vtypes, sv) + if argtypes === nothing + return RTEffects(Bottom, Effects()) + end + arginfo = ArgInfo(ea, argtypes) + return abstract_call(interp, arginfo, sv) +end + +function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing}, + sv::AbsIntState) effects = EFFECTS_UNKNOWN ehead = e.head ๐•ƒแตข = typeinf_lattice(interp) โŠ‘แตข = โŠ‘(๐•ƒแตข) if ehead === :call - ea = e.args - argtypes = collect_argtypes(interp, ea, vtypes, sv) - if argtypes === nothing - rt = Bottom - effects = Effects() - else - arginfo = ArgInfo(ea, argtypes) - si = StmtInfo(isa(sv, IRCode) ? true : !call_result_unused(sv, sv.currpc)) - (; rt, effects, info) = abstract_call(interp, arginfo, si, sv) - if isa(sv, InferenceState) - sv.stmt_info[sv.currpc] = info - # mark this call statement as DCE-elgible - # TODO better to do this in a single pass based on the `info` object at the end of abstractinterpret? - if is_removable_if_unused(effects) - add_curr_ssaflag!(sv, IR_FLAG_EFFECT_FREE) - else - sub_curr_ssaflag!(sv, IR_FLAG_EFFECT_FREE) - end - end - end + (; rt, effects) = abstract_eval_call(interp, e, vtypes, sv) t = rt elseif ehead === :new t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv)) @@ -2365,9 +2392,9 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp if argtypes === nothing t = Bottom else - miโ€ฒ = isa(sv, InferenceState) ? sv.linfo : mi - t = _opaque_closure_tfunc(๐•ƒแตข, argtypes[1], argtypes[2], argtypes[3], - argtypes[4], argtypes[5:end], miโ€ฒ) + mi = frame_instance(sv) + t = opaque_closure_tfunc(๐•ƒแตข, argtypes[1], argtypes[2], argtypes[3], + argtypes[4], argtypes[5:end], mi) if isa(t, PartialOpaque) && isa(sv, InferenceState) && !call_result_unused(sv, sv.currpc) # Infer this now so that the specialization is available to # optimization. @@ -2380,7 +2407,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp end end elseif ehead === :foreigncall - (;rt, effects) = abstract_eval_foreigncall(interp, e, vtypes, sv, mi) + (; rt, effects) = abstract_eval_foreigncall(interp, e, vtypes, sv) t = rt if isa(sv, InferenceState) # mark this call statement as DCE-elgible @@ -2411,7 +2438,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp sym = e.args[1] t = Bool effects = EFFECTS_TOTAL - if isa(sym, SlotNumber) + if isa(sym, SlotNumber) && vtypes !== nothing vtyp = vtypes[slot_id(sym)] if vtyp.typ === Bottom t = Const(false) # never assigned previously @@ -2419,7 +2446,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp t = Const(true) # definitely assigned previously end elseif isa(sym, Symbol) - if isdefined(sv.mod, sym) + if isdefined(frame_module(sv), sym) t = Const(true) elseif InferenceParams(interp).assume_bindings_static t = Const(false) @@ -2465,10 +2492,10 @@ function refine_partial_type(@nospecialize t) return t end -function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable, Nothing}, sv::Union{InferenceState, IRCode}, mi::Union{MethodInstance, Nothing}=nothing) +function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing}, sv::AbsIntState) abstract_eval_value(interp, e.args[1], vtypes, sv) - miโ€ฒ = isa(sv, InferenceState) ? sv.linfo : mi - t = sp_type_rewrap(e.args[2], miโ€ฒ, true) + mi = frame_instance(sv) + t = sp_type_rewrap(e.args[2], mi, true) for i = 3:length(e.args) if abstract_eval_value(interp, e.args[i], vtypes, sv) === Bottom return RTEffects(Bottom, EFFECTS_THROWS) @@ -2493,7 +2520,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes: return RTEffects(t, effects) end -function abstract_eval_phi(interp::AbstractInterpreter, phi::PhiNode, vtypes::Union{VarTable, Nothing}, sv::Union{InferenceState, IRCode}) +function abstract_eval_phi(interp::AbstractInterpreter, phi::PhiNode, vtypes::Union{VarTable,Nothing}, sv::AbsIntState) rt = Union{} for i in 1:length(phi.values) isassigned(phi.values, i) || continue @@ -2503,8 +2530,8 @@ function abstract_eval_phi(interp::AbstractInterpreter, phi::PhiNode, vtypes::Un return rt end -function stmt_taints_inbounds_consistency(sv::InferenceState) - sv.src.propagate_inbounds && return true +function stmt_taints_inbounds_consistency(sv::AbsIntState) + propagate_inbounds(sv) && return true return (get_curr_ssaflag(sv) & IR_FLAG_INBOUNDS) != 0 end @@ -2515,9 +2542,9 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), end return abstract_eval_special_value(interp, e, vtypes, sv) end - (;rt, effects) = abstract_eval_statement_expr(interp, e, vtypes, sv, nothing) + (; rt, effects) = abstract_eval_statement_expr(interp, e, vtypes, sv) if !effects.noinbounds - if !sv.src.propagate_inbounds + if !propagate_inbounds(sv) # The callee read our inbounds flag, but unless we propagate inbounds, # we ourselves don't read our parent's inbounds. effects = Effects(effects; noinbounds=true) @@ -2555,7 +2582,7 @@ function abstract_eval_globalref(g::GlobalRef) end abstract_eval_global(M::Module, s::Symbol) = abstract_eval_globalref(GlobalRef(M, s)) -function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, frame::Union{InferenceState, IRCode}) +function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::AbsIntState) rt = abstract_eval_globalref(g) consistent = inaccessiblememonly = ALWAYS_FALSE nothrow = false @@ -2573,7 +2600,7 @@ function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, fram consistent = inaccessiblememonly = ALWAYS_TRUE rt = Union{} end - merge_effects!(interp, frame, Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly)) + merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly)) return rt end @@ -2586,7 +2613,7 @@ function handle_global_assignment!(interp::AbstractInterpreter, frame::Inference end abstract_eval_ssavalue(s::SSAValue, sv::InferenceState) = abstract_eval_ssavalue(s, sv.ssavaluetypes) -abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any}) + function abstract_eval_ssavalue(s::SSAValue, ssavaluetypes::Vector{Any}) typ = ssavaluetypes[s.id] if typ === NOT_FOUND diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 9229f54f143f2..0a1b852b052f9 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -152,7 +152,6 @@ include("compiler/ssair/domtree.jl") include("compiler/ssair/ir.jl") include("compiler/abstractlattice.jl") - include("compiler/inferenceresult.jl") include("compiler/inferencestate.jl") diff --git a/base/compiler/inferenceresult.jl b/base/compiler/inferenceresult.jl index c079553fca06a..790af05eebada 100644 --- a/base/compiler/inferenceresult.jl +++ b/base/compiler/inferenceresult.jl @@ -86,7 +86,7 @@ function va_process_argtypes(@nospecialize(va_handler!), ๐•ƒ::AbstractLattice, nargs = Int(def.nargs) if isva || isvarargtype(given_argtypes[end]) isva_given_argtypes = Vector{Any}(undef, nargs) - for i = 1:(nargs - isva) + for i = 1:(nargs-isva) isva_given_argtypes[i] = argtype_by_index(given_argtypes, i) end if isva @@ -110,10 +110,8 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe isva = !toplevel && method.isva linfo_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...] nargs::Int = toplevel ? 0 : method.nargs - if !withfirst - # For opaque closure, the closure environment is processed elsewhere - nargs -= 1 - end + # For opaque closure, the closure environment is processed elsewhere + withfirst || (nargs -= 1) cache_argtypes = Vector{Any}(undef, nargs) # First, if we're dealing with a varargs method, then we set the last element of `args` # to the appropriate `Tuple` type or `PartialStruct` instance. diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 48d25243417fe..0e0409f755a0b 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -1,14 +1,7 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license -""" - const VarTable = Vector{VarState} - -The extended lattice that maps local variables to inferred type represented as `AbstractLattice`. -Each index corresponds to the `id` of `SlotNumber` which identifies each local variable. -Note that `InferenceState` will maintain multiple `VarTable`s at each SSA statement -to enable flow-sensitive analysis. -""" -const VarTable = Vector{VarState} +# data structures +# =============== mutable struct BitSetBoundedMinPrioritySet <: AbstractSet{Int} elems::BitSet @@ -78,6 +71,130 @@ function append!(bsbmp::BitSetBoundedMinPrioritySet, itr) end end +mutable struct TwoPhaseVectorView <: AbstractVector{Int} + const data::Vector{Int} + count::Int + const range::UnitRange{Int} +end +size(tpvv::TwoPhaseVectorView) = (tpvv.count,) +function getindex(tpvv::TwoPhaseVectorView, i::Int) + checkbounds(tpvv, i) + @inbounds tpvv.data[first(tpvv.range) + i - 1] +end +function push!(tpvv::TwoPhaseVectorView, v::Int) + tpvv.count += 1 + tpvv.data[first(tpvv.range) + tpvv.count - 1] = v + return nothing +end + +""" + mutable struct TwoPhaseDefUseMap + +This struct is intended as a memory- and GC-pressure-efficient mechanism +for incrementally computing def-use maps. The idea is that the def-use map +is constructed into two passes over the IR. In the first, we simply count the +the number of uses, computing the number of uses for each def as well as the +total number of uses. In the second pass, we actually fill in the def-use +information. + +The idea is that either of these two phases can be combined with other useful +work that needs to scan the instruction stream anyway, while avoiding the +significant allocation pressure of e.g. allocating an array for every SSA value +or attempting to dynamically move things around as new uses are discovered. + +The def-use map is presented as a vector of vectors. For every def, indexing +into the map will return a vector of uses. +""" +mutable struct TwoPhaseDefUseMap <: AbstractVector{TwoPhaseVectorView} + ssa_uses::Vector{Int} + data::Vector{Int} + complete::Bool +end + +function complete!(tpdum::TwoPhaseDefUseMap) + cumsum = 0 + for i = 1:length(tpdum.ssa_uses) + this_val = cumsum + 1 + cumsum += tpdum.ssa_uses[i] + tpdum.ssa_uses[i] = this_val + end + resize!(tpdum.data, cumsum) + fill!(tpdum.data, 0) + tpdum.complete = true +end + +function TwoPhaseDefUseMap(nssas::Int) + ssa_uses = zeros(Int, nssas) + data = Int[] + complete = false + return TwoPhaseDefUseMap(ssa_uses, data, complete) +end + +function count!(tpdum::TwoPhaseDefUseMap, arg::SSAValue) + @assert !tpdum.complete + tpdum.ssa_uses[arg.id] += 1 +end + +function kill_def_use!(tpdum::TwoPhaseDefUseMap, def::Int, use::Int) + if !tpdum.complete + tpdum.ssa_uses[def] -= 1 + else + range = tpdum.ssa_uses[def]:(def == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[def + 1] - 1)) + # TODO: Sorted + useidx = findfirst(idx->tpdum.data[idx] == use, range) + @assert useidx !== nothing + idx = range[useidx] + while idx < lastindex(range) + ndata = tpdum.data[idx+1] + ndata == 0 && break + tpdum.data[idx] = ndata + end + tpdum.data[idx + 1] = 0 + end +end +kill_def_use!(tpdum::TwoPhaseDefUseMap, def::SSAValue, use::Int) = + kill_def_use!(tpdum, def.id, use) + +function getindex(tpdum::TwoPhaseDefUseMap, idx::Int) + @assert tpdum.complete + range = tpdum.ssa_uses[idx]:(idx == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[idx + 1] - 1)) + # TODO: Make logarithmic + nelems = 0 + for i in range + tpdum.data[i] == 0 && break + nelems += 1 + end + return TwoPhaseVectorView(tpdum.data, nelems, range) +end + +mutable struct LazyGenericDomtree{IsPostDom} + ir::IRCode + domtree::GenericDomTree{IsPostDom} + LazyGenericDomtree{IsPostDom}(ir::IRCode) where {IsPostDom} = new{IsPostDom}(ir) +end +function get!(x::LazyGenericDomtree{IsPostDom}) where {IsPostDom} + isdefined(x, :domtree) && return x.domtree + return @timeit "domtree 2" x.domtree = IsPostDom ? + construct_postdomtree(x.ir.cfg.blocks) : + construct_domtree(x.ir.cfg.blocks) +end + +const LazyDomtree = LazyGenericDomtree{false} +const LazyPostDomtree = LazyGenericDomtree{true} + +# InferenceState +# ============== + +""" + const VarTable = Vector{VarState} + +The extended lattice that maps local variables to inferred type represented as `AbstractLattice`. +Each index corresponds to the `id` of `SlotNumber` which identifies each local variable. +Note that `InferenceState` will maintain multiple `VarTable`s at each SSA statement +to enable flow-sensitive analysis. +""" +const VarTable = Vector{VarState} + mutable struct InferenceState #= information about this method instance =# linfo::MethodInstance @@ -87,6 +204,7 @@ mutable struct InferenceState slottypes::Vector{Any} src::CodeInfo cfg::CFG + method_info::MethodInfo #= intermediate states for local abstract interpretation =# currbb::Int @@ -106,7 +224,7 @@ mutable struct InferenceState cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller callers_in_cycle::Vector{InferenceState} dont_work_on_me::Bool - parent::Union{Nothing, InferenceState} + parent # ::Union{Nothing,AbsIntState} #= results =# result::InferenceResult # remember where to put the result @@ -135,6 +253,7 @@ mutable struct InferenceState sptypes = sptypes_from_meth_instance(linfo) code = src.code::Vector{Any} cfg = compute_basic_blocks(code) + method_info = MethodInfo(src) currbb = currpc = 1 ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1) @@ -183,7 +302,7 @@ mutable struct InferenceState cache !== :no && push!(get_inference_cache(interp), result) return new( - linfo, world, mod, sptypes, slottypes, src, cfg, + linfo, world, mod, sptypes, slottypes, src, cfg, method_info, currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info, pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, result, valid_worlds, bestguess, ipo_effects, @@ -195,43 +314,6 @@ end is_inferred(sv::InferenceState) = is_inferred(sv.result) is_inferred(result::InferenceResult) = result.result !== nothing -function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects) - caller.ipo_effects = merge_effects(caller.ipo_effects, effects) -end - -merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) = - merge_effects!(interp, caller, Effects(callee)) -merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing - -is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect) -function is_effect_overridden(linfo::MethodInstance, effect::Symbol) - def = linfo.def - return isa(def, Method) && is_effect_overridden(def, effect) -end -is_effect_overridden(method::Method, effect::Symbol) = is_effect_overridden(decode_effects_override(method.purity), effect) -is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(override, effect) - -add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return - -struct InferenceLoopState - sig - rt - effects::Effects - function InferenceLoopState(@nospecialize(sig), @nospecialize(rt), effects::Effects) - new(sig, rt, effects) - end -end - -function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode}) - return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig) -end -function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode}) - return state.rt === Any && !is_foldable(state.effects) -end -function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode}) - return state.rt === Any -end - was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND function compute_trycatch(code::Vector{Any}, ip::BitSet) @@ -341,29 +423,6 @@ function should_insert_coverage(mod::Module, src::CodeInfo) return false end -""" - Iterate through all callers of the given InferenceState in the abstract - interpretation stack (including the given InferenceState itself), vising - children before their parents (i.e. ascending the tree from the given - InferenceState). Note that cycles may be visited in any order. -""" -struct InfStackUnwind - inf::InferenceState -end -iterate(unw::InfStackUnwind) = (unw.inf, (unw.inf, 0)) -function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState, Int}) - # iterate through the cycle before walking to the parent - if cyclei < length(infstate.callers_in_cycle) - cyclei += 1 - infstate = infstate.callers_in_cycle[cyclei] - else - cyclei = 0 - infstate = infstate.parent - end - infstate === nothing && return nothing - (infstate::InferenceState, (infstate, cyclei)) -end - function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter) # prepare an InferenceState object for inferring lambda world = get_world_counter(interp) @@ -511,14 +570,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance) return sptypes end -_topmod(sv::InferenceState) = _topmod(sv.mod) - -# work towards converging the valid age range for sv -function update_valid_age!(sv::InferenceState, valid_worlds::WorldRange) - valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds) - @assert(sv.world in valid_worlds, "invalid age range update") - return valid_worlds -end +_topmod(sv::InferenceState) = _topmod(frame_module(sv)) function record_ssa_assign!(๐•ƒแตข::AbstractLattice, ssa_id::Int, @nospecialize(new), frame::InferenceState) ssavaluetypes = frame.ssavaluetypes @@ -552,44 +604,16 @@ function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, curr return frame end -# temporarily accumulate our edges to later add as backedges in the callee -function add_backedge!(caller::InferenceState, mi::MethodInstance) - edges = get_stmt_edges!(caller) - if edges !== nothing - push!(edges, mi) - end - return nothing -end - -function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), mi::MethodInstance) - edges = get_stmt_edges!(caller) - if edges !== nothing - push!(edges, invokesig, mi) - end - return nothing -end - -# used to temporarily accumulate our no method errors to later add as backedges in the callee method table -function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ)) - edges = get_stmt_edges!(caller) - if edges !== nothing - push!(edges, mt, typ) - end - return nothing -end - -function get_stmt_edges!(caller::InferenceState) - if !isa(caller.linfo.def, Method) - return nothing # don't add backedges to toplevel exprs - end - edges = caller.stmt_edges[caller.currpc] +function get_stmt_edges!(caller::InferenceState, currpc::Int=caller.currpc) + stmt_edges = caller.stmt_edges + edges = stmt_edges[currpc] if edges === nothing - edges = caller.stmt_edges[caller.currpc] = [] + edges = stmt_edges[currpc] = [] end return edges end -function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc) +function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc) edges = frame.stmt_edges[currpc] edges === nothing || empty!(edges) return nothing @@ -608,10 +632,6 @@ function print_callstack(sv::InferenceState) end end -get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc] -add_curr_ssaflag!(sv::InferenceState, flag::UInt8) = sv.src.ssaflags[sv.currpc] |= flag -sub_curr_ssaflag!(sv::InferenceState, flag::UInt8) = sv.src.ssaflags[sv.currpc] &= ~flag - function narguments(sv::InferenceState, include_va::Bool=true) def = sv.linfo.def nargs = length(sv.result.argtypes) @@ -620,3 +640,223 @@ function narguments(sv::InferenceState, include_va::Bool=true) end return nargs end + +# IRInterpretationState +# ===================== + +# TODO add `result::InferenceResult` and put the irinterp result into the inference cache? +mutable struct IRInterpretationState + const method_info::MethodInfo + const ir::IRCode + const mi::MethodInstance + const world::UInt + curridx::Int + const argtypes_refined::Vector{Bool} + const sptypes::Vector{VarState} + const tpdum::TwoPhaseDefUseMap + const ssa_refined::BitSet + const lazydomtree::LazyDomtree + valid_worlds::WorldRange + const edges::Vector{Any} + parent # ::Union{Nothing,AbsIntState} + + function IRInterpretationState(interp::AbstractInterpreter, + method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any}, + world::UInt, min_world::UInt, max_world::UInt) + curridx = 1 + given_argtypes = Vector{Any}(undef, length(argtypes)) + for i = 1:length(given_argtypes) + given_argtypes[i] = widenslotwrapper(argtypes[i]) + end + given_argtypes = va_process_argtypes(optimizer_lattice(interp), given_argtypes, mi) + argtypes_refined = Bool[!โŠ‘(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i]) + for i = 1:length(given_argtypes)] + empty!(ir.argtypes) + append!(ir.argtypes, given_argtypes) + tpdum = TwoPhaseDefUseMap(length(ir.stmts)) + ssa_refined = BitSet() + lazydomtree = LazyDomtree(ir) + valid_worlds = WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world) + edges = Any[] + parent = nothing + return new(method_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum, + ssa_refined, lazydomtree, valid_worlds, edges, parent) + end +end + +function IRInterpretationState(interp::AbstractInterpreter, + code::CodeInstance, mi::MethodInstance, argtypes::Vector{Any}, world::UInt) + @assert code.def === mi + src = @atomic :monotonic code.inferred + if isa(src, Vector{UInt8}) + src = ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any), mi.def, C_NULL, src)::CodeInfo + else + isa(src, CodeInfo) || return nothing + end + method_info = MethodInfo(src) + ir = inflate_ir(src, mi) + return IRInterpretationState(interp, method_info, ir, mi, argtypes, world, + src.min_world, src.max_world) +end + +# AbsIntState +# =========== + +const AbsIntState = Union{InferenceState,IRInterpretationState} + +frame_instance(sv::InferenceState) = sv.linfo +frame_instance(sv::IRInterpretationState) = sv.mi + +function frame_module(sv::AbsIntState) + mi = frame_instance(sv) + def = mi.def + isa(def, Module) && return def + return def.module +end + +frame_parent(sv::InferenceState) = sv.parent::Union{Nothing,AbsIntState} +frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState} + +is_constproped(sv::InferenceState) = any(sv.result.overridden_by_const) +is_constproped(::IRInterpretationState) = true + +is_cached(sv::InferenceState) = sv.cached +is_cached(::IRInterpretationState) = false + +method_info(sv::InferenceState) = sv.method_info +method_info(sv::IRInterpretationState) = sv.method_info + +propagate_inbounds(sv::AbsIntState) = method_info(sv).propagate_inbounds +method_for_inference_limit_heuristics(sv::AbsIntState) = method_info(sv).method_for_inference_limit_heuristics + +frame_world(sv::InferenceState) = sv.world +frame_world(sv::IRInterpretationState) = sv.world + +callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle +callers_in_cycle(sv::IRInterpretationState) = () + +is_effect_overridden(sv::AbsIntState, effect::Symbol) = is_effect_overridden(frame_instance(sv), effect) +function is_effect_overridden(linfo::MethodInstance, effect::Symbol) + def = linfo.def + return isa(def, Method) && is_effect_overridden(def, effect) +end +is_effect_overridden(method::Method, effect::Symbol) = is_effect_overridden(decode_effects_override(method.purity), effect) +is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(override, effect) + +has_conditional(๐•ƒ::AbstractLattice, ::InferenceState) = has_conditional(๐•ƒ) +has_conditional(::AbstractLattice, ::IRInterpretationState) = false + +# work towards converging the valid age range for sv +function update_valid_age!(sv::AbsIntState, valid_worlds::WorldRange) + valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds) + @assert sv.world in valid_worlds "invalid age range update" + return valid_worlds +end + +""" + AbsIntStackUnwind(sv::AbsIntState) + +Iterate through all callers of the given `AbsIntState` in the abstract interpretation stack +(including the given `AbsIntState` itself), visiting children before their parents (i.e. +ascending the tree from the given `AbsIntState`). +Note that cycles may be visited in any order. +""" +struct AbsIntStackUnwind + sv::AbsIntState +end +iterate(unw::AbsIntStackUnwind) = (unw.sv, (unw.sv, 0)) +function iterate(unw::AbsIntStackUnwind, (sv, cyclei)::Tuple{AbsIntState, Int}) + # iterate through the cycle before walking to the parent + if cyclei < length(callers_in_cycle(sv)) + cyclei += 1 + parent = callers_in_cycle(sv)[cyclei] + else + cyclei = 0 + parent = frame_parent(sv) + end + parent === nothing && return nothing + return (parent, (parent, cyclei)) +end + +# temporarily accumulate our edges to later add as backedges in the callee +function add_backedge!(caller::InferenceState, mi::MethodInstance) + isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance + return push!(get_stmt_edges!(caller), mi) +end +function add_backedge!(irsv::IRInterpretationState, mi::MethodInstance) + return push!(irsv.edges, mi) +end + +function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), mi::MethodInstance) + isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance + return push!(get_stmt_edges!(caller), invokesig, mi) +end +function add_invoke_backedge!(irsv::IRInterpretationState, @nospecialize(invokesig::Type), mi::MethodInstance) + return push!(irsv.edges, invokesig, mi) +end + +# used to temporarily accumulate our no method errors to later add as backedges in the callee method table +function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ)) + isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance + return push!(get_stmt_edges!(caller), mt, typ) +end +function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospecialize(typ)) + return push!(irsv.edges, mt, typ) +end + +get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc] +get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag] + +add_curr_ssaflag!(sv::InferenceState, flag::UInt8) = sv.src.ssaflags[sv.currpc] |= flag +add_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx][:flag] |= flag + +sub_curr_ssaflag!(sv::InferenceState, flag::UInt8) = sv.src.ssaflags[sv.currpc] &= ~flag +sub_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx][:flag] &= ~flag + +merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects) = + caller.ipo_effects = merge_effects(caller.ipo_effects, effects) +merge_effects!(::AbstractInterpreter, ::IRInterpretationState, ::Effects) = return + +struct InferenceLoopState + sig + rt + effects::Effects + function InferenceLoopState(@nospecialize(sig), @nospecialize(rt), effects::Effects) + new(sig, rt, effects) + end +end + +bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState) = + sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig) +bail_out_toplevel_call(::AbstractInterpreter, ::InferenceLoopState, ::IRInterpretationState) = false + +bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState) = + state.rt === Any && !is_foldable(state.effects) +bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::IRInterpretationState) = + state.rt === Any && !is_foldable(state.effects) + +bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState) = + state.rt === Any +bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::IRInterpretationState) = + state.rt === Any + +function should_infer_this_call(interp::AbstractInterpreter, sv::InferenceState) + if InferenceParams(interp).unoptimize_throw_blocks + # Disable inference of calls in throw blocks, since we're unlikely to + # need their types. There is one exception however: If up until now, the + # function has not seen any side effects, we would like to make sure there + # aren't any in the throw block either to enable other optimizations. + if is_stmt_throw_block(get_curr_ssaflag(sv)) + should_infer_for_effects(sv) || return false + end + end + return true +end +function should_infer_for_effects(sv::InferenceState) + effects = sv.ipo_effects + return is_terminates(effects) && is_effect_free(effects) +end +should_infer_this_call(::AbstractInterpreter, ::IRInterpretationState) = true + +add_remark!(::AbstractInterpreter, ::InferenceState, remark) = return +add_remark!(::AbstractInterpreter, ::IRInterpretationState, remark) = return diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 1b5f45fac3e2e..84817b1bf6531 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -125,9 +125,9 @@ struct InliningState{Interp<:AbstractInterpreter} world::UInt interp::Interp end -function InliningState(frame::InferenceState, interp::AbstractInterpreter) - et = EdgeTracker(frame.stmt_edges[1]::Vector{Any}, frame.valid_worlds) - return InliningState(et, frame.world, interp) +function InliningState(sv::InferenceState, interp::AbstractInterpreter) + et = EdgeTracker(sv.stmt_edges[1]::Vector{Any}, sv.valid_worlds) + return InliningState(et, sv.world, interp) end function InliningState(interp::AbstractInterpreter) return InliningState(nothing, get_world_counter(interp), interp) @@ -150,12 +150,12 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter} cfg::Union{Nothing,CFG} insert_coverage::Bool end -function OptimizationState(frame::InferenceState, interp::AbstractInterpreter, +function OptimizationState(sv::InferenceState, interp::AbstractInterpreter, recompute_cfg::Bool=true) - inlining = InliningState(frame, interp) - cfg = recompute_cfg ? nothing : frame.cfg - return OptimizationState(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod, - frame.sptypes, frame.slottypes, inlining, cfg, frame.insert_coverage) + inlining = InliningState(sv, interp) + cfg = recompute_cfg ? nothing : sv.cfg + return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod, + sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage) end function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::AbstractInterpreter) # prepare src for running optimization passes if it isn't already @@ -389,9 +389,9 @@ function argextype( return Const(x) end end +abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any}) abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s] - """ finish(interp::AbstractInterpreter, opt::OptimizationState, ir::IRCode, caller::InferenceResult) diff --git a/base/compiler/ssair/irinterp.jl b/base/compiler/ssair/irinterp.jl index 49c2f54278545..d171cceb842e9 100644 --- a/base/compiler/ssair/irinterp.jl +++ b/base/compiler/ssair/irinterp.jl @@ -1,189 +1,73 @@ -mutable struct TwoPhaseVectorView <: AbstractVector{Int} - const data::Vector{Int} - count::Int - const range::UnitRange{Int} -end -size(tpvv::TwoPhaseVectorView) = (tpvv.count,) -function getindex(tpvv::TwoPhaseVectorView, i::Int) - checkbounds(tpvv, i) - @inbounds tpvv.data[first(tpvv.range) + i - 1] -end -function push!(tpvv::TwoPhaseVectorView, v::Int) - tpvv.count += 1 - tpvv.data[first(tpvv.range) + tpvv.count - 1] = v - return nothing -end - -""" - mutable struct TwoPhaseDefUseMap - -This struct is intended as a memory- and GC-pressure-efficient mechanism -for incrementally computing def-use maps. The idea is that the def-use map -is constructed into two passes over the IR. In the first, we simply count the -the number of uses, computing the number of uses for each def as well as the -total number of uses. In the second pass, we actually fill in the def-use -information. - -The idea is that either of these two phases can be combined with other useful -work that needs to scan the instruction stream anyway, while avoiding the -significant allocation pressure of e.g. allocating an array for every SSA value -or attempting to dynamically move things around as new uses are discovered. - -The def-use map is presented as a vector of vectors. For every def, indexing -into the map will return a vector of uses. -""" -mutable struct TwoPhaseDefUseMap <: AbstractVector{TwoPhaseVectorView} - ssa_uses::Vector{Int} - data::Vector{Int} - complete::Bool -end - -function complete!(tpdum::TwoPhaseDefUseMap) - cumsum = 0 - for i = 1:length(tpdum.ssa_uses) - this_val = cumsum + 1 - cumsum += tpdum.ssa_uses[i] - tpdum.ssa_uses[i] = this_val - end - resize!(tpdum.data, cumsum) - fill!(tpdum.data, 0) - tpdum.complete = true -end - -function TwoPhaseDefUseMap(nssas::Int) - ssa_uses = zeros(Int, nssas) - data = Int[] - complete = false - return TwoPhaseDefUseMap(ssa_uses, data, complete) -end - -function count!(tpdum::TwoPhaseDefUseMap, arg::SSAValue) - @assert !tpdum.complete - tpdum.ssa_uses[arg.id] += 1 -end - -function kill_def_use!(tpdum::TwoPhaseDefUseMap, def::Int, use::Int) - if !tpdum.complete - tpdum.ssa_uses[def] -= 1 - else - range = tpdum.ssa_uses[def]:(def == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[def + 1] - 1)) - # TODO: Sorted - useidx = findfirst(idx->tpdum.data[idx] == use, range) - @assert useidx !== nothing - idx = range[useidx] - while idx < lastindex(range) - ndata = tpdum.data[idx+1] - ndata == 0 && break - tpdum.data[idx] = ndata - end - tpdum.data[idx + 1] = 0 - end -end -kill_def_use!(tpdum::TwoPhaseDefUseMap, def::SSAValue, use::Int) = - kill_def_use!(tpdum, def.id, use) - -function getindex(tpdum::TwoPhaseDefUseMap, idx::Int) - @assert tpdum.complete - range = tpdum.ssa_uses[idx]:(idx == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[idx + 1] - 1)) - # TODO: Make logarithmic - nelems = 0 - for i in range - tpdum.data[i] == 0 && break - nelems += 1 - end - return TwoPhaseVectorView(tpdum.data, nelems, range) -end - -struct IRInterpretationState - ir::IRCode - mi::MethodInstance - world::UInt - argtypes_refined::Vector{Bool} - tpdum::TwoPhaseDefUseMap - ssa_refined::BitSet - lazydomtree::LazyDomtree - function IRInterpretationState(interp::AbstractInterpreter, - ir::IRCode, mi::MethodInstance, world::UInt, argtypes::Vector{Any}) - argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, mi) - for i = 1:length(argtypes) - argtypes[i] = widenslotwrapper(argtypes[i]) - end - argtypes_refined = Bool[!โŠ‘(optimizer_lattice(interp), ir.argtypes[i], argtypes[i]) for i = 1:length(argtypes)] - empty!(ir.argtypes) - append!(ir.argtypes, argtypes) - tpdum = TwoPhaseDefUseMap(length(ir.stmts)) - ssa_refined = BitSet() - lazydomtree = LazyDomtree(ir) - return new(ir, mi, world, argtypes_refined, tpdum, ssa_refined, lazydomtree) - end -end - -function codeinst_to_ir(interp::AbstractInterpreter, code::CodeInstance) - src = @atomic :monotonic code.inferred - mi = code.def - if isa(src, Vector{UInt8}) - src = ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any), mi.def, C_NULL, src)::CodeInfo - else - isa(src, CodeInfo) || return nothing - end - return inflate_ir(src, mi) -end +# This file is a part of Julia. License is MIT: https://julialang.org/license +# TODO (#48913) remove this overload to enable interprocedural call inference from irinterp function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), - arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), - sv::IRCode, max_methods::Int) + arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), + sv::IRInterpretationState, max_methods::Int) return CallMeta(Any, Effects(), NoCallInfo()) end -function collect_limitations!(@nospecialize(typ), ::IRCode) - @assert !isa(typ, LimitedAccuracy) "semi-concrete eval on recursive call graph" +function collect_limitations!(@nospecialize(typ), ::IRInterpretationState) + @assert !isa(typ, LimitedAccuracy) "irinterp is unable to handle heavy recursion" return typ end function concrete_eval_invoke(interp::AbstractInterpreter, inst::Expr, mi::MethodInstance, irsv::IRInterpretationState) - mi_cache = WorldView(code_cache(interp), irsv.world) + world = frame_world(irsv) + mi_cache = WorldView(code_cache(interp), world) code = get(mi_cache, mi, nothing) - if code === nothing - return Pair{Any, Bool}(nothing, false) - end - argtypes = collect_argtypes(interp, inst.args[2:end], nothing, irsv.ir) - argtypes === nothing && return Pair{Any, Bool}(Union{}, false) + code === nothing && return Pair{Any,Bool}(nothing, false) + argtypes = collect_argtypes(interp, inst.args[2:end], nothing, irsv) + argtypes === nothing && return Pair{Any,Bool}(Bottom, false) effects = decode_effects(code.ipo_purity_bits) if is_foldable(effects) && is_all_const_arg(argtypes, #=start=#1) args = collect_const_args(argtypes, #=start=#1) - world = get_world_counter(interp) - value = try - Core._call_in_world_total(world, args...) - catch - return Pair{Any, Bool}(Union{}, false) + value = let world = get_world_counter(interp) + try + Core._call_in_world_total(world, args...) + catch + return Pair{Any,Bool}(Bottom, false) + end end - return Pair{Any, Bool}(Const(value), true) + return Pair{Any,Bool}(Const(value), true) else - irโ€ฒ = codeinst_to_ir(interp, code) - if irโ€ฒ !== nothing - irsvโ€ฒ = IRInterpretationState(interp, irโ€ฒ, mi, irsv.world, argtypes) - return _ir_abstract_constant_propagation(interp, irsvโ€ฒ) + if is_constprop_edge_recursed(mi, irsv) + return Pair{Any,Bool}(nothing, is_nothrow(effects)) + end + newirsv = IRInterpretationState(interp, code, mi, argtypes, world) + if newirsv !== nothing + newirsv.parent = irsv + return _ir_abstract_constant_propagation(interp, newirsv) end + return Pair{Any,Bool}(nothing, is_nothrow(effects)) end - return Pair{Any, Bool}(nothing, is_nothrow(effects)) end +abstract_eval_ssavalue(s::SSAValue, sv::IRInterpretationState) = abstract_eval_ssavalue(s, sv.ir) + function abstract_eval_phi_stmt(interp::AbstractInterpreter, phi::PhiNode, ::Int, irsv::IRInterpretationState) - return abstract_eval_phi(interp, phi, nothing, irsv.ir) + return abstract_eval_phi(interp, phi, nothing, irsv) end function propagate_control_effects!(interp::AbstractInterpreter, idx::Int, stmt::GotoIfNot, - irsv::IRInterpretationState, reprocess::Union{Nothing, BitSet, BitSetBoundedMinPrioritySet}) + irsv::IRInterpretationState, extra_reprocess::Union{Nothing,BitSet,BitSetBoundedMinPrioritySet}) # Nothing to do for most abstract interpreters, but if the abstract # interpreter has control-dependent lattice effects, it can override # this method. return false end -function reprocess_instruction!(interp::AbstractInterpreter, - idx::Int, bb::Union{Int, Nothing}, @nospecialize(inst), @nospecialize(typ), - irsv::IRInterpretationState, reprocess::Union{Nothing, BitSet, BitSetBoundedMinPrioritySet}) +function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState) + si = StmtInfo(true) # TODO better job here? + (; rt, effects, info) = abstract_call(interp, arginfo, si, irsv) + irsv.ir.stmts[irsv.curridx][:info] = info + return RTEffects(rt, effects) +end + +function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union{Int,Nothing}, + @nospecialize(inst), @nospecialize(typ), irsv::IRInterpretationState, + extra_reprocess::Union{Nothing,BitSet,BitSetBoundedMinPrioritySet}) ir = irsv.ir if isa(inst, GotoIfNot) cond = inst.cond @@ -192,22 +76,22 @@ function reprocess_instruction!(interp::AbstractInterpreter, function update_phi!(from::Int, to::Int) if length(ir.cfg.blocks[to].preds) == 0 # Kill the entire block - for idx in ir.cfg.blocks[to].stmts - ir.stmts[idx][:inst] = nothing - ir.stmts[idx][:type] = Union{} - ir.stmts[idx][:flag] = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW + for bidx = ir.cfg.blocks[to].stmts + ir.stmts[bidx][:inst] = nothing + ir.stmts[bidx][:type] = Bottom + ir.stmts[bidx][:flag] = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW end return end - for idx in ir.cfg.blocks[to].stmts - stmt = ir.stmts[idx][:inst] - isa(stmt, Nothing) && continue # allowed between `PhiNode`s - isa(stmt, PhiNode) || break - for (i, edge) in enumerate(stmt.edges) + for sidx = ir.cfg.blocks[to].stmts + sinst = ir.stmts[sidx][:inst] + isa(sinst, Nothing) && continue # allowed between `PhiNode`s + isa(sinst, PhiNode) || break + for (eidx, edge) in enumerate(sinst.edges) if edge == from - deleteat!(stmt.edges, i) - deleteat!(stmt.values, i) - push!(irsv.ssa_refined, idx) + deleteat!(sinst.edges, eidx) + deleteat!(sinst.values, eidx) + push!(irsv.ssa_refined, sidx) break end end @@ -230,27 +114,24 @@ function reprocess_instruction!(interp::AbstractInterpreter, end return true end - return propagate_control_effects!(interp, idx, inst, irsv, reprocess) + return propagate_control_effects!(interp, idx, inst, irsv, extra_reprocess) end rt = nothing if isa(inst, Expr) head = inst.head if head === :call || head === :foreigncall || head === :new || head === :splatnew - (; rt, effects) = abstract_eval_statement_expr(interp, inst, nothing, ir, irsv.mi) + (; rt, effects) = abstract_eval_statement_expr(interp, inst, nothing, irsv) ir.stmts[idx][:flag] |= flags_for_effects(effects) if is_foldable(effects) && isa(rt, Const) && is_inlineable_constant(rt.val) ir.stmts[idx][:inst] = quoted(rt.val) end elseif head === :invoke - miโ€ฒ = inst.args[1]::MethodInstance - if miโ€ฒ !== irsv.mi # prevent infinite loop - rt, nothrow = concrete_eval_invoke(interp, inst, miโ€ฒ, irsv) - if nothrow - ir.stmts[idx][:flag] |= IR_FLAG_NOTHROW - if isa(rt, Const) && is_inlineable_constant(rt.val) - ir.stmts[idx][:inst] = quoted(rt.val) - end + rt, nothrow = concrete_eval_invoke(interp, inst, inst.args[1]::MethodInstance, irsv) + if nothrow + ir.stmts[idx][:flag] |= IR_FLAG_NOTHROW + if isa(rt, Const) && is_inlineable_constant(rt.val) + ir.stmts[idx][:inst] = quoted(rt.val) end end elseif head === :throw_undef_if_not || # TODO: Terminate interpretation early if known false? @@ -258,7 +139,6 @@ function reprocess_instruction!(interp::AbstractInterpreter, head === :gc_preserve_end return false else - ccall(:jl_, Cvoid, (Any,), inst) error("reprocess_instruction!: unhandled expression found") end elseif isa(inst, PhiNode) @@ -273,8 +153,7 @@ function reprocess_instruction!(interp::AbstractInterpreter, elseif isa(inst, GlobalRef) # GlobalRef is not refinable else - ccall(:jl_, Cvoid, (Any,), inst) - error() + error("reprocess_instruction!: unhandled instruction found") end if rt !== nothing && !โŠ‘(optimizer_lattice(interp), typ, rt) ir.stmts[idx][:type] = rt @@ -283,10 +162,9 @@ function reprocess_instruction!(interp::AbstractInterpreter, return false end -# Process the terminator and add the successor to `ip`. Returns whether a backedge was seen. -function process_terminator!(ir::IRCode, idx::Int, bb::Int, - all_rets::Vector{Int}, ip::BitSetBoundedMinPrioritySet) - inst = ir.stmts[idx][:inst] +# Process the terminator and add the successor to `bb_ip`. Returns whether a backedge was seen. +function process_terminator!(ir::IRCode, @nospecialize(inst), idx::Int, bb::Int, + all_rets::Vector{Int}, bb_ip::BitSetBoundedMinPrioritySet) if isa(inst, ReturnNode) if isdefined(inst, :val) push!(all_rets, idx) @@ -294,43 +172,44 @@ function process_terminator!(ir::IRCode, idx::Int, bb::Int, return false elseif isa(inst, GotoNode) backedge = inst.label <= bb - !backedge && push!(ip, inst.label) + backedge || push!(bb_ip, inst.label) return backedge elseif isa(inst, GotoIfNot) backedge = inst.dest <= bb - !backedge && push!(ip, inst.dest) - push!(ip, bb + 1) + backedge || push!(bb_ip, inst.dest) + push!(bb_ip, bb+1) return backedge elseif isexpr(inst, :enter) dest = inst.args[1]::Int @assert dest > bb - push!(ip, dest) - push!(ip, bb + 1) + push!(bb_ip, dest) + push!(bb_ip, bb+1) return false else - push!(ip, bb + 1) + push!(bb_ip, bb+1) return false end end -default_reprocess(interp::AbstractInterpreter, irsv::IRInterpretationState) = nothing +default_reprocess(::AbstractInterpreter, ::IRInterpretationState) = nothing function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState; extra_reprocess::Union{Nothing,BitSet} = default_reprocess(interp, irsv)) (; ir, tpdum, ssa_refined) = irsv bbs = ir.cfg.blocks - ip = BitSetBoundedMinPrioritySet(length(bbs)) - push!(ip, 1) + bb_ip = BitSetBoundedMinPrioritySet(length(bbs)) + push!(bb_ip, 1) all_rets = Int[] # Fast path: Scan both use counts and refinement in one single pass of # of the instructions. In the absence of backedges, this will # converge. - while !isempty(ip) - bb = popfirst!(ip) + while !isempty(bb_ip) + bb = popfirst!(bb_ip) stmts = bbs[bb].stmts lstmt = last(stmts) for idx = stmts + irsv.curridx = idx inst = ir.stmts[idx][:inst] typ = ir.stmts[idx][:type] any_refined = false @@ -357,11 +236,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR idx, bb, inst, typ, irsv, extra_reprocess) push!(ssa_refined, idx) end - if idx == lstmt - if process_terminator!(ir, idx, bb, all_rets, ip) - @goto residual_scan - end - end + idx == lstmt && process_terminator!(ir, inst, idx, bb, all_rets, bb_ip) && @goto residual_scan if typ === Bottom && !isa(inst, PhiNode) break end @@ -377,11 +252,12 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR end # Slow Path Phase 1.A: Complete use scanning - while !isempty(ip) - bb = popfirst!(ip) + while !isempty(bb_ip) + bb = popfirst!(bb_ip) stmts = bbs[bb].stmts lstmt = last(stmts) for idx = stmts + irsv.curridx = idx inst = ir.stmts[idx][:inst] for ur in userefs(inst) val = ur[] @@ -393,18 +269,19 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR count!(tpdum, val) end end - idx == lstmt && process_terminator!(ir, idx, bb, all_rets, ip) + idx == lstmt && process_terminator!(ir, inst, idx, bb, all_rets, bb_ip) end end # Slow Path Phase 1.B: Assemble def-use map complete!(tpdum) - push!(ip, 1) - while !isempty(ip) - bb = popfirst!(ip) + push!(bb_ip, 1) + while !isempty(bb_ip) + bb = popfirst!(bb_ip) stmts = bbs[bb].stmts lstmt = last(stmts) for idx = stmts + irsv.curridx = idx inst = ir.stmts[idx][:inst] for ur in userefs(inst) val = ur[] @@ -412,7 +289,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR push!(tpdum[val.id], idx) end end - idx == lstmt && process_terminator!(ir, idx, bb, all_rets, ip) + idx == lstmt && process_terminator!(ir, inst, idx, bb, all_rets, bb_ip) end end @@ -424,6 +301,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR end while !isempty(stmt_ip) idx = popfirst!(stmt_ip) + irsv.curridx = idx inst = ir.stmts[idx][:inst] typ = ir.stmts[idx][:type] if reprocess_instruction!(interp, @@ -434,7 +312,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR end begin @label compute_rt - ultimate_rt = Union{} + ultimate_rt = Bottom for idx in all_rets bb = block_for_inst(ir.cfg, idx) if bb != 1 && length(ir.cfg.blocks[bb].preds) == 0 @@ -448,26 +326,32 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR end nothrow = true - for i = 1:length(ir.stmts) - if (ir.stmts[i][:flag] & IR_FLAG_NOTHROW) == 0 + for idx = 1:length(ir.stmts) + if (ir.stmts[idx][:flag] & IR_FLAG_NOTHROW) == 0 nothrow = false break end end - return Pair{Any, Bool}(maybe_singleton_const(ultimate_rt), nothrow) + if last(irsv.valid_worlds) >= get_world_counter() + # if we aren't cached, we don't need this edge + # but our caller might, so let's just make it anyways + store_backedges(frame_instance(irsv), irsv.edges) + end + + return Pair{Any,Bool}(maybe_singleton_const(ultimate_rt), nothrow) end function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState) + irinterp = switch_to_irinterp(interp) if __measure_typeinf__[] inf_frame = Timings.InferenceFrameInfo(irsv.mi, irsv.world, VarState[], Any[], length(irsv.ir.argtypes)) Timings.enter_new_timer(inf_frame) - v = _ir_abstract_constant_propagation(interp, irsv) + ret = _ir_abstract_constant_propagation(irinterp, irsv) append!(inf_frame.slottypes, irsv.ir.argtypes) Timings.exit_current_timer(inf_frame) - return v + return ret else - T = _ir_abstract_constant_propagation(interp, irsv) - return T + return _ir_abstract_constant_propagation(irinterp, irsv) end end diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 7a9c877b6c2f3..a2508992bf290 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -604,21 +604,6 @@ function is_old(compact, @nospecialize(old_node_ssa)) !already_inserted(compact, old_node_ssa) end -mutable struct LazyGenericDomtree{IsPostDom} - ir::IRCode - domtree::GenericDomTree{IsPostDom} - LazyGenericDomtree{IsPostDom}(ir::IRCode) where {IsPostDom} = new{IsPostDom}(ir) -end -function get!(x::LazyGenericDomtree{IsPostDom}) where {IsPostDom} - isdefined(x, :domtree) && return x.domtree - return @timeit "domtree 2" x.domtree = IsPostDom ? - construct_postdomtree(x.ir.cfg.blocks) : - construct_domtree(x.ir.cfg.blocks) -end - -const LazyDomtree = LazyGenericDomtree{false} -const LazyPostDomtree = LazyGenericDomtree{true} - function perform_lifting!(compact::IncrementalCompact, visited_phinodes::Vector{AnySSAValue}, @nospecialize(cache_key), lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}, diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 89ae4a1c26df6..84794a27ec034 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -1365,7 +1365,7 @@ end PT = Const(Pair) return instanceof_tfunc(apply_type_tfunc(๐•ƒ, PT, T, T))[1] end -function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::InferenceState) +function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::AbsIntState) nargs = length(argtypes) if !isempty(argtypes) && isvarargtype(argtypes[nargs]) nargs - 1 <= 6 || return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo()) @@ -1973,7 +1973,7 @@ function array_elmtype(@nospecialize ary) return Any end -@nospecs function _opaque_closure_tfunc(๐•ƒ::AbstractLattice, arg, lb, ub, source, env::Vector{Any}, linfo::MethodInstance) +@nospecs function opaque_closure_tfunc(๐•ƒ::AbstractLattice, arg, lb, ub, source, env::Vector{Any}, linfo::MethodInstance) argt, argt_exact = instanceof_tfunc(arg) lbt, lb_exact = instanceof_tfunc(lb) if !lb_exact @@ -2363,7 +2363,7 @@ function builtin_nothrow(๐•ƒ::AbstractLattice, @nospecialize(f), argtypes::Vect end function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, - sv::Union{InferenceState,IRCode,Nothing}) + sv::Union{AbsIntState, Nothing}) ๐•ƒแตข = typeinf_lattice(interp) if f === tuple return tuple_tfunc(๐•ƒแตข, argtypes) @@ -2544,7 +2544,7 @@ end # TODO: this function is a very buggy and poor model of the return_type function # since abstract_call_gf_by_type is a very inaccurate model of _method and of typeinf_type, # while this assumes that it is an absolutely precise and accurate and exact model of both -function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::Union{InferenceState, IRCode}) +function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::AbsIntState) if length(argtypes) == 3 tt = widenslotwrapper(argtypes[3]) if isa(tt, Const) || (isType(tt) && !has_free_typevars(tt)) @@ -2603,7 +2603,7 @@ end # a simplified model of abstract_call_gf_by_type for applicable function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any}, - sv::InferenceState, max_methods::Int) + sv::AbsIntState, max_methods::Int) length(argtypes) < 2 && return CallMeta(Union{}, EFFECTS_UNKNOWN, NoCallInfo()) isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo()) argtypes = argtypes[2:end] @@ -2649,7 +2649,7 @@ end add_tfunc(applicable, 1, INT_INF, @nospecs((๐•ƒ::AbstractLattice, f, args...)->Bool), 40) # a simplified model of abstract_invoke for Core._hasmethod -function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState) +function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::AbsIntState) if length(argtypes) == 3 && !isvarargtype(argtypes[3]) ftโ€ฒ = argtype_by_index(argtypes, 2) ft = widenconst(ftโ€ฒ) diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 03b074bbec318..1eec73d0435bd 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -206,6 +206,7 @@ const __measure_typeinf__ = fill(false) # Wrapper around _typeinf that optionally records the exclusive time for each invocation. function typeinf(interp::AbstractInterpreter, frame::InferenceState) + interp = switch_from_irinterp(interp) if __measure_typeinf__[] Timings.enter_new_timer(frame) v = _typeinf(interp, frame) @@ -564,23 +565,22 @@ function finish(me::InferenceState, interp::AbstractInterpreter) end # record the backedges -function store_backedges(frame::InferenceResult, edges::Vector{Any}) - toplevel = !isa(frame.linfo.def, Method) - if !toplevel - store_backedges(frame.linfo, edges) - end - nothing +function store_backedges(caller::InferenceResult, edges::Vector{Any}) + isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance + return store_backedges(caller.linfo, edges) end -function store_backedges(frame::MethodInstance, edges::Vector{Any}) - for (; sig, caller) in BackedgeIterator(edges) - if isa(caller, MethodInstance) - ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), caller, sig, frame) +function store_backedges(caller::MethodInstance, edges::Vector{Any}) + for itr in BackedgeIterator(edges) + callee = itr.caller + if isa(callee, MethodInstance) + ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller) else - typeassert(caller, MethodTable) - ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), caller, sig, frame) + typeassert(callee, MethodTable) + ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller) end end + return nothing end function record_slot_assign!(sv::InferenceState) @@ -784,15 +784,20 @@ function merge_call_chain!(interp::AbstractInterpreter, parent::InferenceState, union_caller_cycle!(ancestor, child) child = parent child === ancestor && break - parent = child.parent::InferenceState + parent = frame_parent(child) + while !isa(parent, InferenceState) + # XXX we may miss some edges here? + parent = frame_parent(parent::IRInterpretationState) + end + parent = parent::InferenceState end end function is_same_frame(interp::AbstractInterpreter, mi::MethodInstance, frame::InferenceState) - return mi === frame.linfo + return mi === frame_instance(frame) end -function poison_callstack(infstate::InferenceState, topmost::InferenceState) +function poison_callstack!(infstate::InferenceState, topmost::InferenceState) push!(infstate.pclimitations, topmost) nothing end @@ -804,33 +809,38 @@ end # frame's `callers_in_cycle` field and adding the appropriate backedges. Finally, # we return `mi`'s pre-existing frame. If no cycles are found, `nothing` is # returned instead. -function resolve_call_cycle!(interp::AbstractInterpreter, mi::MethodInstance, parent::InferenceState) +function resolve_call_cycle!(interp::AbstractInterpreter, mi::MethodInstance, parent::AbsIntState) + # TODO (#48913) implement a proper recursion handling for irinterp: + # This works just because currently the `:terminate` condition guarantees that + # irinterp doesn't fail into unresolved cycles, but it's not a good solution. + # We should revisit this once we have a better story for handling cycles in irinterp. + isa(parent, InferenceState) || return false frame = parent uncached = false while isa(frame, InferenceState) - uncached |= !frame.cached # ensure we never add an uncached frame to a cycle + uncached |= !is_cached(frame) # ensure we never add an uncached frame to a cycle if is_same_frame(interp, mi, frame) if uncached # our attempt to speculate into a constant call lead to an undesired self-cycle # that cannot be converged: poison our call-stack (up to the discovered duplicate frame) # with the limited flag and abort (set return type to Any) now - poison_callstack(parent, frame) + poison_callstack!(parent, frame) return true end merge_call_chain!(interp, parent, frame, frame) return frame end - for caller in frame.callers_in_cycle + for caller in callers_in_cycle(frame) if is_same_frame(interp, mi, caller) if uncached - poison_callstack(parent, frame) + poison_callstack!(parent, frame) return true end merge_call_chain!(interp, parent, frame, caller) return caller end end - frame = frame.parent + frame = frame_parent(frame) end return false end @@ -851,7 +861,7 @@ struct EdgeCallResult end # compute (and cache) an inferred AST and return the current best estimate of the result type -function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize(atype), sparams::SimpleVector, caller::InferenceState) +function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize(atype), sparams::SimpleVector, caller::AbsIntState) mi = specialize_method(method, atype, sparams)::MethodInstance code = get(code_cache(interp), mi, nothing) if code isa CodeInstance # return existing rettype if the code is already inferred @@ -890,9 +900,9 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize add_remark!(interp, caller, "Inference is disabled for the target module") return EdgeCallResult(Any, nothing, Effects()) end - if !caller.cached && caller.parent === nothing + if !is_cached(caller) && frame_parent(caller) === nothing # this caller exists to return to the user - # (if we asked resolve_call_cyle, it might instead detect that there is a cycle that it can't merge) + # (if we asked resolve_call_cycle!, it might instead detect that there is a cycle that it can't merge) frame = false else frame = resolve_call_cycle!(interp, mi, caller) @@ -908,7 +918,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize unlock_mi_inference(interp, mi) return EdgeCallResult(Any, nothing, Effects()) end - if caller.cached || caller.parent !== nothing # don't involve uncached functions in cycle resolution + if is_cached(caller) || frame_parent(caller) !== nothing # don't involve uncached functions in cycle resolution frame.parent = caller end typeinf(interp, frame) diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 91d585cdd76ff..c987e03df5261 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -32,6 +32,14 @@ struct StmtInfo used::Bool end +struct MethodInfo + propagate_inbounds::Bool + method_for_inference_limit_heuristics::Union{Nothing,Method} +end +MethodInfo(src::CodeInfo) = MethodInfo( + src.propagate_inbounds, + src.method_for_inference_limit_heuristics::Union{Nothing,Method}) + """ v::VarState @@ -465,13 +473,23 @@ optimizer_lattice(interp::NativeInterpreter) = OptimizerLattice(SimpleInferenceL """ switch_to_irinterp(interp::AbstractInterpreter) -> irinterp::AbstractInterpreter -Optionally convert `interp` to new `irinterp::AbstractInterpreter` to perform semi-concrete -interpretation. `NativeInterpreter` uses this interface to switch its lattice to -`optimizer_lattice` during semi-concrete interpretation on `IRCode`. +This interface allows `ir_abstract_constant_propagation` to convert `interp` to a new +`irinterp::AbstractInterpreter` to perform semi-concrete interpretation. +`NativeInterpreter` uses this interface to switch its lattice to `optimizer_lattice` during +semi-concrete interpretation on `IRCode`. """ switch_to_irinterp(interp::AbstractInterpreter) = interp switch_to_irinterp(interp::NativeInterpreter) = NativeInterpreter(interp; irinterp=true) +""" + switch_from_irinterp(irinterp::AbstractInterpreter) -> interp::AbstractInterpreter + +The inverse operation of `switch_to_irinterp`, allowing `typeinf` to convert `irinterp` back +to a new `interp::AbstractInterpreter` to perform ordinary abstract interpretation. +""" +switch_from_irinterp(irinterp::AbstractInterpreter) = irinterp +switch_from_irinterp(irinterp::NativeInterpreter) = NativeInterpreter(irinterp; irinterp=false) + abstract type CallInfo end @nospecialize diff --git a/test/compiler/datastructures.jl b/test/compiler/datastructures.jl index a25a884373ab4..8dbaee61503d0 100644 --- a/test/compiler/datastructures.jl +++ b/test/compiler/datastructures.jl @@ -7,7 +7,7 @@ using Test table = Core.Compiler.method_table(interp) sig = Tuple{typeof(*), Any, Any} result1 = Core.Compiler.findall(sig, table; limit=-1) - result2 = Core.Compiler.findall(sig, table; limit=Core.Compiler.get_max_methods(*, @__MODULE__, interp)) + result2 = Core.Compiler.findall(sig, table; limit=Core.Compiler.InferenceParams().max_methods) @test result1 !== nothing && !Core.Compiler.isempty(result1.matches) @test result2 === nothing end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 12fedf2792a61..1634345f70459 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -4762,3 +4762,64 @@ fhasmethod(::Integer, ::Int32) = 3 @test only(Base.return_types(()) do; Val(hasmethod(tuple, Tuple{Vararg{Int}})); end) === Val{true} @test only(Base.return_types(()) do; Val(hasmethod(sin, Tuple{Int, Vararg{Int}})); end) == Val{false} @test only(Base.return_types(()) do; Val(hasmethod(sin, Tuple{Int, Int, Vararg{Int}})); end) === Val{false} + +# TODO (#48913) enable interprocedural call inference from irinterp +# # interprocedural call inference from irinterp +# @noinline Base.@assume_effects :total issue48679_unknown_any(x) = Base.inferencebarrier(x) + +# @noinline _issue48679(y::Union{Nothing,T}) where {T} = T::Type +# Base.@constprop :aggressive function issue48679(x, b) +# if b +# x = issue48679_unknown_any(x) +# end +# return _issue48679(x) +# end +# @test Base.return_types((Float64,)) do x +# issue48679(x, false) +# end |> only == Type{Float64} + +# Base.@constprop :aggressive @noinline _issue48679_const(b, y::Union{Nothing,T}) where {T} = b ? nothing : T::Type +# Base.@constprop :aggressive function issue48679_const(x, b) +# if b +# x = issue48679_unknown_any(x) +# end +# return _issue48679_const(b, x) +# end +# @test Base.return_types((Float64,)) do x +# issue48679_const(x, false) +# end |> only == Type{Float64} + +# `invoke` call in irinterp +@noinline _irinterp_invoke(x::Any) = :any +@noinline _irinterp_invoke(x::T) where T = T +Base.@constprop :aggressive Base.@assume_effects :foldable function irinterp_invoke(x::T, b) where T + return @invoke _irinterp_invoke(x::(b ? T : Any)) +end +@test Base.return_types((Int,)) do x + irinterp_invoke(x, true) +end |> only == Type{Int} + +# recursion detection for semi-concrete interpretation +# avoid direct infinite loop via `concrete_eval_invoke` +Base.@assume_effects :foldable function recur_irinterp1(x, y) + if rand(Bool) + return x, y + end + return recur_irinterp1(x+1, y) +end +@test Base.return_types((Symbol,)) do y + recur_irinterp1(0, y) +end |> only === Tuple{Int,Symbol} +@test last(recur_irinterp1(0, :y)) === :y +# avoid indirect infinite loop via `concrete_eval_invoke` +Base.@assume_effects :foldable function recur_irinterp2(x, y) + if rand(Bool) + return x, y + end + return _recur_irinterp2(x+1, y) +end +Base.@assume_effects :foldable _recur_irinterp2(x, y) = @noinline recur_irinterp2(x, y) +@test Base.return_types((Symbol,)) do y + recur_irinterp2(0, y) +end |> only === Tuple{Int,Symbol} +@test last(recur_irinterp2(0, :y)) === :y