Skip to content

WIP: we can now make use of __class_getitem__ for non-generic types #11558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2932,6 +2932,9 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
# Visit the index, just to make sure we have a type for it available
self.accept(index)

# Base method name for instance-based item access, might be changed for types:
method_name = '__getitem__'

if isinstance(left_type, UnionType):
original_type = original_type or left_type
# Don't combine literal types, since we may need them for type narrowing.
Expand Down Expand Up @@ -2967,12 +2970,22 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
elif (isinstance(left_type, TypeVarType)
and not self.has_member(left_type.upper_bound, "__getitem__")):
return self.visit_index_with_type(left_type.upper_bound, e, original_type)
else:
result, method_type = self.check_method_call_by_name(
'__getitem__', left_type, [e.index], [ARG_POS], e,
original_type=original_type)
e.method_type = method_type
return result
elif (isinstance(left_type, TypeType)
or (isinstance(left_type, CallableType) and left_type.is_type_obj())):
# When we do `SomeClass[1]`, we actually call `__class_getitem__`,
# not just `__getitem__`.
method_name = '__class_getitem__'

result, method_type = self.check_method_call_by_name(
method=method_name,
base_type=left_type,
args=[e.index],
arg_kinds=[ARG_POS],
context=e,
original_type=original_type,
)
e.method_type = method_type
return result

def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Type:
begin: Sequence[Optional[int]] = [None]
Expand Down
5 changes: 4 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def analyze_type_callable_member_access(name: str,
if isinstance(ret_type, TupleType):
ret_type = tuple_fallback(ret_type)
if isinstance(ret_type, Instance):
if not mx.is_operator:
if not mx.is_operator or name == '__class_getitem__':
# When Python sees an operator (eg `3 == 4`), it automatically translates that
# into something like `int.__eq__(3, 4)` instead of `(3).__eq__(4)` as an
# optimization.
Expand All @@ -250,6 +250,9 @@ def analyze_type_callable_member_access(name: str,
# the corresponding method in the current instance to avoid this edge case.
# See https://github.com/python/mypy/pull/1787 for more info.
# TODO: do not rely on same type variables being present in all constructor overloads.

# We also allow `SomeClass[1]` acccess via `__class_getitem__`,
# it is very special. It only works this way for non-generic types.
result = analyze_class_attribute_access(ret_type, name, mx,
original_vars=typ.items[0].variables)
if result:
Expand Down
7 changes: 7 additions & 0 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,12 @@ def fix_instance(t: Instance, fail: MsgCallback, note: MsgCallback,

Also emit a suitable error if this is not due to implicit Any's.
"""
if not t.type.is_generic() and t.type.has_readable_member('__class_getitem__'):
# Corner case: we might have a non-generic class with `__class_getitem__`
# which is used for something else: not type application.
# So, in this case: we allow using this type without type arguments.
return

if len(t.args) == 0:
if use_generic_error:
fullname: Optional[str] = None
Expand All @@ -1081,6 +1087,7 @@ def fix_instance(t: Instance, fail: MsgCallback, note: MsgCallback,
unexpanded_type)
t.args = (any_type,) * len(t.type.type_vars)
return

# Invalid number of type parameters.
n = len(t.type.type_vars)
s = '{} type arguments'.format(n)
Expand Down