Skip to content

Commit 56ed5c3

Browse files
tavianatorJukkaL
authored andcommitted
Analyze super() expressions more precisely (#7232)
The previous implementation assumed that all super(T, x) expressions were for instances of the current class. This new implementation uses the precise types of T and x to do a more accurate analysis. Fixes #5794.
1 parent 445020f commit 56ed5c3

File tree

6 files changed

+216
-112
lines changed

6 files changed

+216
-112
lines changed

mypy/checkexpr.py

Lines changed: 143 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -3108,112 +3108,141 @@ def infer_lambda_type_using_context(self, e: LambdaExpr) -> Tuple[Optional[Calla
31083108

31093109
def visit_super_expr(self, e: SuperExpr) -> Type:
31103110
"""Type check a super expression (non-lvalue)."""
3111-
self.check_super_arguments(e)
3112-
t = self.analyze_super(e, False)
3113-
return t
31143111

3115-
def check_super_arguments(self, e: SuperExpr) -> None:
3116-
"""Check arguments in a super(...) call."""
3117-
if ARG_STAR in e.call.arg_kinds:
3112+
# We have an expression like super(T, var).member
3113+
3114+
# First compute the types of T and var
3115+
types = self._super_arg_types(e)
3116+
if isinstance(types, tuple):
3117+
type_type, instance_type = types
3118+
else:
3119+
return types
3120+
3121+
# Now get the MRO
3122+
type_info = type_info_from_type(type_type)
3123+
if type_info is None:
3124+
self.chk.fail(message_registry.UNSUPPORTED_ARG_1_FOR_SUPER, e)
3125+
return AnyType(TypeOfAny.from_error)
3126+
3127+
instance_info = type_info_from_type(instance_type)
3128+
if instance_info is None:
3129+
self.chk.fail(message_registry.UNSUPPORTED_ARG_2_FOR_SUPER, e)
3130+
return AnyType(TypeOfAny.from_error)
3131+
3132+
mro = instance_info.mro
3133+
3134+
# The base is the first MRO entry *after* type_info that has a member
3135+
# with the right name
3136+
try:
3137+
index = mro.index(type_info)
3138+
except ValueError:
3139+
self.chk.fail(message_registry.SUPER_ARG_2_NOT_INSTANCE_OF_ARG_1, e)
3140+
return AnyType(TypeOfAny.from_error)
3141+
3142+
for base in mro[index+1:]:
3143+
if e.name in base.names or base == mro[-1]:
3144+
if e.info and e.info.fallback_to_any and base == mro[-1]:
3145+
# There's an undefined base class, and we're at the end of the
3146+
# chain. That's not an error.
3147+
return AnyType(TypeOfAny.special_form)
3148+
3149+
return analyze_member_access(name=e.name,
3150+
typ=instance_type,
3151+
is_lvalue=False,
3152+
is_super=True,
3153+
is_operator=False,
3154+
original_type=instance_type,
3155+
override_info=base,
3156+
context=e,
3157+
msg=self.msg,
3158+
chk=self.chk,
3159+
in_literal_context=self.is_literal_context())
3160+
3161+
assert False, 'unreachable'
3162+
3163+
def _super_arg_types(self, e: SuperExpr) -> Union[Type, Tuple[Type, Type]]:
3164+
"""
3165+
Computes the types of the type and instance expressions in super(T, instance), or the
3166+
implicit ones for zero-argument super() expressions. Returns a single type for the whole
3167+
super expression when possible (for errors, anys), otherwise the pair of computed types.
3168+
"""
3169+
3170+
if not self.chk.in_checked_function():
3171+
return AnyType(TypeOfAny.unannotated)
3172+
elif len(e.call.args) == 0:
3173+
if self.chk.options.python_version[0] == 2:
3174+
self.chk.fail(message_registry.TOO_FEW_ARGS_FOR_SUPER, e)
3175+
return AnyType(TypeOfAny.from_error)
3176+
elif not e.info:
3177+
# This has already been reported by the semantic analyzer.
3178+
return AnyType(TypeOfAny.from_error)
3179+
elif self.chk.scope.active_class():
3180+
self.chk.fail(message_registry.SUPER_OUTSIDE_OF_METHOD_NOT_SUPPORTED, e)
3181+
return AnyType(TypeOfAny.from_error)
3182+
3183+
# Zero-argument super() is like super(<current class>, <self>)
3184+
current_type = fill_typevars(e.info)
3185+
type_type = TypeType(current_type) # type: Type
3186+
3187+
# Use the type of the self argument, in case it was annotated
3188+
method = self.chk.scope.top_function()
3189+
assert method is not None
3190+
if method.arguments:
3191+
instance_type = method.arguments[0].variable.type or current_type # type: Type
3192+
else:
3193+
self.chk.fail(message_registry.SUPER_ENCLOSING_POSITIONAL_ARGS_REQUIRED, e)
3194+
return AnyType(TypeOfAny.from_error)
3195+
elif ARG_STAR in e.call.arg_kinds:
31183196
self.chk.fail(message_registry.SUPER_VARARGS_NOT_SUPPORTED, e)
3119-
elif e.call.args and set(e.call.arg_kinds) != {ARG_POS}:
3197+
return AnyType(TypeOfAny.from_error)
3198+
elif set(e.call.arg_kinds) != {ARG_POS}:
31203199
self.chk.fail(message_registry.SUPER_POSITIONAL_ARGS_REQUIRED, e)
3200+
return AnyType(TypeOfAny.from_error)
31213201
elif len(e.call.args) == 1:
31223202
self.chk.fail(message_registry.SUPER_WITH_SINGLE_ARG_NOT_SUPPORTED, e)
3123-
elif len(e.call.args) > 2:
3124-
self.chk.fail(message_registry.TOO_MANY_ARGS_FOR_SUPER, e)
3125-
elif self.chk.options.python_version[0] == 2 and len(e.call.args) == 0:
3126-
self.chk.fail(message_registry.TOO_FEW_ARGS_FOR_SUPER, e)
3203+
return AnyType(TypeOfAny.from_error)
31273204
elif len(e.call.args) == 2:
3128-
type_obj_type = self.accept(e.call.args[0])
3205+
type_type = self.accept(e.call.args[0])
31293206
instance_type = self.accept(e.call.args[1])
3130-
if isinstance(type_obj_type, FunctionLike) and type_obj_type.is_type_obj():
3131-
type_info = type_obj_type.type_object()
3132-
elif isinstance(type_obj_type, TypeType):
3133-
item = type_obj_type.item
3134-
if isinstance(item, AnyType):
3135-
# Could be anything.
3136-
return
3137-
if isinstance(item, TupleType):
3138-
# Handle named tuples and other Tuple[...] subclasses.
3139-
item = tuple_fallback(item)
3140-
if not isinstance(item, Instance):
3141-
# A complicated type object type. Too tricky, give up.
3142-
# TODO: Do something more clever here.
3143-
self.chk.fail(message_registry.UNSUPPORTED_ARG_1_FOR_SUPER, e)
3144-
return
3145-
type_info = item.type
3146-
elif isinstance(type_obj_type, AnyType):
3147-
return
3207+
else:
3208+
self.chk.fail(message_registry.TOO_MANY_ARGS_FOR_SUPER, e)
3209+
return AnyType(TypeOfAny.from_error)
3210+
3211+
# Imprecisely assume that the type is the current class
3212+
if isinstance(type_type, AnyType):
3213+
if e.info:
3214+
type_type = TypeType(fill_typevars(e.info))
31483215
else:
3149-
self.msg.first_argument_for_super_must_be_type(type_obj_type, e)
3150-
return
3216+
return AnyType(TypeOfAny.from_another_any, source_any=type_type)
3217+
elif isinstance(type_type, TypeType):
3218+
type_item = type_type.item
3219+
if isinstance(type_item, AnyType):
3220+
if e.info:
3221+
type_type = TypeType(fill_typevars(e.info))
3222+
else:
3223+
return AnyType(TypeOfAny.from_another_any, source_any=type_item)
31513224

3152-
if isinstance(instance_type, (Instance, TupleType, TypeVarType)):
3153-
if isinstance(instance_type, TypeVarType):
3154-
# Needed for generic self.
3155-
instance_type = instance_type.upper_bound
3156-
if not isinstance(instance_type, (Instance, TupleType)):
3157-
# Too tricky, give up.
3158-
# TODO: Do something more clever here.
3159-
self.chk.fail(message_registry.UNSUPPORTED_ARG_2_FOR_SUPER, e)
3160-
return
3161-
if isinstance(instance_type, TupleType):
3162-
# Needed for named tuples and other Tuple[...] subclasses.
3163-
instance_type = tuple_fallback(instance_type)
3164-
if type_info not in instance_type.type.mro:
3165-
self.chk.fail(message_registry.SUPER_ARG_2_NOT_INSTANCE_OF_ARG_1, e)
3166-
elif isinstance(instance_type, TypeType) or (isinstance(instance_type, FunctionLike)
3167-
and instance_type.is_type_obj()):
3168-
# TODO: Check whether this is a valid type object here.
3169-
pass
3170-
elif not isinstance(instance_type, AnyType):
3171-
self.chk.fail(message_registry.UNSUPPORTED_ARG_2_FOR_SUPER, e)
3172-
3173-
def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
3174-
"""Type check a super expression."""
3175-
if e.info and e.info.bases:
3176-
# TODO fix multiple inheritance etc
3177-
if len(e.info.mro) < 2:
3178-
self.chk.fail('Internal error: unexpected mro for {}: {}'.format(
3179-
e.info.name(), e.info.mro), e)
3180-
return AnyType(TypeOfAny.from_error)
3181-
for base in e.info.mro[1:]:
3182-
if e.name in base.names or base == e.info.mro[-1]:
3183-
if e.info.fallback_to_any and base == e.info.mro[-1]:
3184-
# There's an undefined base class, and we're
3185-
# at the end of the chain. That's not an error.
3186-
return AnyType(TypeOfAny.special_form)
3187-
if not self.chk.in_checked_function():
3188-
return AnyType(TypeOfAny.unannotated)
3189-
if self.chk.scope.active_class() is not None:
3190-
self.chk.fail(message_registry.SUPER_OUTSIDE_OF_METHOD_NOT_SUPPORTED, e)
3191-
return AnyType(TypeOfAny.from_error)
3192-
method = self.chk.scope.top_function()
3193-
assert method is not None
3194-
args = method.arguments
3195-
# super() in a function with empty args is an error; we
3196-
# need something in declared_self.
3197-
if not args:
3198-
self.chk.fail(message_registry.SUPER_ENCLOSING_POSITIONAL_ARGS_REQUIRED, e)
3199-
return AnyType(TypeOfAny.from_error)
3200-
declared_self = args[0].variable.type or fill_typevars(e.info)
3201-
return analyze_member_access(name=e.name,
3202-
typ=fill_typevars(e.info),
3203-
is_lvalue=False,
3204-
is_super=True,
3205-
is_operator=False,
3206-
original_type=declared_self,
3207-
override_info=base,
3208-
context=e,
3209-
msg=self.msg,
3210-
chk=self.chk,
3211-
in_literal_context=self.is_literal_context())
3212-
assert False, 'unreachable'
3213-
else:
3214-
# Invalid super. This has been reported by the semantic analyzer.
3225+
if (not isinstance(type_type, TypeType)
3226+
and not (isinstance(type_type, FunctionLike) and type_type.is_type_obj())):
3227+
self.msg.first_argument_for_super_must_be_type(type_type, e)
32153228
return AnyType(TypeOfAny.from_error)
32163229

3230+
# Imprecisely assume that the instance is of the current class
3231+
if isinstance(instance_type, AnyType):
3232+
if e.info:
3233+
instance_type = fill_typevars(e.info)
3234+
else:
3235+
return AnyType(TypeOfAny.from_another_any, source_any=instance_type)
3236+
elif isinstance(instance_type, TypeType):
3237+
instance_item = instance_type.item
3238+
if isinstance(instance_item, AnyType):
3239+
if e.info:
3240+
instance_type = TypeType(fill_typevars(e.info))
3241+
else:
3242+
return AnyType(TypeOfAny.from_another_any, source_any=instance_item)
3243+
3244+
return type_type, instance_type
3245+
32173246
def visit_slice_expr(self, e: SliceExpr) -> Type:
32183247
expected = make_optional_type(self.named_type('builtins.int'))
32193248
for index in [e.begin_index, e.end_index, e.stride]:
@@ -4001,3 +4030,22 @@ def has_bytes_component(typ: Type) -> bool:
40014030
if isinstance(typ, Instance) and typ.type.fullname() == 'builtins.bytes':
40024031
return True
40034032
return False
4033+
4034+
4035+
def type_info_from_type(typ: Type) -> Optional[TypeInfo]:
4036+
"""Gets the TypeInfo for a type, indirecting through things like type variables and tuples."""
4037+
4038+
if isinstance(typ, FunctionLike) and typ.is_type_obj():
4039+
return typ.type_object()
4040+
if isinstance(typ, TypeType):
4041+
typ = typ.item
4042+
if isinstance(typ, TypeVarType):
4043+
typ = typ.upper_bound
4044+
if isinstance(typ, TupleType):
4045+
typ = tuple_fallback(typ)
4046+
if isinstance(typ, Instance):
4047+
return typ.type
4048+
4049+
# A complicated type. Too tricky, give up.
4050+
# TODO: Do something more clever here.
4051+
return None

mypy/checkmember.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,17 @@ def _analyze_member_access(name: str,
136136
elif isinstance(typ, FunctionLike) and typ.is_type_obj():
137137
return analyze_type_callable_member_access(name, typ, mx)
138138
elif isinstance(typ, TypeType):
139-
return analyze_type_type_member_access(name, typ, mx)
139+
return analyze_type_type_member_access(name, typ, mx, override_info)
140140
elif isinstance(typ, TupleType):
141141
# Actually look up from the fallback instance type.
142-
return _analyze_member_access(name, tuple_fallback(typ), mx)
142+
return _analyze_member_access(name, tuple_fallback(typ), mx, override_info)
143143
elif isinstance(typ, (TypedDictType, LiteralType, FunctionLike)):
144144
# Actually look up from the fallback instance type.
145-
return _analyze_member_access(name, typ.fallback, mx)
145+
return _analyze_member_access(name, typ.fallback, mx, override_info)
146146
elif isinstance(typ, NoneType):
147147
return analyze_none_member_access(name, typ, mx)
148148
elif isinstance(typ, TypeVarType):
149-
return _analyze_member_access(name, typ.upper_bound, mx)
149+
return _analyze_member_access(name, typ.upper_bound, mx, override_info)
150150
elif isinstance(typ, DeletedType):
151151
mx.msg.deleted_as_rvalue(typ, mx.context)
152152
return AnyType(TypeOfAny.from_error)
@@ -238,7 +238,10 @@ def analyze_type_callable_member_access(name: str,
238238
assert False, 'Unexpected type {}'.format(repr(ret_type))
239239

240240

241-
def analyze_type_type_member_access(name: str, typ: TypeType, mx: MemberContext) -> Type:
241+
def analyze_type_type_member_access(name: str,
242+
typ: TypeType,
243+
mx: MemberContext,
244+
override_info: Optional[TypeInfo]) -> Type:
242245
# Similar to analyze_type_callable_attribute_access.
243246
item = None
244247
fallback = mx.builtin_type('builtins.type')
@@ -248,7 +251,7 @@ def analyze_type_type_member_access(name: str, typ: TypeType, mx: MemberContext)
248251
item = typ.item
249252
elif isinstance(typ.item, AnyType):
250253
mx = mx.copy_modified(messages=ignore_messages)
251-
return _analyze_member_access(name, fallback, mx)
254+
return _analyze_member_access(name, fallback, mx, override_info)
252255
elif isinstance(typ.item, TypeVarType):
253256
if isinstance(typ.item.upper_bound, Instance):
254257
item = typ.item.upper_bound
@@ -262,7 +265,7 @@ def analyze_type_type_member_access(name: str, typ: TypeType, mx: MemberContext)
262265
item = typ.item.item.type.metaclass_type
263266
if item and not mx.is_operator:
264267
# See comment above for why operators are skipped
265-
result = analyze_class_attribute_access(item, name, mx)
268+
result = analyze_class_attribute_access(item, name, mx, override_info)
266269
if result:
267270
if not (isinstance(result, AnyType) and item.type.fallback_to_any):
268271
return result
@@ -271,7 +274,7 @@ def analyze_type_type_member_access(name: str, typ: TypeType, mx: MemberContext)
271274
mx = mx.copy_modified(messages=ignore_messages)
272275
if item is not None:
273276
fallback = item.type.metaclass_type or fallback
274-
return _analyze_member_access(name, fallback, mx)
277+
return _analyze_member_access(name, fallback, mx, override_info)
275278

276279

277280
def analyze_union_member_access(name: str, typ: UnionType, mx: MemberContext) -> Type:
@@ -603,11 +606,16 @@ class A:
603606

604607
def analyze_class_attribute_access(itype: Instance,
605608
name: str,
606-
mx: MemberContext) -> Optional[Type]:
609+
mx: MemberContext,
610+
override_info: Optional[TypeInfo] = None) -> Optional[Type]:
607611
"""original_type is the type of E in the expression E.var"""
608-
node = itype.type.get(name)
612+
info = itype.type
613+
if override_info:
614+
info = override_info
615+
616+
node = info.get(name)
609617
if not node:
610-
if itype.type.fallback_to_any:
618+
if info.fallback_to_any:
611619
return AnyType(TypeOfAny.special_form)
612620
return None
613621

@@ -628,9 +636,9 @@ def analyze_class_attribute_access(itype: Instance,
628636
# An assignment to final attribute on class object is also always an error,
629637
# independently of types.
630638
if mx.is_lvalue and not mx.chk.get_final_context():
631-
check_final_member(name, itype.type, mx.msg, mx.context)
639+
check_final_member(name, info, mx.msg, mx.context)
632640

633-
if itype.type.is_enum and not (mx.is_lvalue or is_decorated or is_method):
641+
if info.is_enum and not (mx.is_lvalue or is_decorated or is_method):
634642
enum_literal = LiteralType(name, fallback=itype)
635643
return itype.copy_modified(last_known_value=enum_literal)
636644

@@ -691,7 +699,7 @@ def analyze_class_attribute_access(itype: Instance,
691699

692700
if isinstance(node.node, TypeVarExpr):
693701
mx.msg.fail(message_registry.CANNOT_USE_TYPEVAR_AS_EXPRESSION.format(
694-
itype.type.name(), name), mx.context)
702+
info.name(), name), mx.context)
695703
return AnyType(TypeOfAny.from_error)
696704

697705
if isinstance(node.node, TypeInfo):

mypy/newsemanal/semanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3327,7 +3327,7 @@ def bind_name_expr(self, expr: NameExpr, sym: SymbolTableNode) -> None:
33273327
expr.fullname = sym.fullname
33283328

33293329
def visit_super_expr(self, expr: SuperExpr) -> None:
3330-
if not self.type:
3330+
if not self.type and not expr.call.args:
33313331
self.fail('"super" used outside class', expr)
33323332
return
33333333
expr.info = self.type

mypy/semanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2956,7 +2956,7 @@ def visit_name_expr(self, expr: NameExpr) -> None:
29562956
expr.fullname = n.fullname
29572957

29582958
def visit_super_expr(self, expr: SuperExpr) -> None:
2959-
if not self.type:
2959+
if not self.type and not expr.call.args:
29602960
self.fail('"super" used outside class', expr)
29612961
return
29622962
expr.info = self.type

test-data/unit/check-class-namedtuple.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ Y(y=1, x='1').method()
493493

494494
class CallsBaseInit(X):
495495
def __init__(self, x: str) -> None:
496-
super().__init__(x)
496+
super().__init__(x) # E: Too many arguments for "__init__" of "object"
497497

498498
[case testNewNamedTupleWithMethods]
499499
from typing import NamedTuple

0 commit comments

Comments
 (0)