Skip to content

Commit 2a0d58a

Browse files
authored
abstract_apply: Don't drop effects of iterate'd calls (#47846)
We were accidentally dropping the effects of calls from `iterate` calls performed during abstract_iteration. This allowed calls that were not actually eligible for (semi-)concrete evaluation to go through that path anyway. This could cause incorrect results (see test), though it was usually fine, since iterate call tend to not have side effects. It was noticed however in #47688, because it forced irinterp down a path that was not meant to be reachable (resulting in a TODO error message). For good measure, let's also address this todo (since it is reachable by external absint if they want), but the missing effect propagation was the more serious bug here.
1 parent 6bfc6ac commit 2a0d58a

File tree

5 files changed

+72
-31
lines changed

5 files changed

+72
-31
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,14 @@ function ssa_def_slot(@nospecialize(arg), sv::InferenceState)
13421342
return arg
13431343
end
13441344

1345+
struct AbstractIterationResult
1346+
cti::Vector{Any}
1347+
info::MaybeAbstractIterationInfo
1348+
ai_effects::Effects
1349+
end
1350+
AbstractIterationResult(cti::Vector{Any}, info::MaybeAbstractIterationInfo) =
1351+
AbstractIterationResult(cti, info, EFFECTS_TOTAL)
1352+
13451353
# `typ` is the inferred type for expression `arg`.
13461354
# if the expression constructs a container (e.g. `svec(x,y,z)`),
13471355
# refine its type to an array of element types.
@@ -1352,14 +1360,14 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
13521360
if isa(typ, PartialStruct)
13531361
widet = typ.typ
13541362
if isa(widet, DataType) && widet.name === Tuple.name
1355-
return typ.fields, nothing
1363+
return AbstractIterationResult(typ.fields, nothing)
13561364
end
13571365
end
13581366

13591367
if isa(typ, Const)
13601368
val = typ.val
13611369
if isa(val, SimpleVector) || isa(val, Tuple)
1362-
return Any[ Const(val[i]) for i in 1:length(val) ], nothing # avoid making a tuple Generator here!
1370+
return AbstractIterationResult(Any[ Const(val[i]) for i in 1:length(val) ], nothing) # avoid making a tuple Generator here!
13631371
end
13641372
end
13651373

@@ -1374,12 +1382,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
13741382
if isa(tti, Union)
13751383
utis = uniontypes(tti)
13761384
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
1377-
return Any[Vararg{Any}], nothing
1385+
return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′)
13781386
end
13791387
ltp = length((utis[1]::DataType).parameters)
13801388
for t in utis
13811389
if length((t::DataType).parameters) != ltp
1382-
return Any[Vararg{Any}], nothing
1390+
return AbstractIterationResult(Any[Vararg{Any}], nothing)
13831391
end
13841392
end
13851393
result = Any[ Union{} for _ in 1:ltp ]
@@ -1390,12 +1398,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
13901398
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
13911399
end
13921400
end
1393-
return result, nothing
1401+
return AbstractIterationResult(result, nothing)
13941402
elseif tti0 <: Tuple
13951403
if isa(tti0, DataType)
1396-
return Any[ p for p in tti0.parameters ], nothing
1404+
return AbstractIterationResult(Any[ p for p in tti0.parameters ], nothing)
13971405
elseif !isa(tti, DataType)
1398-
return Any[Vararg{Any}], nothing
1406+
return AbstractIterationResult(Any[Vararg{Any}], nothing)
13991407
else
14001408
len = length(tti.parameters)
14011409
last = tti.parameters[len]
@@ -1404,12 +1412,14 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
14041412
if va
14051413
elts[len] = Vararg{elts[len]}
14061414
end
1407-
return elts, nothing
1415+
return AbstractIterationResult(elts, nothing)
14081416
end
1409-
elseif tti0 === SimpleVector || tti0 === Any
1410-
return Any[Vararg{Any}], nothing
1417+
elseif tti0 === SimpleVector
1418+
return AbstractIterationResult(Any[Vararg{Any}], nothing)
1419+
elseif tti0 === Any
1420+
return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′)
14111421
elseif tti0 <: Array
1412-
return Any[Vararg{eltype(tti0)}], nothing
1422+
return AbstractIterationResult(Any[Vararg{eltype(tti0)}], nothing)
14131423
else
14141424
return abstract_iteration(interp, itft, typ, sv)
14151425
end
@@ -1420,7 +1430,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
14201430
if isa(itft, Const)
14211431
iteratef = itft.val
14221432
else
1423-
return Any[Vararg{Any}], nothing
1433+
return AbstractIterationResult(Any[Vararg{Any}], nothing, EFFECTS_UNKNOWN′)
14241434
end
14251435
@assert !isvarargtype(itertype)
14261436
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), StmtInfo(true), sv)
@@ -1430,7 +1440,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
14301440
# WARNING: Changes to the iteration protocol must be reflected here,
14311441
# this is not just an optimization.
14321442
# TODO: this doesn't realize that Array, SimpleVector, Tuple, and NamedTuple do not use the iterate protocol
1433-
stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, call.effects, info)])
1443+
stateordonet === Bottom && return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, call.effects, info)], true))
14341444
valtype = statetype = Bottom
14351445
ret = Any[]
14361446
calls = CallMeta[call]
@@ -1440,7 +1450,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
14401450
# length iterators, or interesting prefix
14411451
while true
14421452
if stateordonet_widened === Nothing
1443-
return ret, AbstractIterationInfo(calls)
1453+
return AbstractIterationResult(ret, AbstractIterationInfo(calls, true))
14441454
end
14451455
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).max_tuple_splat
14461456
break
@@ -1452,7 +1462,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
14521462
# If there's no new information in this statetype, don't bother continuing,
14531463
# the iterator won't be finite.
14541464
if (typeinf_lattice(interp), nstatetype, statetype)
1455-
return Any[Bottom], nothing
1465+
return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_THROWS)
14561466
end
14571467
valtype = getfield_tfunc(typeinf_lattice(interp), stateordonet, Const(1))
14581468
push!(ret, valtype)
@@ -1482,7 +1492,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
14821492
# ... but cannot terminate
14831493
if !may_have_terminated
14841494
# ... and cannot have terminated prior to this loop
1485-
return Any[Bottom], nothing
1495+
return AbstractIterationResult(Any[Bottom], AbstractIterationInfo(calls, false), EFFECTS_UNKNOWN′)
14861496
else
14871497
# iterator may have terminated prior to this loop, but not during it
14881498
valtype = Bottom
@@ -1492,13 +1502,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
14921502
end
14931503
valtype = tmerge(valtype, nounion.parameters[1])
14941504
statetype = tmerge(statetype, nounion.parameters[2])
1495-
stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv).rt
1505+
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv)
1506+
push!(calls, call)
1507+
stateordonet = call.rt
14961508
stateordonet_widened = widenconst(stateordonet)
14971509
end
14981510
if valtype !== Union{}
14991511
push!(ret, Vararg{valtype})
15001512
end
1501-
return ret, nothing
1513+
return AbstractIterationResult(ret, AbstractIterationInfo(calls, false))
15021514
end
15031515

15041516
# do apply(af, fargs...), where af is a function value
@@ -1529,13 +1541,9 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
15291541
infos′ = Vector{MaybeAbstractIterationInfo}[]
15301542
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
15311543
if !isvarargtype(ti)
1532-
cti_info = precise_container_type(interp, itft, ti, sv)
1533-
cti = cti_info[1]::Vector{Any}
1534-
info = cti_info[2]::MaybeAbstractIterationInfo
1544+
(;cti, info, ai_effects) = precise_container_type(interp, itft, ti, sv)
15351545
else
1536-
cti_info = precise_container_type(interp, itft, unwrapva(ti), sv)
1537-
cti = cti_info[1]::Vector{Any}
1538-
info = cti_info[2]::MaybeAbstractIterationInfo
1546+
(;cti, info, ai_effects) = precise_container_type(interp, itft, unwrapva(ti), sv)
15391547
# We can't represent a repeating sequence of the same types,
15401548
# so tmerge everything together to get one type that represents
15411549
# everything.
@@ -1548,6 +1556,12 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
15481556
end
15491557
cti = Any[Vararg{argt}]
15501558
end
1559+
effects = merge_effects(effects, ai_effects)
1560+
if info !== nothing
1561+
for call in info.each
1562+
effects = merge_effects(effects, call.effects)
1563+
end
1564+
end
15511565
if any(@nospecialize(t) -> t === Bottom, cti)
15521566
continue
15531567
end

base/compiler/ssair/inlining.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ function rewrite_apply_exprargs!(todo::Vector{Pair{Int,Any}},
729729
def = argexprs[i]
730730
def_type = argtypes[i]
731731
thisarginfo = arginfos[i-arg_start]
732-
if thisarginfo === nothing
732+
if thisarginfo === nothing || !thisarginfo.complete
733733
if def_type isa PartialStruct
734734
# def_type.typ <: Tuple is assumed
735735
def_argtypes = def_type.fields
@@ -1134,9 +1134,9 @@ function inline_apply!(todo::Vector{Pair{Int,Any}},
11341134
for i = (arg_start + 1):length(argtypes)
11351135
thisarginfo = nothing
11361136
if !is_valid_type_for_apply_rewrite(argtypes[i], state.params)
1137-
if isa(info, ApplyCallInfo) && info.arginfo[i-arg_start] !== nothing
1138-
thisarginfo = info.arginfo[i-arg_start]
1139-
else
1137+
isa(info, ApplyCallInfo) || return nothing
1138+
thisarginfo = info.arginfo[i-arg_start]
1139+
if thisarginfo === nothing || !thisarginfo.complete
11401140
return nothing
11411141
end
11421142
end

base/compiler/ssair/irinterp.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,17 @@ function kill_def_use!(tpdum::TwoPhaseDefUseMap, def::Int, use::Int)
6666
if !tpdum.complete
6767
tpdum.ssa_uses[def] -= 1
6868
else
69-
@assert false && "TODO"
69+
range = tpdum.ssa_uses[def]:(def == length(tpdum.ssa_uses) ? length(tpdum.data) : (tpdum.ssa_uses[def + 1] - 1))
70+
# TODO: Sorted
71+
useidx = findfirst(idx->tpdum.data[idx] == use, range)
72+
@assert useidx !== nothing
73+
idx = range[useidx]
74+
while idx < lastindex(range)
75+
ndata = tpdum.data[idx+1]
76+
ndata == 0 && break
77+
tpdum.data[idx] = ndata
78+
end
79+
tpdum.data[idx + 1] = 0
7080
end
7181
end
7282
kill_def_use!(tpdum::TwoPhaseDefUseMap, def::SSAValue, use::Int) =
@@ -262,11 +272,11 @@ function process_terminator!(ir::IRCode, idx::Int, bb::Int,
262272
end
263273
return false
264274
elseif isa(inst, GotoNode)
265-
backedge = inst.label < bb
275+
backedge = inst.label <= bb
266276
!backedge && push!(ip, inst.label)
267277
return backedge
268278
elseif isa(inst, GotoIfNot)
269-
backedge = inst.dest < bb
279+
backedge = inst.dest <= bb
270280
!backedge && push!(ip, inst.dest)
271281
push!(ip, bb + 1)
272282
return backedge

base/compiler/stmtinfo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ Each (abstract) call to `iterate`, corresponds to one entry in `ainfo.each::Vect
114114
"""
115115
struct AbstractIterationInfo
116116
each::Vector{CallMeta}
117+
complete::Bool
117118
end
118119

119120
const MaybeAbstractIterationInfo = Union{Nothing, AbstractIterationInfo}

test/compiler/inference.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4633,3 +4633,19 @@ end |> only === Type{Float64}
46334633
# Issue #46839: `abstract_invoke` should handle incorrect call type
46344634
@test only(Base.return_types(()->invoke(BitSet, Any, x), ())) === Union{}
46354635
@test only(Base.return_types(()->invoke(BitSet, Union{Tuple{Int32},Tuple{Int64}}, 1), ())) === Union{}
4636+
4637+
# Issue #47688: Abstract iteration should take into account `iterate` effects
4638+
global it_count47688 = 0
4639+
struct CountsIterate47688{N}; end
4640+
function Base.iterate(::CountsIterate47688{N}, n=0) where N
4641+
global it_count47688 += 1
4642+
n <= N ? (n, n+1) : nothing
4643+
end
4644+
foo47688() = tuple(CountsIterate47688{5}()...)
4645+
bar47688() = foo47688()
4646+
@test only(Base.return_types(bar47688)) == NTuple{6, Int}
4647+
@test it_count47688 == 0
4648+
@test isa(bar47688(), NTuple{6, Int})
4649+
@test it_count47688 == 7
4650+
@test isa(foo47688(), NTuple{6, Int})
4651+
@test it_count47688 == 14

0 commit comments

Comments
 (0)