@@ -7511,6 +7511,15 @@ def manager():
7511
7511
self .assertIsInstance (cm , typing .ContextManager )
7512
7512
self .assertNotIsInstance (42 , typing .ContextManager )
7513
7513
7514
+ def test_contextmanager_type_params (self ):
7515
+ cm1 = typing .ContextManager [int ]
7516
+ self .assertEqual (get_args (cm1 ), (int , bool | None ))
7517
+ cm2 = typing .ContextManager [int , None ]
7518
+ self .assertEqual (get_args (cm2 ), (int , types .NoneType ))
7519
+
7520
+ type gen_cm [T1 , T2 ] = typing .ContextManager [T1 , T2 ]
7521
+ self .assertEqual (get_args (gen_cm .__value__ [int , None ]), (int , types .NoneType ))
7522
+
7514
7523
def test_async_contextmanager (self ):
7515
7524
class NotACM :
7516
7525
pass
@@ -7522,11 +7531,17 @@ def manager():
7522
7531
7523
7532
cm = manager ()
7524
7533
self .assertNotIsInstance (cm , typing .AsyncContextManager )
7525
- self .assertEqual (typing .AsyncContextManager [int ].__args__ , (int ,))
7534
+ self .assertEqual (typing .AsyncContextManager [int ].__args__ , (int , bool | None ))
7526
7535
with self .assertRaises (TypeError ):
7527
7536
isinstance (42 , typing .AsyncContextManager [int ])
7528
7537
with self .assertRaises (TypeError ):
7529
- typing .AsyncContextManager [int , str ]
7538
+ typing .AsyncContextManager [int , str , float ]
7539
+
7540
+ def test_asynccontextmanager_type_params (self ):
7541
+ cm1 = typing .AsyncContextManager [int ]
7542
+ self .assertEqual (get_args (cm1 ), (int , bool | None ))
7543
+ cm2 = typing .AsyncContextManager [int , None ]
7544
+ self .assertEqual (get_args (cm2 ), (int , types .NoneType ))
7530
7545
7531
7546
7532
7547
class TypeTests (BaseTestCase ):
@@ -9953,7 +9968,7 @@ def test_special_attrs(self):
9953
9968
typing .ValuesView : 'ValuesView' ,
9954
9969
# Subscribed ABC classes
9955
9970
typing .AbstractSet [Any ]: 'AbstractSet' ,
9956
- typing .AsyncContextManager [Any ]: 'AsyncContextManager' ,
9971
+ typing .AsyncContextManager [Any , Any ]: 'AsyncContextManager' ,
9957
9972
typing .AsyncGenerator [Any , Any ]: 'AsyncGenerator' ,
9958
9973
typing .AsyncIterable [Any ]: 'AsyncIterable' ,
9959
9974
typing .AsyncIterator [Any ]: 'AsyncIterator' ,
@@ -9963,7 +9978,7 @@ def test_special_attrs(self):
9963
9978
typing .ChainMap [Any , Any ]: 'ChainMap' ,
9964
9979
typing .Collection [Any ]: 'Collection' ,
9965
9980
typing .Container [Any ]: 'Container' ,
9966
- typing .ContextManager [Any ]: 'ContextManager' ,
9981
+ typing .ContextManager [Any , Any ]: 'ContextManager' ,
9967
9982
typing .Coroutine [Any , Any , Any ]: 'Coroutine' ,
9968
9983
typing .Counter [Any ]: 'Counter' ,
9969
9984
typing .DefaultDict [Any , Any ]: 'DefaultDict' ,
0 commit comments