@@ -1742,7 +1742,7 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
1742
1742
1743
1743
def visit_func_expr (self , e : FuncExpr ) -> Type :
1744
1744
"""Type check lambda expression."""
1745
- inferred_type = self .infer_lambda_type_using_context (e )
1745
+ inferred_type , type_override = self .infer_lambda_type_using_context (e )
1746
1746
if not inferred_type :
1747
1747
# No useful type context.
1748
1748
ret_type = self .accept (e .expr ())
@@ -1752,7 +1752,7 @@ def visit_func_expr(self, e: FuncExpr) -> Type:
1752
1752
return callable_type (e , fallback , ret_type )
1753
1753
else :
1754
1754
# Type context available.
1755
- self .chk .check_func_item (e , type_override = inferred_type )
1755
+ self .chk .check_func_item (e , type_override = type_override )
1756
1756
if e .expr () not in self .chk .type_map :
1757
1757
self .accept (e .expr ())
1758
1758
ret_type = self .chk .type_map [e .expr ()]
@@ -1763,10 +1763,12 @@ def visit_func_expr(self, e: FuncExpr) -> Type:
1763
1763
return inferred_type
1764
1764
return replace_callable_return_type (inferred_type , ret_type )
1765
1765
1766
- def infer_lambda_type_using_context (self , e : FuncExpr ) -> Optional [CallableType ]:
1766
+ def infer_lambda_type_using_context (self , e : FuncExpr ) -> Tuple [Optional [CallableType ],
1767
+ Optional [CallableType ]]:
1767
1768
"""Try to infer lambda expression type using context.
1768
1769
1769
1770
Return None if could not infer type.
1771
+ The second item in the return type is the type_override parameter for check_func_item.
1770
1772
"""
1771
1773
# TODO also accept 'Any' context
1772
1774
ctx = self .chk .type_context [- 1 ]
@@ -1777,7 +1779,7 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> Optional[CallableType]
1777
1779
ctx = callables [0 ]
1778
1780
1779
1781
if not ctx or not isinstance (ctx , CallableType ):
1780
- return None
1782
+ return None , None
1781
1783
1782
1784
# The context may have function type variables in it. We replace them
1783
1785
# since these are the type variables we are ultimately trying to infer;
@@ -1799,13 +1801,13 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> Optional[CallableType]
1799
1801
1800
1802
if ARG_STAR in arg_kinds or ARG_STAR2 in arg_kinds :
1801
1803
# TODO treat this case appropriately
1802
- return None
1804
+ return callable_ctx , None
1803
1805
if callable_ctx .arg_kinds != arg_kinds :
1804
1806
# Incompatible context; cannot use it to infer types.
1805
1807
self .chk .fail (messages .CANNOT_INFER_LAMBDA_TYPE , e )
1806
- return None
1808
+ return None , None
1807
1809
1808
- return callable_ctx
1810
+ return callable_ctx , callable_ctx
1809
1811
1810
1812
def visit_super_expr (self , e : SuperExpr ) -> Type :
1811
1813
"""Type check a super expression (non-lvalue)."""
0 commit comments