diff --git a/mypy/build.py b/mypy/build.py index ba2d1b1b3d35..c0b9aff5ab32 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -2217,7 +2217,10 @@ def type_checker(self) -> TypeChecker: return self._type_checker def type_map(self) -> Dict[Expression, Type]: - return self.type_checker().type_map + # We can extract the master type map directly since at this + # point no temporary type maps can be active. + assert len(self.type_checker()._type_maps) == 1 + return self.type_checker()._type_maps[0] def type_check_second_pass(self) -> bool: if self.options.semantic_analysis_only: diff --git a/mypy/checker.py b/mypy/checker.py index cff637dbb57a..109a3b1f15d2 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -161,8 +161,15 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface): errors: Errors # Utility for generating messages msg: MessageBuilder - # Types of type checked nodes - type_map: Dict[Expression, Type] + # Types of type checked nodes. The first item is the "master" type + # map that will store the final, exported types. Additional items + # are temporary type maps used during type inference, and these + # will be eventually popped and either discarded or merged into + # the master type map. + # + # Avoid accessing this directly, but prefer the lookup_type(), + # has_type() etc. helpers instead. + _type_maps: List[Dict[Expression, Type]] # Helper for managing conditional types binder: ConditionalTypeBinder @@ -246,7 +253,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option self.partial_reported = set() self.var_decl_frames = {} self.deferred_nodes = [] - self.type_map = {} + self._type_maps = [{}] self.module_refs = set() self.pass_num = 0 self.current_node_deferred = False @@ -283,7 +290,9 @@ def reset(self) -> None: self.partial_reported.clear() self.module_refs.clear() self.binder = ConditionalTypeBinder() - self.type_map.clear() + self._type_maps[1:] = [] + self._type_maps[0].clear() + self.temp_type_map = None assert self.inferred_attribute_types is None assert self.partial_types == [] @@ -2227,9 +2236,9 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: if len(s.lvalues) > 1: # Chained assignment (e.g. x = y = ...). # Make sure that rvalue type will not be reinferred. - if s.rvalue not in self.type_map: + if not self.has_type(s.rvalue): self.expr_checker.accept(s.rvalue) - rvalue = self.temp_node(self.type_map[s.rvalue], s) + rvalue = self.temp_node(self.lookup_type(s.rvalue), s) for lv in s.lvalues[:-1]: with self.enter_final_context(s.is_final_def): self.check_assignment(lv, rvalue, s.type is None) @@ -2935,7 +2944,9 @@ def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: E infer_lvalue_type=infer_lvalue_type, rv_type=item, undefined_rvalue=True) for t, lv in zip(transposed, self.flatten_lvalues(lvalues)): - t.append(self.type_map.pop(lv, AnyType(TypeOfAny.special_form))) + # We can access _type_maps directly since temporary type maps are + # only created within expressions. + t.append(self._type_maps[0].pop(lv, AnyType(TypeOfAny.special_form))) union_types = tuple(make_simplified_union(col) for col in transposed) for expr, items in assignments.items(): # Bind a union of types collected in 'assignments' to every expression. @@ -3986,8 +3997,8 @@ def visit_decorator(self, e: Decorator) -> None: # if this is a expression like @b.a where b is an object, get the type of b # so we can pass it the method hook in the plugins object_type: Optional[Type] = None - if fullname is None and isinstance(d, MemberExpr) and d.expr in self.type_map: - object_type = self.type_map[d.expr] + if fullname is None and isinstance(d, MemberExpr) and self.has_type(d.expr): + object_type = self.lookup_type(d.expr) fullname = self.expr_checker.method_fullname(object_type, d.name) self.check_for_untyped_decorator(e.func, dec, d) sig, t2 = self.expr_checker.check_call(dec, [temp], @@ -4545,8 +4556,6 @@ def find_type_equals_check(self, node: ComparisonExpr, expr_indices: List[int] expr_indices: The list of indices of expressions in ``node`` that are being compared """ - type_map = self.type_map - def is_type_call(expr: CallExpr) -> bool: """Is expr a call to type with one argument?""" return (refers_to_fullname(expr.callee, 'builtins.type') @@ -4565,7 +4574,7 @@ def is_type_call(expr: CallExpr) -> bool: if isinstance(expr, CallExpr) and is_type_call(expr): exprs_in_type_calls.append(expr.args[0]) else: - current_type = get_isinstance_type(expr, type_map) + current_type = self.get_isinstance_type(expr) if current_type is None: continue if type_being_compared is not None: @@ -4588,7 +4597,7 @@ def is_type_call(expr: CallExpr) -> bool: else_maps: List[TypeMap] = [] for expr in exprs_in_type_calls: current_if_type, current_else_type = self.conditional_types_with_intersection( - type_map[expr], + self.lookup_type(expr), type_being_compared, expr ) @@ -4633,12 +4642,11 @@ def find_isinstance_check(self, node: Expression Can return None, None in situations involving NoReturn. """ if_map, else_map = self.find_isinstance_check_helper(node) - new_if_map = self.propagate_up_typemap_info(self.type_map, if_map) - new_else_map = self.propagate_up_typemap_info(self.type_map, else_map) + new_if_map = self.propagate_up_typemap_info(if_map) + new_else_map = self.propagate_up_typemap_info(else_map) return new_if_map, new_else_map def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeMap]: - type_map = self.type_map if is_true_literal(node): return {}, None if is_false_literal(node): @@ -4653,8 +4661,8 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM return conditional_types_to_typemaps( expr, *self.conditional_types_with_intersection( - type_map[expr], - get_isinstance_type(node.args[1], type_map), + self.lookup_type(expr), + self.get_isinstance_type(node.args[1]), expr ) ) @@ -4662,12 +4670,12 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if len(node.args) != 2: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: - return self.infer_issubclass_maps(node, expr, type_map) + return self.infer_issubclass_maps(node, expr) elif refers_to_fullname(node.callee, 'builtins.callable'): if len(node.args) != 1: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: - vartype = type_map[expr] + vartype = self.lookup_type(expr) return self.conditional_callable_type_map(expr, vartype) elif isinstance(node.callee, RefExpr): if node.callee.type_guard is not None: @@ -4691,14 +4699,14 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM operand_types = [] narrowable_operand_index_to_hash = {} for i, expr in enumerate(operands): - if expr not in type_map: + if not self.has_type(expr): return {}, {} - expr_type = type_map[expr] + expr_type = self.lookup_type(expr) operand_types.append(expr_type) if (literal(expr) == LITERAL_TYPE and not is_literal_none(expr) - and not is_literal_enum(type_map, expr)): + and not self.is_literal_enum(expr)): h = literal_hash(expr) if h is not None: narrowable_operand_index_to_hash[i] = h @@ -4871,7 +4879,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively - original_vartype = type_map[node] + original_vartype = self.lookup_type(node) self._check_for_truthy_type(original_vartype, node) vartype = try_expanding_sum_type_to_union(original_vartype, "builtins.bool") @@ -4890,7 +4898,6 @@ def has_no_custom_eq_checks(t: Type) -> bool: return if_map, else_map def propagate_up_typemap_info(self, - existing_types: Mapping[Expression, Type], new_types: TypeMap) -> TypeMap: """Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types. @@ -4924,7 +4931,7 @@ def propagate_up_typemap_info(self, output_map[expr] = expr_type # Next, try using this information to refine the parent types, if applicable. - new_mapping = self.refine_parent_types(existing_types, expr, expr_type) + new_mapping = self.refine_parent_types(expr, expr_type) for parent_expr, proposed_parent_type in new_mapping.items(): # We don't try inferring anything if we've already inferred something for # the parent expression. @@ -4935,7 +4942,6 @@ def propagate_up_typemap_info(self, return output_map def refine_parent_types(self, - existing_types: Mapping[Expression, Type], expr: Expression, expr_type: Type) -> Mapping[Expression, Type]: """Checks if the given expr is a 'lookup operation' into a union and iteratively refines @@ -4958,7 +4964,7 @@ def refine_parent_types(self, # operation against arbitrary types. if isinstance(expr, MemberExpr): parent_expr = expr.expr - parent_type = existing_types.get(parent_expr) + parent_type = self.lookup_type_or_none(parent_expr) member_name = expr.name def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: @@ -4981,9 +4987,9 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: return member_type elif isinstance(expr, IndexExpr): parent_expr = expr.base - parent_type = existing_types.get(parent_expr) + parent_type = self.lookup_type_or_none(parent_expr) - index_type = existing_types.get(expr.index) + index_type = self.lookup_type_or_none(expr.index) if index_type is None: return output @@ -5335,7 +5341,41 @@ def str_type(self) -> Instance: def store_type(self, node: Expression, typ: Type) -> None: """Store the type of a node in the type map.""" - self.type_map[node] = typ + self._type_maps[-1][node] = typ + + def has_type(self, node: Expression) -> bool: + for m in reversed(self._type_maps): + if node in m: + return True + return False + + def lookup_type_or_none(self, node: Expression) -> Optional[Type]: + for m in reversed(self._type_maps): + if node in m: + return m[node] + return None + + def lookup_type(self, node: Expression) -> Type: + for m in reversed(self._type_maps): + t = m.get(node) + if t is not None: + return t + raise KeyError(node) + + def store_types(self, d: Dict[Expression, Type]) -> None: + self._type_maps[-1].update(d) + + @contextmanager + def local_type_map(self) -> Iterator[Dict[Expression, Type]]: + """Store inferred types into a temporary type map (returned). + + This can be used to perform type checking "experiments" without + affecting exported types (which are used by mypyc). + """ + temp_type_map: Dict[Expression, Type] = {} + self._type_maps.append(temp_type_map) + yield temp_type_map + self._type_maps.pop() def in_checked_function(self) -> bool: """Should we type-check the current function? @@ -5579,11 +5619,10 @@ def push_type_map(self, type_map: 'TypeMap') -> None: def infer_issubclass_maps(self, node: CallExpr, expr: Expression, - type_map: Dict[Expression, Type] ) -> Tuple[TypeMap, TypeMap]: """Infer type restrictions for an expression in issubclass call.""" - vartype = type_map[expr] - type = get_isinstance_type(node.args[1], type_map) + vartype = self.lookup_type(expr) + type = self.get_isinstance_type(node.args[1]) if isinstance(vartype, TypeVarType): vartype = vartype.upper_bound vartype = get_proper_type(vartype) @@ -5683,6 +5722,75 @@ def is_writable_attribute(self, node: Node) -> bool: else: return False + def get_isinstance_type(self, expr: Expression) -> Optional[List[TypeRange]]: + if isinstance(expr, OpExpr) and expr.op == '|': + left = self.get_isinstance_type(expr.left) + right = self.get_isinstance_type(expr.right) + if left is None or right is None: + return None + return left + right + all_types = get_proper_types(flatten_types(self.lookup_type(expr))) + types: List[TypeRange] = [] + for typ in all_types: + if isinstance(typ, FunctionLike) and typ.is_type_obj(): + # Type variables may be present -- erase them, which is the best + # we can do (outside disallowing them here). + erased_type = erase_typevars(typ.items[0].ret_type) + types.append(TypeRange(erased_type, is_upper_bound=False)) + elif isinstance(typ, TypeType): + # Type[A] means "any type that is a subtype of A" rather than "precisely type A" + # we indicate this by setting is_upper_bound flag + types.append(TypeRange(typ.item, is_upper_bound=True)) + elif isinstance(typ, Instance) and typ.type.fullname == 'builtins.type': + object_type = Instance(typ.type.mro[-1], []) + types.append(TypeRange(object_type, is_upper_bound=True)) + elif isinstance(typ, AnyType): + types.append(TypeRange(typ, is_upper_bound=False)) + else: # we didn't see an actual type, but rather a variable with unknown value + return None + if not types: + # this can happen if someone has empty tuple as 2nd argument to isinstance + # strictly speaking, we should return UninhabitedType but for simplicity we will simply + # refuse to do any type inference for now + return None + return types + + def is_literal_enum(self, n: Expression) -> bool: + """Returns true if this expression (with the given type context) is an Enum literal. + + For example, if we had an enum: + + class Foo(Enum): + A = 1 + B = 2 + + ...and if the expression 'Foo' referred to that enum within the current type context, + then the expression 'Foo.A' would be a literal enum. However, if we did 'a = Foo.A', + then the variable 'a' would *not* be a literal enum. + + We occasionally special-case expressions like 'Foo.A' and treat them as a single primitive + unit for the same reasons we sometimes treat 'True', 'False', or 'None' as a single + primitive unit. + """ + if not isinstance(n, MemberExpr) or not isinstance(n.expr, NameExpr): + return False + + parent_type = self.lookup_type_or_none(n.expr) + member_type = self.lookup_type_or_none(n) + if member_type is None or parent_type is None: + return False + + parent_type = get_proper_type(parent_type) + member_type = get_proper_type(coerce_to_literal(member_type)) + if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType): + return False + + if not parent_type.is_type_obj(): + return False + + return (member_type.is_enum_literal() + and member_type.fallback.type == parent_type.type_object()) + @overload def conditional_types(current_type: Type, @@ -5785,42 +5893,6 @@ def is_false_literal(n: Expression) -> bool: or isinstance(n, IntExpr) and n.value == 0) -def is_literal_enum(type_map: Mapping[Expression, Type], n: Expression) -> bool: - """Returns true if this expression (with the given type context) is an Enum literal. - - For example, if we had an enum: - - class Foo(Enum): - A = 1 - B = 2 - - ...and if the expression 'Foo' referred to that enum within the current type context, - then the expression 'Foo.A' would be a literal enum. However, if we did 'a = Foo.A', - then the variable 'a' would *not* be a literal enum. - - We occasionally special-case expressions like 'Foo.A' and treat them as a single primitive - unit for the same reasons we sometimes treat 'True', 'False', or 'None' as a single - primitive unit. - """ - if not isinstance(n, MemberExpr) or not isinstance(n.expr, NameExpr): - return False - - parent_type = type_map.get(n.expr) - member_type = type_map.get(n) - if member_type is None or parent_type is None: - return False - - parent_type = get_proper_type(parent_type) - member_type = get_proper_type(coerce_to_literal(member_type)) - if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType): - return False - - if not parent_type.is_type_obj(): - return False - - return member_type.is_enum_literal() and member_type.fallback.type == parent_type.type_object() - - def is_literal_none(n: Expression) -> bool: """Returns true if this expression is the 'None' literal/keyword.""" return isinstance(n, NameExpr) and n.fullname == 'builtins.None' @@ -5986,41 +6058,6 @@ def flatten_types(t: Type) -> List[Type]: return [t] -def get_isinstance_type(expr: Expression, - type_map: Dict[Expression, Type]) -> Optional[List[TypeRange]]: - if isinstance(expr, OpExpr) and expr.op == '|': - left = get_isinstance_type(expr.left, type_map) - right = get_isinstance_type(expr.right, type_map) - if left is None or right is None: - return None - return left + right - all_types = get_proper_types(flatten_types(type_map[expr])) - types: List[TypeRange] = [] - for typ in all_types: - if isinstance(typ, FunctionLike) and typ.is_type_obj(): - # Type variables may be present -- erase them, which is the best - # we can do (outside disallowing them here). - erased_type = erase_typevars(typ.items[0].ret_type) - types.append(TypeRange(erased_type, is_upper_bound=False)) - elif isinstance(typ, TypeType): - # Type[A] means "any type that is a subtype of A" rather than "precisely type A" - # we indicate this by setting is_upper_bound flag - types.append(TypeRange(typ.item, is_upper_bound=True)) - elif isinstance(typ, Instance) and typ.type.fullname == 'builtins.type': - object_type = Instance(typ.type.mro[-1], []) - types.append(TypeRange(object_type, is_upper_bound=True)) - elif isinstance(typ, AnyType): - types.append(TypeRange(typ, is_upper_bound=False)) - else: # we didn't see an actual type, but rather a variable whose value is unknown to us - return None - if not types: - # this can happen if someone has empty tuple as 2nd argument to isinstance - # strictly speaking, we should return UninhabitedType but for simplicity we will simply - # refuse to do any type inference for now - return None - return types - - def expand_func(defn: FuncItem, map: Dict[TypeVarId, Type]) -> FuncItem: visitor = TypeTransformVisitor(map) ret = defn.accept(visitor) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 21fd7d81967f..bfbe961adc7a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -369,9 +369,9 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> # be invoked for these. if (fullname is None and isinstance(e.callee, MemberExpr) - and e.callee.expr in self.chk.type_map): + and self.chk.has_type(e.callee.expr)): member = e.callee.name - object_type = self.chk.type_map[e.callee.expr] + object_type = self.chk.lookup_type(e.callee.expr) ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, object_type, member) if isinstance(e.callee, RefExpr) and len(e.args) == 2: @@ -401,8 +401,8 @@ def check_str_format_call(self, e: CallExpr) -> None: format_value = None if isinstance(e.callee.expr, (StrExpr, UnicodeExpr)): format_value = e.callee.expr.value - elif e.callee.expr in self.chk.type_map: - base_typ = try_getting_literal(self.chk.type_map[e.callee.expr]) + elif self.chk.has_type(e.callee.expr): + base_typ = try_getting_literal(self.chk.lookup_type(e.callee.expr)) if isinstance(base_typ, LiteralType) and isinstance(base_typ.value, str): format_value = base_typ.value if format_value is not None: @@ -442,7 +442,7 @@ def always_returns_none(self, node: Expression) -> bool: if self.defn_returns_none(node.node): return True if isinstance(node, MemberExpr) and node.node is None: # instance or class attribute - typ = get_proper_type(self.chk.type_map.get(node.expr)) + typ = get_proper_type(self.chk.lookup_type(node.expr)) if isinstance(typ, Instance): info = typ.type elif isinstance(typ, CallableType) and typ.is_type_obj(): @@ -478,7 +478,7 @@ def defn_returns_none(self, defn: Optional[SymbolNode]) -> bool: def check_runtime_protocol_test(self, e: CallExpr) -> None: for expr in mypy.checker.flatten(e.args[1]): - tp = get_proper_type(self.chk.type_map[expr]) + tp = get_proper_type(self.chk.lookup_type(expr)) if (isinstance(tp, CallableType) and tp.is_type_obj() and tp.type_object().is_protocol and not tp.type_object().runtime_protocol): @@ -486,7 +486,7 @@ def check_runtime_protocol_test(self, e: CallExpr) -> None: def check_protocol_issubclass(self, e: CallExpr) -> None: for expr in mypy.checker.flatten(e.args[1]): - tp = get_proper_type(self.chk.type_map[expr]) + tp = get_proper_type(self.chk.lookup_type(expr)) if (isinstance(tp, CallableType) and tp.is_type_obj() and tp.type_object().is_protocol): attr_members = non_method_protocol_members(tp.type_object()) @@ -1740,18 +1740,20 @@ def infer_overload_return_type(self, return_types: List[Type] = [] inferred_types: List[Type] = [] args_contain_any = any(map(has_any_type, arg_types)) + type_maps: List[Dict[Expression, Type]] = [] for typ in plausible_targets: assert self.msg is self.chk.msg with self.msg.filter_errors() as w: - ret_type, infer_type = self.check_call( - callee=typ, - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type) + with self.chk.local_type_map() as m: + ret_type, infer_type = self.check_call( + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type) is_match = not w.has_new_errors() if is_match: # Return early if possible; otherwise record info so we can @@ -1761,6 +1763,7 @@ def infer_overload_return_type(self, matches.append(typ) return_types.append(ret_type) inferred_types.append(infer_type) + type_maps.append(m) if len(matches) == 0: # No match was found @@ -1769,8 +1772,10 @@ def infer_overload_return_type(self, # An argument of type or containing the type 'Any' caused ambiguity. # We try returning a precise type if we can. If not, we give up and just return 'Any'. if all_same_types(return_types): + self.chk.store_types(type_maps[0]) return return_types[0], inferred_types[0] elif all_same_types([erase_type(typ) for typ in return_types]): + self.chk.store_types(type_maps[0]) return erase_type(return_types[0]), erase_type(inferred_types[0]) else: return self.check_call(callee=AnyType(TypeOfAny.special_form), @@ -1782,6 +1787,7 @@ def infer_overload_return_type(self, object_type=object_type) else: # Success! No ambiguity; return the first match. + self.chk.store_types(type_maps[0]) return return_types[0], inferred_types[0] def overload_erased_call_targets(self, @@ -3546,10 +3552,10 @@ def visit_lambda_expr(self, e: LambdaExpr) -> Type: # Type context available. self.chk.return_types.append(inferred_type.ret_type) self.chk.check_func_item(e, type_override=type_override) - if e.expr() not in self.chk.type_map: + if not self.chk.has_type(e.expr()): # TODO: return expression must be accepted before exiting function scope. self.accept(e.expr(), allow_none_return=True) - ret_type = self.chk.type_map[e.expr()] + ret_type = self.chk.lookup_type(e.expr()) self.chk.return_types.pop() return replace_callable_return_type(inferred_type, ret_type) diff --git a/mypy/checkstrformat.py b/mypy/checkstrformat.py index 20b3716ea513..60a0d35ede08 100644 --- a/mypy/checkstrformat.py +++ b/mypy/checkstrformat.py @@ -317,7 +317,7 @@ def check_specs_in_format_call(self, call: CallExpr, assert len(replacements) == len(specs) for spec, repl in zip(specs, replacements): repl = self.apply_field_accessors(spec, repl, ctx=call) - actual_type = repl.type if isinstance(repl, TempNode) else self.chk.type_map.get(repl) + actual_type = repl.type if isinstance(repl, TempNode) else self.chk.lookup_type(repl) assert actual_type is not None # Special case custom formatting. @@ -370,7 +370,7 @@ def perform_special_format_checks(self, spec: ConversionSpecifier, call: CallExp if spec.conv_type == 'c': if isinstance(repl, (StrExpr, BytesExpr)) and len(repl.value) != 1: self.msg.requires_int_or_char(call, format_call=True) - c_typ = get_proper_type(self.chk.type_map[repl]) + c_typ = get_proper_type(self.chk.lookup_type(repl)) if isinstance(c_typ, Instance) and c_typ.last_known_value: c_typ = c_typ.last_known_value if isinstance(c_typ, LiteralType) and isinstance(c_typ.value, str): @@ -442,7 +442,7 @@ def get_expr_by_position(self, pos: int, call: CallExpr) -> Optional[Expression] # Fall back to *args when present in call. star_arg = star_args[0] - varargs_type = get_proper_type(self.chk.type_map[star_arg]) + varargs_type = get_proper_type(self.chk.lookup_type(star_arg)) if (not isinstance(varargs_type, Instance) or not varargs_type.type.has_base('typing.Sequence')): # Error should be already reported. @@ -465,7 +465,7 @@ def get_expr_by_name(self, key: str, call: CallExpr) -> Optional[Expression]: if not star_args_2: return None star_arg_2 = star_args_2[0] - kwargs_type = get_proper_type(self.chk.type_map[star_arg_2]) + kwargs_type = get_proper_type(self.chk.lookup_type(star_arg_2)) if (not isinstance(kwargs_type, Instance) or not kwargs_type.type.has_base('typing.Mapping')): # Error should be already reported. diff --git a/mypyc/test-data/run-functions.test b/mypyc/test-data/run-functions.test index 66b56503f329..77e9c9ed32f7 100644 --- a/mypyc/test-data/run-functions.test +++ b/mypyc/test-data/run-functions.test @@ -1192,3 +1192,31 @@ def foo(): pass def test_decorator_name(): assert foo.__name__ == "foo" + +[case testLambdaArgToOverloaded] +from lib import sub + +def test_str_overload() -> None: + assert sub('x', lambda m: m) == 'x' + +def test_bytes_overload() -> None: + assert sub(b'x', lambda m: m) == b'x' + +[file lib.py] +from typing import overload, Callable, TypeVar, Generic + +T = TypeVar("T") + +class Match(Generic[T]): + def __init__(self, x: T) -> None: + self.x = x + + def group(self, n: int) -> T: + return self.x + +@overload +def sub(s: str, f: Callable[[str], str]) -> str: ... +@overload +def sub(s: bytes, f: Callable[[bytes], bytes]) -> bytes: ... +def sub(s, f): + return f(s) diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index bdefb49e3038..5cbdf38d1b4f 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -1182,6 +1182,43 @@ IntExpr(2) : Literal[1]? OpExpr(2) : builtins.str StrExpr(2) : Literal['%d']? +[case testExportOverloadArgType] +## LambdaExpr|NameExpr +from typing import List, overload, Callable +@overload +def f(x: int, f: Callable[[int], int]) -> None: ... +@overload +def f(x: str, f: Callable[[str], str]) -> None: ... +def f(x): ... +f( + 1, lambda x: x) +[builtins fixtures/list.pyi] +[out] +NameExpr(8) : Overload(def (x: builtins.int, f: def (builtins.int) -> builtins.int), def (x: builtins.str, f: def (builtins.str) -> builtins.str)) +LambdaExpr(9) : def (builtins.int) -> builtins.int +NameExpr(9) : builtins.int + +[case testExportOverloadArgTypeNested] +## LambdaExpr +from typing import overload, Callable +@overload +def f(x: int, f: Callable[[int], int]) -> int: ... +@overload +def f(x: str, f: Callable[[str], str]) -> str: ... +def f(x): ... +f( + f(1, lambda y: y), + lambda x: x) +f( + f('x', lambda y: y), + lambda x: x) +[builtins fixtures/list.pyi] +[out] +LambdaExpr(9) : def (builtins.int) -> builtins.int +LambdaExpr(10) : def (builtins.int) -> builtins.int +LambdaExpr(12) : def (builtins.str) -> builtins.str +LambdaExpr(13) : def (builtins.str) -> builtins.str + -- TODO -- -- test expressions @@ -1193,7 +1230,6 @@ StrExpr(2) : Literal['%d']? -- more complex lambda (multiple arguments etc.) -- list comprehension -- generator expression --- overloads -- other things -- type inference -- default argument value