Skip to content

Commit a9ae1d2

Browse files
committed
adjust to the upstream irinterp refactoring
xref: JuliaLang/julia#48913
1 parent e0b24da commit a9ae1d2

File tree

3 files changed

+55
-51
lines changed

3 files changed

+55
-51
lines changed

src/codegen/forward_demand.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ Internal method which generates the code for forward mode diffentiation
189189
190190
191191
- `ir` the IR being differnetation
192-
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
192+
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
193193
paired with the order (first deriviative, second derivative etc)
194194
195-
- `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
195+
- `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
196196
decides if the custom `transform!` should be applied to a `stmt` or not
197197
Default: `false` for all statements
198198
- `transform!(ir, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
@@ -289,10 +289,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
289289
end
290290

291291

292-
function forward_diff!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::Vector{Pair{SSAValue, Int}}; kwargs...)
292+
function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::MethodInstance,
293+
to_diff::Vector{Pair{SSAValue, Int}}; kwargs...)
293294
forward_diff_no_inf!(ir, to_diff; kwargs...)
294295

295296
# Step 3: Re-inference
297+
296298
ir = compact!(ir)
297299

298300
extra_reprocess = CC.BitSet()
@@ -302,9 +304,13 @@ function forward_diff!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::V
302304
end
303305
end
304306

305-
interp′ = enable_reinference(interp)
306-
irsv = IRInterpretationState(interp′, ir, mi, world, ir.argtypes[1:mi.def.nargs])
307-
rt = CC._ir_abstract_constant_propagation(interp′, irsv; extra_reprocess)
307+
method_info = CC.MethodInfo(src)
308+
argtypes = ir.argtypes[1:mi.def.nargs]
309+
world = CC.get_world_counter(interp)
310+
irsv = IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world)
311+
rt = CC._ir_abstract_constant_propagation(enable_reinference(interp), irsv; extra_reprocess)
312+
313+
ir = compact!(ir)
308314

309315
return ir
310316
end

src/stage2/forward.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
66
interp = ADInterpreter(; forward=true, backward=false)
77
match = Base._which(tt)
88
frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true)
9+
mi = frame.linfo
910

10-
ir = copy((interp.opt[0][frame.linfo].inferred).ir::IRCode)
11+
src = CC.copy(interp.unopt[0][mi].src)
12+
ir = CC.copy((@atomic :monotonic interp.opt[0][mi].inferred).ir::IRCode)
1113

1214
# Find all Return Nodes
1315
vals = Pair{SSAValue, Int}[]
@@ -43,10 +45,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
4345
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
4446
end
4547

48+
ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!)
4649

47-
irsv = CC.IRInterpretationState(interp, ir, frame.linfo, CC.get_world_counter(interp), ir.argtypes[1:frame.linfo.def.nargs])
48-
ir = forward_diff!(ir, interp, frame.linfo, CC.get_world_counter(interp), vals; visit_custom!, transform!)
49-
50-
ir = compact!(ir)
5150
return OpaqueClosure(ir)
5251
end

src/stage2/interpreter.jl

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ function Compiler3.get_codeinstance(graph::ADGraph, cursor::ADCursor)
3131
end
3232
=#
3333

34-
using Core.Compiler: AbstractInterpreter, NativeInterpreter, InferenceState,
35-
InferenceResult, CodeInstance, WorldRange, ArgInfo, StmtInfo
34+
using Core: MethodInstance, CodeInstance
35+
using .CC: AbstractInterpreter, ArgInfo, Effects, InferenceResult, InferenceState,
36+
IRInterpretationState, NativeInterpreter, OptimizationState, StmtInfo, WorldRange
3637

3738
const OptCache = Dict{MethodInstance, CodeInstance}
3839
const UnoptCache = Dict{Union{MethodInstance, InferenceResult}, Cthulhu.InferredSource}
@@ -120,7 +121,7 @@ function Cthulhu.lookup(interp::ADInterpreter, curs::ADCursor, optimize::Bool; a
120121
opt = codeinst.inferred
121122
if opt !== nothing
122123
opt = opt::Cthulhu.OptimizedSource
123-
src = Core.Compiler.copy(opt.ir)
124+
src = CC.copy(opt.ir)
124125
codeinf = opt.src
125126
infos = src.stmts.info
126127
slottypes = src.argtypes
@@ -162,7 +163,6 @@ function Cthulhu.custom_toggles(interp::ADInterpreter)
162163
end
163164

164165
# TODO: Something is going very wrong here
165-
using Core.Compiler: Effects, OptimizationState
166166
function Cthulhu.get_effects(interp::ADInterpreter, mi::MethodInstance, opt::Bool)
167167
if haskey(interp.unopt[0], mi)
168168
return interp.unopt[0][mi].effects
@@ -171,7 +171,7 @@ function Cthulhu.get_effects(interp::ADInterpreter, mi::MethodInstance, opt::Boo
171171
end
172172
end
173173

174-
function Core.Compiler.is_same_frame(interp::ADInterpreter, linfo::MethodInstance, frame::InferenceState)
174+
function CC.is_same_frame(interp::ADInterpreter, linfo::MethodInstance, frame::InferenceState)
175175
linfo === frame.linfo || return false
176176
return interp.current_level === frame.interp.current_level
177177
end
@@ -224,7 +224,7 @@ function Cthulhu.navigate(curs::ADCursor, callsite::Cthulhu.Callsite)
224224
return ADCursor(curs.level, Cthulhu.get_mi(callsite))
225225
end
226226

227-
function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Compiler.CallInfo), argtypes::Cthulhu.ArgTypes, @nospecialize(rt), optimize::Bool)
227+
function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::CC.CallInfo), argtypes::Cthulhu.ArgTypes, @nospecialize(rt), optimize::Bool)
228228
if isa(info, RecurseInfo)
229229
newargtypes = argtypes[2:end]
230230
callinfos = Cthulhu.process_info(interp, info.info, newargtypes, Cthulhu.unwrapType(widenconst(rt)), optimize)
@@ -252,33 +252,33 @@ function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Co
252252
elseif isa(info, CompClosInfo)
253253
return Any[CompClosCallInfo(rt)]
254254
end
255-
return invoke(Cthulhu.process_info, Tuple{AbstractInterpreter, Core.Compiler.CallInfo, Cthulhu.ArgTypes, Any, Bool},
255+
return invoke(Cthulhu.process_info, Tuple{AbstractInterpreter, CC.CallInfo, Cthulhu.ArgTypes, Any, Bool},
256256
interp, info, argtypes, rt, optimize)
257257
end
258258

259-
Core.Compiler.InferenceParams(ei::ADInterpreter) = InferenceParams(ei.native_interpreter)
260-
Core.Compiler.OptimizationParams(ei::ADInterpreter) = OptimizationParams(ei.native_interpreter)
261-
Core.Compiler.get_world_counter(ei::ADInterpreter) = get_world_counter(ei.native_interpreter)
262-
Core.Compiler.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interpreter)
259+
CC.InferenceParams(ei::ADInterpreter) = InferenceParams(ei.native_interpreter)
260+
CC.OptimizationParams(ei::ADInterpreter) = OptimizationParams(ei.native_interpreter)
261+
CC.get_world_counter(ei::ADInterpreter) = get_world_counter(ei.native_interpreter)
262+
CC.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interpreter)
263263

264264
# No need to do any locking since we're not putting our results into the runtime cache
265-
Core.Compiler.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
266-
Core.Compiler.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
265+
CC.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
266+
CC.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
267267

268268
struct CodeInfoView
269269
d::Dict{MethodInstance, Any}
270270
end
271271

272-
function Core.Compiler.code_cache(ei::ADInterpreter)
272+
function CC.code_cache(ei::ADInterpreter)
273273
while ei.current_level > lastindex(ei.opt)
274274
push!(ei.opt, Dict{MethodInstance, Any}())
275275
end
276276
ei.opt[ei.current_level]
277277
end
278-
Core.Compiler.may_optimize(ei::ADInterpreter) = true
279-
Core.Compiler.may_compress(ei::ADInterpreter) = false
280-
Core.Compiler.may_discard_trees(ei::ADInterpreter) = false
281-
function Core.Compiler.get(view::CodeInfoView, mi::MethodInstance, default)
278+
CC.may_optimize(ei::ADInterpreter) = true
279+
CC.may_compress(ei::ADInterpreter) = false
280+
CC.may_discard_trees(ei::ADInterpreter) = false
281+
function CC.get(view::CodeInfoView, mi::MethodInstance, default)
282282
r = get(view.d, mi, nothing)
283283
if r === nothing
284284
return default
@@ -298,23 +298,23 @@ end
298298
Cthulhu.get_remarks(interp::ADInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks[interp.current_level], key, nothing)
299299

300300
#=
301-
function Core.Compiler.const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
301+
function CC.const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
302302
return true
303303
end
304304
=#
305305

306-
function Core.Compiler.finish(state::InferenceState, interp::ADInterpreter)
307-
res = @invoke Core.Compiler.finish(state::InferenceState, interp::AbstractInterpreter)
308-
key = Core.Compiler.any(state.result.overridden_by_const) ? state.result : state.linfo
306+
function CC.finish(state::InferenceState, interp::ADInterpreter)
307+
res = @invoke CC.finish(state::InferenceState, interp::AbstractInterpreter)
308+
key = CC.any(state.result.overridden_by_const) ? state.result : state.linfo
309309
interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(
310310
copy(state.src),
311311
copy(state.stmt_info),
312-
isdefined(Core.Compiler, :Effects) ? state.ipo_effects : nothing,
312+
state.ipo_effects,
313313
state.result.result)
314314
return res
315315
end
316316

317-
function Core.Compiler.transform_result_for_cache(interp::ADInterpreter,
317+
function CC.transform_result_for_cache(interp::ADInterpreter,
318318
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
319319
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
320320
end
@@ -325,28 +325,27 @@ function CC.inlining_policy(interp::ADInterpreter,
325325
if isa(info, FRuleCallInfo)
326326
return nothing
327327
end
328-
if isdefined(CC, :SemiConcreteResult) && isa(src, CC.SemiConcreteResult)
328+
if isa(src, CC.SemiConcreteResult)
329329
return src
330330
end
331331
@assert isa(src, Cthulhu.OptimizedSource) || isnothing(src)
332332
if isa(src, Cthulhu.OptimizedSource)
333333
if CC.is_stmt_inline(stmt_flag) || src.isinlineable
334334
return src.ir
335335
end
336-
else
337-
# the default inlining policy may try additional effor to find the source in a local cache
338-
return @invoke CC.inlining_policy(interp::AbstractInterpreter,
339-
nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
336+
return nothing
340337
end
341-
return nothing
338+
# the default inlining policy may try additional effor to find the source in a local cache
339+
return @invoke CC.inlining_policy(interp::AbstractInterpreter,
340+
nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
342341
end
343342

344343
function dummy() end
345344
const dummym = first(methods(dummy))
346345

347346
function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
348347
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
349-
sv::IRCode, max_methods::Int)
348+
sv::IRInterpretationState, max_methods::Int)
350349

351350
if interp.reinference
352351
# Create a dummy inference state to serve as the root
@@ -359,41 +358,41 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
359358
return r
360359
end
361360

362-
return CallMeta(Any, CC.Effects(), CC.NoCallInfo())
361+
return CallMeta(Any, Effects(), CC.NoCallInfo())
363362
end
364363

365364
#=
366-
function Core.Compiler.optimize(interp::ADInterpreter, opt::OptimizationState,
365+
function CC.optimize(interp::ADInterpreter, opt::OptimizationState,
367366
params::OptimizationParams, caller::InferenceResult)
368367
369368
# TODO: Enable some amount of inlining
370369
#@timeit "optimizer" ir = run_passes(opt.src, opt, caller)
371370
372371
sv = opt
373372
ci = opt.src
374-
ir = Core.Compiler.convert_to_ircode(ci, sv)
375-
ir = Core.Compiler.slot2reg(ir, ci, sv)
373+
ir = CC.convert_to_ircode(ci, sv)
374+
ir = CC.slot2reg(ir, ci, sv)
376375
# TODO: Domsorting can produce an updated domtree - no need to recompute here
377-
ir = Core.Compiler.compact!(ir)
378-
return Core.Compiler.finish(interp, opt, params, ir, caller)
376+
ir = CC.compact!(ir)
377+
return CC.finish(interp, opt, params, ir, caller)
379378
end
380379
=#
381380

382-
function Core.Compiler.finish!(interp::ADInterpreter, caller::InferenceResult)
381+
function CC.finish!(interp::ADInterpreter, caller::InferenceResult)
383382
effects = caller.ipo_effects
384383
caller.src = Cthulhu.create_cthulhu_source(caller.src, effects)
385384
end
386385

387386
function ir2codeinst(ir::IRCode, inst::CodeInstance, ci::CodeInfo)
388387
CodeInstance(inst.def, inst.rettype, isdefined(inst, :rettype_const) ? inst.rettype_const : nothing,
389-
Cthulhu.OptimizedSource(Core.Compiler.copy(ir), ci, inst.inferred.isinlineable, Core.Compiler.decode_effects(inst.purity_bits)),
388+
Cthulhu.OptimizedSource(CC.copy(ir), ci, inst.inferred.isinlineable, CC.decode_effects(inst.purity_bits)),
390389
Int32(0), inst.min_world, inst.max_world, inst.ipo_purity_bits, inst.purity_bits,
391390
inst.argescapes, inst.relocatability)
392391
end
393392

394393
using Core: OpaqueClosure
395394
function codegen(interp::ADInterpreter, curs::ADCursor, cache=Dict{ADCursor, OpaqueClosure}())
396-
ir = Core.Compiler.copy(Cthulhu.get_optimized_codeinst(interp, curs).inferred.ir)
395+
ir = CC.copy(Cthulhu.get_optimized_codeinst(interp, curs).inferred.ir)
397396
codeinst = interp.opt[curs.level][curs.mi]
398397
ci = codeinst.inferred.src
399398
if curs.level >= 1

0 commit comments

Comments
 (0)