Skip to content

Commit 56bea36

Browse files
committed
very wip: inference: allow semi-concrete interpret to perform recursive inference
fix #48679
1 parent 9b9b99f commit 56bea36

File tree

10 files changed

+281
-196
lines changed

10 files changed

+281
-196
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 100 additions & 92 deletions
Large diffs are not rendered by default.

base/compiler/compiler.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,25 @@ include("compiler/ssair/ir.jl")
154154
include("compiler/abstractlattice.jl")
155155

156156
include("compiler/inferenceresult.jl")
157+
158+
# TODO define the interface for this abstract type
159+
abstract type AbsIntState end
160+
function frame_instance end
161+
function frame_module(sv::AbsIntState)
162+
mi = frame_instance(sv)
163+
def = mi.def
164+
isa(def, Module) && return def
165+
return def.module
166+
end
167+
function frame_parent end
168+
function frame_cached end
169+
function frame_src end
170+
function callers_in_cycle end
171+
# function recur_state end
172+
# pclimitations(sv::AbsIntState) = recur_state(sv).pclimitations
173+
# limitations(sv::AbsIntState) = recur_state(sv).limitations
174+
# callers_in_cycle(sv::AbsIntState) = recur_state(sv).callers_in_cycle
175+
157176
include("compiler/inferencestate.jl")
158177

159178
include("compiler/typeutils.jl")

base/compiler/inferencestate.jl

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,18 @@ function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
7878
end
7979
end
8080

81-
mutable struct InferenceState
81+
struct AbsIntRecursionState
82+
pclimitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
83+
limitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on return
84+
callers_in_cycle::Vector{AbsIntState}
85+
end
86+
function AbsIntRecursionState()
87+
return AbsIntRecursionState(IdSet{AbsIntState}(),
88+
IdSet{AbsIntState}(),
89+
Vector{AbsIntState}())
90+
end
91+
92+
mutable struct InferenceState <: AbsIntState
8293
#= information about this method instance =#
8394
linfo::MethodInstance
8495
world::UInt
@@ -195,23 +206,25 @@ end
195206
is_inferred(sv::InferenceState) = is_inferred(sv.result)
196207
is_inferred(result::InferenceResult) = result.result !== nothing
197208

209+
frame_instance(sv::InferenceState) = sv.linfo
210+
frame_parent(sv::InferenceState) = sv.parent
211+
frame_cached(sv::InferenceState) = sv.cached
212+
frame_src(sv::InferenceState) = sv.src
213+
callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle
214+
198215
function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
199216
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
200217
end
201218

202-
merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) =
203-
merge_effects!(interp, caller, Effects(callee))
204-
merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing
205-
206-
is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect)
219+
is_effect_overridden(sv::AbsIntState, effect::Symbol) = is_effect_overridden(frame_instance(sv), effect)
207220
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
208221
def = linfo.def
209222
return isa(def, Method) && is_effect_overridden(def, effect)
210223
end
211224
is_effect_overridden(method::Method, effect::Symbol) = is_effect_overridden(decode_effects_override(method.purity), effect)
212225
is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(override, effect)
213226

214-
add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return
227+
add_remark!(::AbstractInterpreter, ::AbsIntState, remark) = return
215228

216229
struct InferenceLoopState
217230
sig
@@ -222,13 +235,13 @@ struct InferenceLoopState
222235
end
223236
end
224237

225-
function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
226-
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
238+
function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState)
239+
return sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
227240
end
228-
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
241+
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
229242
return state.rt === Any && !is_foldable(state.effects)
230243
end
231-
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
244+
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
232245
return state.rt === Any
233246
end
234247

@@ -347,21 +360,21 @@ end
347360
children before their parents (i.e. ascending the tree from the given
348361
InferenceState). Note that cycles may be visited in any order.
349362
"""
350-
struct InfStackUnwind
351-
inf::InferenceState
363+
struct InfStackUnwind{SV<:AbsIntState}
364+
inf::SV
352365
end
353366
iterate(unw::InfStackUnwind) = (unw.inf, (unw.inf, 0))
354-
function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState, Int})
367+
function iterate(unw::InfStackUnwind{SV}, (infstate, cyclei)::Tuple{SV, Int}) where SV<:AbsIntState
355368
# iterate through the cycle before walking to the parent
356-
if cyclei < length(infstate.callers_in_cycle)
369+
if cyclei < length(callers_in_cycle(infstate))
357370
cyclei += 1
358-
infstate = infstate.callers_in_cycle[cyclei]
371+
infstate = callers_in_cycle(infstate)[cyclei]
359372
else
360373
cyclei = 0
361-
infstate = infstate.parent
374+
infstate = frame_parent(infstate)
362375
end
363376
infstate === nothing && return nothing
364-
(infstate::InferenceState, (infstate, cyclei))
377+
(infstate, (infstate, cyclei))
365378
end
366379

367380
function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
@@ -500,12 +513,12 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
500513
return sptypes
501514
end
502515

503-
_topmod(sv::InferenceState) = _topmod(sv.mod)
516+
_topmod(sv::InferenceState) = _topmod(frame_module(sv))
504517

505518
# work towards converging the valid age range for sv
506519
function update_valid_age!(sv::InferenceState, valid_worlds::WorldRange)
507520
valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds)
508-
@assert(sv.world in valid_worlds, "invalid age range update")
521+
@assert sv.world in valid_worlds "invalid age range update"
509522
return valid_worlds
510523
end
511524

@@ -543,42 +556,31 @@ end
543556

544557
# temporarily accumulate our edges to later add as backedges in the callee
545558
function add_backedge!(caller::InferenceState, mi::MethodInstance)
546-
edges = get_stmt_edges!(caller)
547-
if edges !== nothing
548-
push!(edges, mi)
549-
end
550-
return nothing
559+
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
560+
return push!(get_stmt_edges!(caller), mi)
551561
end
552562

553563
function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), mi::MethodInstance)
554-
edges = get_stmt_edges!(caller)
555-
if edges !== nothing
556-
push!(edges, invokesig, mi)
557-
end
558-
return nothing
564+
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
565+
return push!(get_stmt_edges!(caller), invokesig, mi)
559566
end
560567

561568
# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
562569
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
563-
edges = get_stmt_edges!(caller)
564-
if edges !== nothing
565-
push!(edges, mt, typ)
566-
end
567-
return nothing
570+
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
571+
return push!(get_stmt_edges!(caller), mt, typ)
568572
end
569573

570-
function get_stmt_edges!(caller::InferenceState)
571-
if !isa(caller.linfo.def, Method)
572-
return nothing # don't add backedges to toplevel exprs
573-
end
574-
edges = caller.stmt_edges[caller.currpc]
574+
function get_stmt_edges!(caller::InferenceState, currpc::Int=caller.currpc)
575+
stmt_edges = caller.stmt_edges
576+
edges = stmt_edges[currpc]
575577
if edges === nothing
576-
edges = caller.stmt_edges[caller.currpc] = []
578+
edges = stmt_edges[currpc] = []
577579
end
578580
return edges
579581
end
580582

581-
function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
583+
function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc)
582584
edges = frame.stmt_edges[currpc]
583585
edges === nothing || empty!(edges)
584586
return nothing

base/compiler/optimize.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ struct InliningState{Interp<:AbstractInterpreter}
126126
world::UInt
127127
interp::Interp
128128
end
129-
function InliningState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
130-
et = EdgeTracker(frame.stmt_edges[1]::Vector{Any}, frame.valid_worlds)
131-
return InliningState(params, et, frame.world, interp)
129+
function InliningState(sv::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
130+
et = EdgeTracker(sv.stmt_edges[1]::Vector{Any}, sv.valid_worlds)
131+
return InliningState(params, et, sv.world, interp)
132132
end
133133
function InliningState(params::OptimizationParams, interp::AbstractInterpreter)
134134
return InliningState(params, nothing, get_world_counter(interp), interp)
@@ -151,12 +151,12 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
151151
cfg::Union{Nothing,CFG}
152152
insert_coverage::Bool
153153
end
154-
function OptimizationState(frame::InferenceState, params::OptimizationParams,
154+
function OptimizationState(sv::InferenceState, params::OptimizationParams,
155155
interp::AbstractInterpreter, recompute_cfg::Bool=true)
156-
inlining = InliningState(frame, params, interp)
157-
cfg = recompute_cfg ? nothing : frame.cfg
158-
return OptimizationState(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
159-
frame.sptypes, frame.slottypes, inlining, cfg, frame.insert_coverage)
156+
inlining = InliningState(sv, params, interp)
157+
cfg = recompute_cfg ? nothing : sv.cfg
158+
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, frame_module(sv),
159+
sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage)
160160
end
161161
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams,
162162
interp::AbstractInterpreter)
@@ -387,9 +387,9 @@ function argextype(
387387
return Const(x)
388388
end
389389
end
390+
abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any})
390391
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]
391392

392-
393393
"""
394394
finish(interp::AbstractInterpreter, opt::OptimizationState,
395395
params::OptimizationParams, ir::IRCode, caller::InferenceResult)

0 commit comments

Comments
 (0)