@@ -2507,6 +2507,76 @@ def conditional_type_map(expr: Expression,
2507
2507
return {}, {}
2508
2508
2509
2509
2510
+ def partition_by_callable (type : Optional [Type ]) -> Tuple [List [Type ], List [Type ]]:
2511
+ """Takes in a type and partitions that type into callable subtypes and
2512
+ uncallable subtypes.
2513
+
2514
+ Thus, given:
2515
+ `callables, uncallables = partition_by_callable(type)`
2516
+
2517
+ If we assert `callable(type)` then `type` has type Union[*callables], and
2518
+ If we assert `not callable(type)` then `type` has type Union[*uncallables]
2519
+
2520
+ Guaranteed to not return [], []"""
2521
+ if isinstance (type , FunctionLike ) or isinstance (type , TypeType ):
2522
+ return [type ], []
2523
+
2524
+ if isinstance (type , AnyType ):
2525
+ return [type ], [type ]
2526
+
2527
+ if isinstance (type , UnionType ):
2528
+ callables = []
2529
+ uncallables = []
2530
+ for subtype in type .items :
2531
+ subcallables , subuncallables = partition_by_callable (subtype )
2532
+ callables .extend (subcallables )
2533
+ uncallables .extend (subuncallables )
2534
+ return callables , uncallables
2535
+
2536
+ if isinstance (type , TypeVarType ):
2537
+ return partition_by_callable (type .erase_to_union_or_bound ())
2538
+
2539
+ if isinstance (type , Instance ):
2540
+ method = type .type .get_method ('__call__' )
2541
+ if method :
2542
+ callables , uncallables = partition_by_callable (method .type )
2543
+ if len (callables ) and not len (uncallables ):
2544
+ # Only consider the type callable if its __call__ method is
2545
+ # definitely callable.
2546
+ return [type ], []
2547
+ return [], [type ]
2548
+
2549
+ return [], [type ]
2550
+
2551
+
2552
+ def conditional_callable_type_map (expr : Expression ,
2553
+ current_type : Optional [Type ],
2554
+ ) -> Tuple [TypeMap , TypeMap ]:
2555
+ """Takes in an expression and the current type of the expression.
2556
+
2557
+ Returns a 2-tuple: The first element is a map from the expression to
2558
+ the restricted type if it were callable. The second element is a
2559
+ map from the expression to the type it would hold if it weren't
2560
+ callable."""
2561
+ if not current_type :
2562
+ return {}, {}
2563
+
2564
+ if isinstance (current_type , AnyType ):
2565
+ return {}, {}
2566
+
2567
+ callables , uncallables = partition_by_callable (current_type )
2568
+
2569
+ if len (callables ) and len (uncallables ):
2570
+ callable_map = {expr : UnionType .make_union (callables )} if len (callables ) else None
2571
+ uncallable_map = {expr : UnionType .make_union (uncallables )} if len (uncallables ) else None
2572
+ return callable_map , uncallable_map
2573
+
2574
+ elif len (callables ):
2575
+ return {}, None
2576
+
2577
+ return None , {}
2578
+
2579
+
2510
2580
def is_true_literal (n : Expression ) -> bool :
2511
2581
return (refers_to_fullname (n , 'builtins.True' )
2512
2582
or isinstance (n , IntExpr ) and n .value == 1 )
@@ -2580,7 +2650,7 @@ def find_isinstance_check(node: Expression,
2580
2650
type_map : Dict [Expression , Type ],
2581
2651
) -> Tuple [TypeMap , TypeMap ]:
2582
2652
"""Find any isinstance checks (within a chain of ands). Includes
2583
- implicit and explicit checks for None.
2653
+ implicit and explicit checks for None and calls to callable .
2584
2654
2585
2655
Return value is a map of variables to their types if the condition
2586
2656
is true and a map of variables to their types if the condition is false.
@@ -2601,6 +2671,11 @@ def find_isinstance_check(node: Expression,
2601
2671
vartype = type_map [expr ]
2602
2672
type = get_isinstance_type (node .args [1 ], type_map )
2603
2673
return conditional_type_map (expr , vartype , type )
2674
+ elif refers_to_fullname (node .callee , 'builtins.callable' ):
2675
+ expr = node .args [0 ]
2676
+ if expr .literal == LITERAL_TYPE :
2677
+ vartype = type_map [expr ]
2678
+ return conditional_callable_type_map (expr , vartype )
2604
2679
elif (isinstance (node , ComparisonExpr ) and experiments .STRICT_OPTIONAL ):
2605
2680
# Check for `x is None` and `x is not None`.
2606
2681
is_not = node .operators == ['is not' ]
0 commit comments