@@ -743,7 +743,9 @@ def incompatible_argument_note(
743
743
context : Context ,
744
744
code : ErrorCode | None ,
745
745
) -> None :
746
- if isinstance (original_caller_type , (Instance , TupleType , TypedDictType )):
746
+ if isinstance (
747
+ original_caller_type , (Instance , TupleType , TypedDictType , TypeType , CallableType )
748
+ ):
747
749
if isinstance (callee_type , Instance ) and callee_type .type .is_protocol :
748
750
self .report_protocol_problems (
749
751
original_caller_type , callee_type , context , code = code
@@ -1791,7 +1793,7 @@ def impossible_intersection(
1791
1793
1792
1794
def report_protocol_problems (
1793
1795
self ,
1794
- subtype : Instance | TupleType | TypedDictType ,
1796
+ subtype : Instance | TupleType | TypedDictType | TypeType | CallableType ,
1795
1797
supertype : Instance ,
1796
1798
context : Context ,
1797
1799
* ,
@@ -1811,15 +1813,15 @@ def report_protocol_problems(
1811
1813
exclusions : dict [type , list [str ]] = {
1812
1814
TypedDictType : ["typing.Mapping" ],
1813
1815
TupleType : ["typing.Iterable" , "typing.Sequence" ],
1814
- Instance : [],
1815
1816
}
1816
- if supertype .type .fullname in exclusions [ type (subtype )] :
1817
+ if supertype .type .fullname in exclusions . get ( type (subtype ), []) :
1817
1818
return
1818
1819
if any (isinstance (tp , UninhabitedType ) for tp in get_proper_types (supertype .args )):
1819
1820
# We don't want to add notes for failed inference (e.g. Iterable[<nothing>]).
1820
1821
# This will be only confusing a user even more.
1821
1822
return
1822
1823
1824
+ class_obj = False
1823
1825
if isinstance (subtype , TupleType ):
1824
1826
if not isinstance (subtype .partial_fallback , Instance ):
1825
1827
return
@@ -1828,6 +1830,21 @@ def report_protocol_problems(
1828
1830
if not isinstance (subtype .fallback , Instance ):
1829
1831
return
1830
1832
subtype = subtype .fallback
1833
+ elif isinstance (subtype , TypeType ):
1834
+ if not isinstance (subtype .item , Instance ):
1835
+ return
1836
+ class_obj = True
1837
+ subtype = subtype .item
1838
+ elif isinstance (subtype , CallableType ):
1839
+ if not subtype .is_type_obj ():
1840
+ return
1841
+ ret_type = get_proper_type (subtype .ret_type )
1842
+ if isinstance (ret_type , TupleType ):
1843
+ ret_type = ret_type .partial_fallback
1844
+ if not isinstance (ret_type , Instance ):
1845
+ return
1846
+ class_obj = True
1847
+ subtype = ret_type
1831
1848
1832
1849
# Report missing members
1833
1850
missing = get_missing_protocol_members (subtype , supertype )
@@ -1836,20 +1853,29 @@ def report_protocol_problems(
1836
1853
and len (missing ) < len (supertype .type .protocol_members )
1837
1854
and len (missing ) <= MAX_ITEMS
1838
1855
):
1839
- self .note (
1840
- '"{}" is missing following "{}" protocol member{}:' .format (
1841
- subtype .type .name , supertype .type .name , plural_s (missing )
1842
- ),
1843
- context ,
1844
- code = code ,
1845
- )
1846
- self .note (", " .join (missing ), context , offset = OFFSET , code = code )
1856
+ if missing == ["__call__" ] and class_obj :
1857
+ self .note (
1858
+ '"{}" has constructor incompatible with "__call__" of "{}"' .format (
1859
+ subtype .type .name , supertype .type .name
1860
+ ),
1861
+ context ,
1862
+ code = code ,
1863
+ )
1864
+ else :
1865
+ self .note (
1866
+ '"{}" is missing following "{}" protocol member{}:' .format (
1867
+ subtype .type .name , supertype .type .name , plural_s (missing )
1868
+ ),
1869
+ context ,
1870
+ code = code ,
1871
+ )
1872
+ self .note (", " .join (missing ), context , offset = OFFSET , code = code )
1847
1873
elif len (missing ) > MAX_ITEMS or len (missing ) == len (supertype .type .protocol_members ):
1848
1874
# This is an obviously wrong type: too many missing members
1849
1875
return
1850
1876
1851
1877
# Report member type conflicts
1852
- conflict_types = get_conflict_protocol_types (subtype , supertype )
1878
+ conflict_types = get_conflict_protocol_types (subtype , supertype , class_obj = class_obj )
1853
1879
if conflict_types and (
1854
1880
not is_subtype (subtype , erase_type (supertype ))
1855
1881
or not subtype .type .defn .type_vars
@@ -1875,29 +1901,43 @@ def report_protocol_problems(
1875
1901
else :
1876
1902
self .note ("Expected:" , context , offset = OFFSET , code = code )
1877
1903
if isinstance (exp , CallableType ):
1878
- self .note (pretty_callable (exp ), context , offset = 2 * OFFSET , code = code )
1904
+ self .note (
1905
+ pretty_callable (exp , skip_self = class_obj ),
1906
+ context ,
1907
+ offset = 2 * OFFSET ,
1908
+ code = code ,
1909
+ )
1879
1910
else :
1880
1911
assert isinstance (exp , Overloaded )
1881
- self .pretty_overload (exp , context , 2 * OFFSET , code = code )
1912
+ self .pretty_overload (
1913
+ exp , context , 2 * OFFSET , code = code , skip_self = class_obj
1914
+ )
1882
1915
self .note ("Got:" , context , offset = OFFSET , code = code )
1883
1916
if isinstance (got , CallableType ):
1884
- self .note (pretty_callable (got ), context , offset = 2 * OFFSET , code = code )
1917
+ self .note (
1918
+ pretty_callable (got , skip_self = class_obj ),
1919
+ context ,
1920
+ offset = 2 * OFFSET ,
1921
+ code = code ,
1922
+ )
1885
1923
else :
1886
1924
assert isinstance (got , Overloaded )
1887
- self .pretty_overload (got , context , 2 * OFFSET , code = code )
1925
+ self .pretty_overload (
1926
+ got , context , 2 * OFFSET , code = code , skip_self = class_obj
1927
+ )
1888
1928
self .print_more (conflict_types , context , OFFSET , MAX_ITEMS , code = code )
1889
1929
1890
1930
# Report flag conflicts (i.e. settable vs read-only etc.)
1891
- conflict_flags = get_bad_protocol_flags (subtype , supertype )
1931
+ conflict_flags = get_bad_protocol_flags (subtype , supertype , class_obj = class_obj )
1892
1932
for name , subflags , superflags in conflict_flags [:MAX_ITEMS ]:
1893
- if IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags :
1933
+ if not class_obj and IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags :
1894
1934
self .note (
1895
1935
"Protocol member {}.{} expected instance variable,"
1896
1936
" got class variable" .format (supertype .type .name , name ),
1897
1937
context ,
1898
1938
code = code ,
1899
1939
)
1900
- if IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags :
1940
+ if not class_obj and IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags :
1901
1941
self .note (
1902
1942
"Protocol member {}.{} expected class variable,"
1903
1943
" got instance variable" .format (supertype .type .name , name ),
@@ -1919,6 +1959,13 @@ def report_protocol_problems(
1919
1959
context ,
1920
1960
code = code ,
1921
1961
)
1962
+ if class_obj and IS_SETTABLE in superflags and IS_CLASSVAR not in subflags :
1963
+ self .note (
1964
+ "Only class variables allowed for class object access on protocols,"
1965
+ ' {} is an instance variable of "{}"' .format (name , subtype .type .name ),
1966
+ context ,
1967
+ code = code ,
1968
+ )
1922
1969
self .print_more (conflict_flags , context , OFFSET , MAX_ITEMS , code = code )
1923
1970
1924
1971
def pretty_overload (
@@ -1930,6 +1977,7 @@ def pretty_overload(
1930
1977
add_class_or_static_decorator : bool = False ,
1931
1978
allow_dups : bool = False ,
1932
1979
code : ErrorCode | None = None ,
1980
+ skip_self : bool = False ,
1933
1981
) -> None :
1934
1982
for item in tp .items :
1935
1983
self .note ("@overload" , context , offset = offset , allow_dups = allow_dups , code = code )
@@ -1940,7 +1988,11 @@ def pretty_overload(
1940
1988
self .note (decorator , context , offset = offset , allow_dups = allow_dups , code = code )
1941
1989
1942
1990
self .note (
1943
- pretty_callable (item ), context , offset = offset , allow_dups = allow_dups , code = code
1991
+ pretty_callable (item , skip_self = skip_self ),
1992
+ context ,
1993
+ offset = offset ,
1994
+ allow_dups = allow_dups ,
1995
+ code = code ,
1944
1996
)
1945
1997
1946
1998
def print_more (
@@ -2373,10 +2425,14 @@ def pretty_class_or_static_decorator(tp: CallableType) -> str | None:
2373
2425
return None
2374
2426
2375
2427
2376
- def pretty_callable (tp : CallableType ) -> str :
2428
+ def pretty_callable (tp : CallableType , skip_self : bool = False ) -> str :
2377
2429
"""Return a nice easily-readable representation of a callable type.
2378
2430
For example:
2379
2431
def [T <: int] f(self, x: int, y: T) -> None
2432
+
2433
+ If skip_self is True, print an actual callable type, as it would appear
2434
+ when bound on an instance/class, rather than how it would appear in the
2435
+ defining statement.
2380
2436
"""
2381
2437
s = ""
2382
2438
asterisk = False
@@ -2420,7 +2476,11 @@ def [T <: int] f(self, x: int, y: T) -> None
2420
2476
and hasattr (tp .definition , "arguments" )
2421
2477
):
2422
2478
definition_arg_names = [arg .variable .name for arg in tp .definition .arguments ]
2423
- if len (definition_arg_names ) > len (tp .arg_names ) and definition_arg_names [0 ]:
2479
+ if (
2480
+ len (definition_arg_names ) > len (tp .arg_names )
2481
+ and definition_arg_names [0 ]
2482
+ and not skip_self
2483
+ ):
2424
2484
if s :
2425
2485
s = ", " + s
2426
2486
s = definition_arg_names [0 ] + s
@@ -2487,7 +2547,9 @@ def get_missing_protocol_members(left: Instance, right: Instance) -> list[str]:
2487
2547
return missing
2488
2548
2489
2549
2490
- def get_conflict_protocol_types (left : Instance , right : Instance ) -> list [tuple [str , Type , Type ]]:
2550
+ def get_conflict_protocol_types (
2551
+ left : Instance , right : Instance , class_obj : bool = False
2552
+ ) -> list [tuple [str , Type , Type ]]:
2491
2553
"""Find members that are defined in 'left' but have incompatible types.
2492
2554
Return them as a list of ('member', 'got', 'expected').
2493
2555
"""
@@ -2498,7 +2560,7 @@ def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[s
2498
2560
continue
2499
2561
supertype = find_member (member , right , left )
2500
2562
assert supertype is not None
2501
- subtype = find_member (member , left , left )
2563
+ subtype = find_member (member , left , left , class_obj = class_obj )
2502
2564
if not subtype :
2503
2565
continue
2504
2566
is_compat = is_subtype (subtype , supertype , ignore_pos_arg_names = True )
@@ -2510,7 +2572,7 @@ def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[s
2510
2572
2511
2573
2512
2574
def get_bad_protocol_flags (
2513
- left : Instance , right : Instance
2575
+ left : Instance , right : Instance , class_obj : bool = False
2514
2576
) -> list [tuple [str , set [int ], set [int ]]]:
2515
2577
"""Return all incompatible attribute flags for members that are present in both
2516
2578
'left' and 'right'.
@@ -2536,6 +2598,9 @@ def get_bad_protocol_flags(
2536
2598
and IS_SETTABLE not in subflags
2537
2599
or IS_CLASS_OR_STATIC in superflags
2538
2600
and IS_CLASS_OR_STATIC not in subflags
2601
+ or class_obj
2602
+ and IS_SETTABLE in superflags
2603
+ and IS_CLASSVAR not in subflags
2539
2604
):
2540
2605
bad_flags .append ((name , subflags , superflags ))
2541
2606
return bad_flags
0 commit comments