Skip to content

Commit db23898

Browse files
Kenoaviatesk
authored andcommitted
inference: improve :nothrow modeling for :static_parameter (#46820)
1 parent e360d72 commit db23898

File tree

8 files changed

+121
-66
lines changed

8 files changed

+121
-66
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,12 +2189,16 @@ function abstract_eval_value_expr(interp::AbstractInterpreter, e::Expr, vtypes::
21892189
head = e.head
21902190
if head === :static_parameter
21912191
n = e.args[1]::Int
2192+
nothrow = false
21922193
if 1 <= n <= length(sv.sptypes)
21932194
rt = sv.sptypes[n]
2195+
if is_maybeundefsp(rt)
2196+
rt = unwrap_maybeundefsp(rt)
2197+
else
2198+
nothrow = true
2199+
end
21942200
end
2195-
if !isa(rt, Const)
2196-
merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; nothrow=false))
2197-
end
2201+
merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; nothrow))
21982202
return rt
21992203
elseif head === :boundscheck
22002204
if isa(sv, InferenceState)
@@ -2456,8 +2460,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
24562460
elseif isexpr(sym, :static_parameter)
24572461
n = sym.args[1]::Int
24582462
if 1 <= n <= length(sv.sptypes)
2459-
spty = sv.sptypes[n]
2460-
if isa(spty, Const)
2463+
if !is_maybeundefsp(sv.sptypes, n)
24612464
t = Const(true)
24622465
end
24632466
end

base/compiler/inferencestate.jl

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -348,23 +348,100 @@ function InferenceState(result::InferenceResult, cache::Symbol, interp::Abstract
348348
return InferenceState(result, src, cache, interp)
349349
end
350350

351+
"""
352+
constrains_param(var::TypeVar, sig, covariant::Bool)
353+
354+
Check if `var` will be constrained to have a definite value
355+
in any concrete leaftype subtype of `sig`.
356+
357+
It is used as a helper to determine whether type intersection is guaranteed to be able to
358+
find a value for a particular type parameter.
359+
A necessary condition for type intersection to not assign a parameter is that it only
360+
appears in a `Union[All]` and during subtyping some other union component (that does not
361+
constrain the type parameter) is selected.
362+
"""
363+
function constrains_param(var::TypeVar, @nospecialize(typ), covariant::Bool)
364+
typ === var && return true
365+
while typ isa UnionAll
366+
covariant && constrains_param(var, typ.var.ub, covariant) && return true
367+
# typ.var.lb doesn't constrain var
368+
typ = typ.body
369+
end
370+
if typ isa Union
371+
# for unions, verify that both options would constrain var
372+
ba = constrains_param(var, typ.a, covariant)
373+
bb = constrains_param(var, typ.b, covariant)
374+
(ba && bb) && return true
375+
elseif typ isa DataType
376+
# return true if any param constrains var
377+
fc = length(typ.parameters)
378+
if fc > 0
379+
if typ.name === Tuple.name
380+
# vararg tuple needs special handling
381+
for i in 1:(fc - 1)
382+
p = typ.parameters[i]
383+
constrains_param(var, p, covariant) && return true
384+
end
385+
lastp = typ.parameters[fc]
386+
vararg = unwrap_unionall(lastp)
387+
if vararg isa Core.TypeofVararg && isdefined(vararg, :N)
388+
constrains_param(var, vararg.N, covariant) && return true
389+
# T = vararg.parameters[1] doesn't constrain var
390+
else
391+
constrains_param(var, lastp, covariant) && return true
392+
end
393+
else
394+
for i in 1:fc
395+
p = typ.parameters[i]
396+
constrains_param(var, p, false) && return true
397+
end
398+
end
399+
end
400+
end
401+
return false
402+
end
403+
404+
"""
405+
MaybeUndefSP(typ)
406+
is_maybeundefsp(typ) -> Bool
407+
unwrap_maybeundefsp(typ) -> Any
408+
409+
A special wrapper that represents a static parameter that could be undefined at runtime.
410+
This does not participate in the native type system nor the inference lattice,
411+
and it thus should be always unwrapped when performing any type or lattice operations on it.
412+
"""
413+
struct MaybeUndefSP
414+
typ
415+
MaybeUndefSP(@nospecialize typ) = new(typ)
416+
end
417+
is_maybeundefsp(@nospecialize typ) = isa(typ, MaybeUndefSP)
418+
unwrap_maybeundefsp(@nospecialize typ) = isa(typ, MaybeUndefSP) ? typ.typ : typ
419+
is_maybeundefsp(sptypes::Vector{Any}, idx::Int) = is_maybeundefsp(sptypes[idx])
420+
unwrap_maybeundefsp(sptypes::Vector{Any}, idx::Int) = unwrap_maybeundefsp(sptypes[idx])
421+
422+
const EMPTY_SPTYPES = Any[]
423+
351424
function sptypes_from_meth_instance(linfo::MethodInstance)
352-
toplevel = !isa(linfo.def, Method)
353-
if !toplevel && isempty(linfo.sparam_vals) && isa(linfo.def.sig, UnionAll)
425+
def = linfo.def
426+
isa(def, Method) || return EMPTY_SPTYPES # toplevel
427+
sig = def.sig
428+
if isempty(linfo.sparam_vals)
429+
isa(sig, UnionAll) || return EMPTY_SPTYPES
354430
# linfo is unspecialized
355431
sp = Any[]
356-
sig = linfo.def.sig
357-
while isa(sig, UnionAll)
358-
push!(sp, sig.var)
359-
sig = sig.body
432+
sig = sig
433+
while isa(sig, UnionAll)
434+
push!(sp, sig.var)
435+
sig = sig.body
360436
end
361437
else
362438
sp = collect(Any, linfo.sparam_vals)
363439
end
364440
for i = 1:length(sp)
365441
v = sp[i]
366442
if v isa TypeVar
367-
temp = linfo.def.sig
443+
maybe_undef = !constrains_param(v, linfo.specTypes, true)
444+
temp = sig
368445
for j = 1:i-1
369446
temp = temp.body
370447
end
@@ -402,12 +479,13 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
402479
tv = TypeVar(v.name, lb, ub)
403480
ty = UnionAll(tv, Type{tv})
404481
end
482+
@label ty_computed
483+
maybe_undef && (ty = MaybeUndefSP(ty))
405484
elseif isvarargtype(v)
406485
ty = Int
407486
else
408487
ty = Const(v)
409488
end
410-
@label ty_computed
411489
sp[i] = ty
412490
end
413491
return sp

base/compiler/optimize.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
267267
if isa(stmt, Expr)
268268
(; head, args) = stmt
269269
if head === :static_parameter
270-
etyp = (isa(src, IRCode) ? src.sptypes : src.ir.sptypes)[args[1]::Int]
271270
# if we aren't certain enough about the type, it might be an UndefVarError at runtime
272-
nothrow = isa(etyp, Const)
271+
sptypes = isa(src, IRCode) ? src.sptypes : src.ir.sptypes
272+
nothrow = !is_maybeundefsp(sptypes, args[1]::Int)
273273
return (true, nothrow, nothrow)
274274
end
275275
if head === :call
@@ -377,7 +377,7 @@ function argextype(
377377
sptypes::Vector{Any}, slottypes::Vector{Any})
378378
if isa(x, Expr)
379379
if x.head === :static_parameter
380-
return sptypes[x.args[1]::Int]
380+
return unwrap_maybeundefsp(sptypes, x.args[1]::Int)
381381
elseif x.head === :boundscheck
382382
return Bool
383383
elseif x.head === :copyast

base/compiler/ssair/slot2ssa.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ end
216216
function typ_for_val(@nospecialize(x), ci::CodeInfo, sptypes::Vector{Any}, idx::Int, slottypes::Vector{Any})
217217
if isa(x, Expr)
218218
if x.head === :static_parameter
219-
return sptypes[x.args[1]::Int]
219+
return unwrap_maybeundefsp(sptypes, x.args[1]::Int)
220220
elseif x.head === :boundscheck
221221
return Bool
222222
elseif x.head === :copyast

base/compiler/ssair/verify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ function verify_ir(ir::IRCode, print::Bool=true,
268268
elseif stmt.head === :foreigncall
269269
isforeigncall = true
270270
elseif stmt.head === :isdefined && length(stmt.args) == 1 &&
271-
(stmt.args[1] isa GlobalRef || (stmt.args[1] isa Expr && stmt.args[1].head === :static_parameter))
271+
(stmt.args[1] isa GlobalRef || isexpr(stmt.args[1], :static_parameter))
272272
# a GlobalRef or static_parameter isdefined check does not evaluate its argument
273273
continue
274274
elseif stmt.head === :call

stdlib/Test/src/Test.jl

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,54 +1970,11 @@ function detect_unbound_args(mods...;
19701970
return collect(ambs)
19711971
end
19721972

1973-
# find if var will be constrained to have a definite value
1974-
# in any concrete leaftype subtype of typ
1975-
function constrains_param(var::TypeVar, @nospecialize(typ), covariant::Bool)
1976-
typ === var && return true
1977-
while typ isa UnionAll
1978-
covariant && constrains_param(var, typ.var.ub, covariant) && return true
1979-
# typ.var.lb doesn't constrain var
1980-
typ = typ.body
1981-
end
1982-
if typ isa Union
1983-
# for unions, verify that both options would constrain var
1984-
ba = constrains_param(var, typ.a, covariant)
1985-
bb = constrains_param(var, typ.b, covariant)
1986-
(ba && bb) && return true
1987-
elseif typ isa DataType
1988-
# return true if any param constrains var
1989-
fc = length(typ.parameters)
1990-
if fc > 0
1991-
if typ.name === Tuple.name
1992-
# vararg tuple needs special handling
1993-
for i in 1:(fc - 1)
1994-
p = typ.parameters[i]
1995-
constrains_param(var, p, covariant) && return true
1996-
end
1997-
lastp = typ.parameters[fc]
1998-
vararg = Base.unwrap_unionall(lastp)
1999-
if vararg isa Core.TypeofVararg && isdefined(vararg, :N)
2000-
constrains_param(var, vararg.N, covariant) && return true
2001-
# T = vararg.parameters[1] doesn't constrain var
2002-
else
2003-
constrains_param(var, lastp, covariant) && return true
2004-
end
2005-
else
2006-
for i in 1:fc
2007-
p = typ.parameters[i]
2008-
constrains_param(var, p, false) && return true
2009-
end
2010-
end
2011-
end
2012-
end
2013-
return false
2014-
end
2015-
20161973
function has_unbound_vars(@nospecialize sig)
20171974
while sig isa UnionAll
20181975
var = sig.var
20191976
sig = sig.body
2020-
if !constrains_param(var, sig, true)
1977+
if !Core.Compiler.constrains_param(var, sig, true)
20211978
return true
20221979
end
20231980
end

test/compiler/effects.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -775,9 +775,13 @@ gotoifnot_throw_check_48583(x) = x ? x : 0
775775
# unknown :static_parameter should taint :nothrow
776776
# https://github.com/JuliaLang/julia/issues/46771
777777
unknown_sparam_throw(::Union{Nothing, Type{T}}) where T = (T; nothing)
778+
unknown_sparam_nothrow1(x::Ref{T}) where T = (T; nothing)
779+
unknown_sparam_nothrow2(x::Ref{Ref{T}}) where T = (T; nothing)
778780
@test Core.Compiler.is_nothrow(Base.infer_effects(unknown_sparam_throw, (Type{Int},)))
781+
@test Core.Compiler.is_nothrow(Base.infer_effects(unknown_sparam_throw, (Type{<:Integer},)))
782+
@test !Core.Compiler.is_nothrow(Base.infer_effects(unknown_sparam_throw, (Type,)))
779783
@test !Core.Compiler.is_nothrow(Base.infer_effects(unknown_sparam_throw, (Nothing,)))
780-
781-
unknown_sparam_nothrow(x::Ref{T}) where {T} = (T; nothing)
782-
@test_broken Core.Compiler.is_nothrow(Base.infer_effects(unknown_sparam_nothrow, (Ref,)))
783-
784+
@test !Core.Compiler.is_nothrow(Base.infer_effects(unknown_sparam_throw, (Union{Type{Int},Nothing},)))
785+
@test !Core.Compiler.is_nothrow(Base.infer_effects(unknown_sparam_throw, (Any,)))
786+
@test Core.Compiler.is_nothrow(Base.infer_effects(unknown_sparam_nothrow1, (Ref,)))
787+
@test Core.Compiler.is_nothrow(Base.infer_effects(unknown_sparam_nothrow2, (Ref{Ref{T}} where T,)))

test/compiler/inference.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4761,3 +4761,16 @@ g_no_bail_effects_any(x::Any) = f_no_bail_effects_any(x)
47614761

47624762
# issue #48374
47634763
@test (() -> Union{<:Nothing})() == Nothing
4764+
4765+
# :static_parameter accuracy
4766+
unknown_sparam_throw(::Union{Nothing, Type{T}}) where T = @isdefined(T) ? T::Type : nothing
4767+
unknown_sparam_nothrow1(x::Ref{T}) where T = @isdefined(T) ? T::Type : nothing
4768+
unknown_sparam_nothrow2(x::Ref{Ref{T}}) where T = @isdefined(T) ? T::Type : nothing
4769+
@test only(Base.return_types(unknown_sparam_throw, (Type{Int},))) == Type{Int}
4770+
@test only(Base.return_types(unknown_sparam_throw, (Type{<:Integer},))) == Type{<:Integer}
4771+
@test only(Base.return_types(unknown_sparam_throw, (Type,))) == Type
4772+
@test_broken only(Base.return_types(unknown_sparam_throw, (Nothing,))) === Nothing
4773+
@test_broken only(Base.return_types(unknown_sparam_throw, (Union{Type{Int},Nothing},))) === Union{Nothing,Type{Int}}
4774+
@test only(Base.return_types(unknown_sparam_throw, (Any,))) === Union{Nothing,Type}
4775+
@test only(Base.return_types(unknown_sparam_nothrow1, (Ref,))) === Type
4776+
@test only(Base.return_types(unknown_sparam_nothrow2, (Ref{Ref{T}} where T,))) === Type

0 commit comments

Comments
 (0)