Skip to content

Commit f8ff732

Browse files
afriedergvanrossum
authored andcommitted
Conditional type-check based on callable call (#2627)
Fixes #1973
1 parent 7692a6e commit f8ff732

File tree

4 files changed

+448
-1
lines changed

4 files changed

+448
-1
lines changed

mypy/checker.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2507,6 +2507,76 @@ def conditional_type_map(expr: Expression,
25072507
return {}, {}
25082508

25092509

2510+
def partition_by_callable(type: Optional[Type]) -> Tuple[List[Type], List[Type]]:
2511+
"""Takes in a type and partitions that type into callable subtypes and
2512+
uncallable subtypes.
2513+
2514+
Thus, given:
2515+
`callables, uncallables = partition_by_callable(type)`
2516+
2517+
If we assert `callable(type)` then `type` has type Union[*callables], and
2518+
If we assert `not callable(type)` then `type` has type Union[*uncallables]
2519+
2520+
Guaranteed to not return [], []"""
2521+
if isinstance(type, FunctionLike) or isinstance(type, TypeType):
2522+
return [type], []
2523+
2524+
if isinstance(type, AnyType):
2525+
return [type], [type]
2526+
2527+
if isinstance(type, UnionType):
2528+
callables = []
2529+
uncallables = []
2530+
for subtype in type.items:
2531+
subcallables, subuncallables = partition_by_callable(subtype)
2532+
callables.extend(subcallables)
2533+
uncallables.extend(subuncallables)
2534+
return callables, uncallables
2535+
2536+
if isinstance(type, TypeVarType):
2537+
return partition_by_callable(type.erase_to_union_or_bound())
2538+
2539+
if isinstance(type, Instance):
2540+
method = type.type.get_method('__call__')
2541+
if method:
2542+
callables, uncallables = partition_by_callable(method.type)
2543+
if len(callables) and not len(uncallables):
2544+
# Only consider the type callable if its __call__ method is
2545+
# definitely callable.
2546+
return [type], []
2547+
return [], [type]
2548+
2549+
return [], [type]
2550+
2551+
2552+
def conditional_callable_type_map(expr: Expression,
2553+
current_type: Optional[Type],
2554+
) -> Tuple[TypeMap, TypeMap]:
2555+
"""Takes in an expression and the current type of the expression.
2556+
2557+
Returns a 2-tuple: The first element is a map from the expression to
2558+
the restricted type if it were callable. The second element is a
2559+
map from the expression to the type it would hold if it weren't
2560+
callable."""
2561+
if not current_type:
2562+
return {}, {}
2563+
2564+
if isinstance(current_type, AnyType):
2565+
return {}, {}
2566+
2567+
callables, uncallables = partition_by_callable(current_type)
2568+
2569+
if len(callables) and len(uncallables):
2570+
callable_map = {expr: UnionType.make_union(callables)} if len(callables) else None
2571+
uncallable_map = {expr: UnionType.make_union(uncallables)} if len(uncallables) else None
2572+
return callable_map, uncallable_map
2573+
2574+
elif len(callables):
2575+
return {}, None
2576+
2577+
return None, {}
2578+
2579+
25102580
def is_true_literal(n: Expression) -> bool:
25112581
return (refers_to_fullname(n, 'builtins.True')
25122582
or isinstance(n, IntExpr) and n.value == 1)
@@ -2580,7 +2650,7 @@ def find_isinstance_check(node: Expression,
25802650
type_map: Dict[Expression, Type],
25812651
) -> Tuple[TypeMap, TypeMap]:
25822652
"""Find any isinstance checks (within a chain of ands). Includes
2583-
implicit and explicit checks for None.
2653+
implicit and explicit checks for None and calls to callable.
25842654
25852655
Return value is a map of variables to their types if the condition
25862656
is true and a map of variables to their types if the condition is false.
@@ -2601,6 +2671,11 @@ def find_isinstance_check(node: Expression,
26012671
vartype = type_map[expr]
26022672
type = get_isinstance_type(node.args[1], type_map)
26032673
return conditional_type_map(expr, vartype, type)
2674+
elif refers_to_fullname(node.callee, 'builtins.callable'):
2675+
expr = node.args[0]
2676+
if expr.literal == LITERAL_TYPE:
2677+
vartype = type_map[expr]
2678+
return conditional_callable_type_map(expr, vartype)
26042679
elif (isinstance(node, ComparisonExpr) and experiments.STRICT_OPTIONAL):
26052680
# Check for `x is None` and `x is not None`.
26062681
is_not = node.operators == ['is not']

mypy/test/testcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
]
3838
fast_parser_files = [
3939
'check-basic.test',
40+
'check-callable.test',
4041
'check-classes.test',
4142
'check-statements.test',
4243
'check-generics.test',

0 commit comments

Comments
 (0)