@@ -28,14 +28,21 @@ def _no_slots_copy(dct):
28
28
dict_copy .pop (slot , None )
29
29
return dict_copy
30
30
31
-
32
- def _check_generic (cls , parameters ):
33
- if not cls .__parameters__ :
31
+ def _check_generic (cls , parameters , elen ):
32
+ """Check correct count for parameters of a generic cls (internal helper).
33
+ This gives a nice error message in case of count mismatch.
34
+ """
35
+ if not elen :
34
36
raise TypeError (f"{ cls } is not a generic class" )
35
37
alen = len (parameters )
36
- elen = len (cls .__parameters__ )
37
38
if alen != elen :
38
- raise TypeError (f"Too { 'many' if alen > elen else 'few' } arguments for { cls } ;"
39
+ if hasattr (cls , "__parameters__" ):
40
+ num_tv_tuples = sum (
41
+ isinstance (p , TypeVarTuple ) for p in cls .__parameters__
42
+ )
43
+ if num_tv_tuples > 0 and alen >= elen - num_tv_tuples :
44
+ return
45
+ raise TypeError (f"Too { 'many' if alen > elen else 'few' } parameters for { cls } ;"
39
46
f" actual { alen } , expected { elen } " )
40
47
41
48
@@ -602,7 +609,7 @@ def __class_getitem__(cls, params):
602
609
"Parameters to Protocol[...] must all be unique" )
603
610
else :
604
611
# Subscripting a regular Generic subclass.
605
- _check_generic (cls , params )
612
+ _check_generic (cls , params , len ( cls . __parameters__ ) )
606
613
return typing ._GenericAlias (cls , params )
607
614
608
615
def __init_subclass__ (cls , * args , ** kwargs ):
@@ -883,7 +890,7 @@ def __getitem__(self, params):
883
890
elif self .__origin__ in (typing .Generic , Protocol ):
884
891
raise TypeError (f"Cannot subscript already-subscripted { repr (self )} " )
885
892
else :
886
- _check_generic (self , params )
893
+ _check_generic (self , params , len ( self . __parameters__ ) )
887
894
tvars = _type_vars (params )
888
895
args = params
889
896
@@ -2296,7 +2303,9 @@ def Unpack(self, parameters):
2296
2303
Shape = TypeVarTuple('Shape')
2297
2304
Batch = NewType('Batch', int)
2298
2305
2299
- def add_batch_axis(x: Array[Unpack[Shape]]) -> Array[Batch, Unpack[Shape]]: ...
2306
+ def add_batch_axis(
2307
+ x: Array[Unpack[Shape]]
2308
+ ) -> Array[Batch, Unpack[Shape]]: ...
2300
2309
2301
2310
"""
2302
2311
item = typing ._type_check (parameters , f'{ self ._name } accepts only single type' )
@@ -2317,22 +2326,6 @@ def _collect_type_vars(types):
2317
2326
return tuple (tvars )
2318
2327
2319
2328
typing ._collect_type_vars = _collect_type_vars
2320
-
2321
- def _check_generic (cls , parameters , elen ):
2322
- """Check correct count for parameters of a generic cls (internal helper).
2323
- This gives a nice error message in case of count mismatch.
2324
- """
2325
- if not elen :
2326
- raise TypeError (f"{ cls } is not a generic class" )
2327
- alen = len (parameters )
2328
- if alen != elen :
2329
- if hasattr (cls , "__parameters__" ):
2330
- num_tv_tuples = sum (isinstance (p , TypeVarTuple ) for p in cls .__parameters__ )
2331
- if num_tv_tuples > 0 and alen >= elen - num_tv_tuples :
2332
- return
2333
- raise TypeError (f"Too { 'many' if alen > elen else 'few' } parameters for { cls } ;"
2334
- f" actual { alen } , expected { elen } " )
2335
-
2336
2329
typing ._check_generic = _check_generic
2337
2330
2338
2331
elif sys .version_info [:2 ] >= (3 , 7 ):
@@ -2355,7 +2348,9 @@ def __getitem__(self, parameters):
2355
2348
Shape = TypeVarTuple('Shape')
2356
2349
Batch = NewType('Batch', int)
2357
2350
2358
- def add_batch_axis(x: Array[Unpack[Shape]]) -> Array[Batch, Unpack[Shape]]: ...
2351
+ def add_batch_axis(
2352
+ x: Array[Unpack[Shape]]
2353
+ ) -> Array[Batch, Unpack[Shape]]: ...
2359
2354
2360
2355
""" )
2361
2356
else :
@@ -2366,7 +2361,9 @@ class _Unpack(typing._FinalTypingBase, _root=True):
2366
2361
Shape = TypeVarTuple('Shape')
2367
2362
Batch = NewType('Batch', int)
2368
2363
2369
- def add_batch_axis(x: Array[Unpack[Shape]]) -> Array[Batch, Unpack[Shape]]: ...
2364
+ def add_batch_axis(
2365
+ x: Array[Unpack[Shape]]
2366
+ ) -> Array[Batch, Unpack[Shape]]: ...
2370
2367
2371
2368
"""
2372
2369
__slots__ = ('__type__' ,)
@@ -2407,7 +2404,6 @@ def __eq__(self, other):
2407
2404
Unpack = _Unpack (_root = True )
2408
2405
2409
2406
2410
- # Inherits from list as a workaround for Callable checks in Python < 3.9.2.
2411
2407
class TypeVarTuple (typing ._Final , _root = True ):
2412
2408
"""Type variable tuple.
2413
2409
0 commit comments