6
6
import threading
7
7
import time
8
8
import warnings
9
- from contextlib import asynccontextmanager
10
9
from json import JSONDecodeError
11
10
from typing import (
11
+ AsyncContextManager ,
12
12
AsyncGenerator ,
13
- AsyncIterator ,
14
13
Callable ,
15
14
Dict ,
16
15
Iterator ,
@@ -368,8 +367,9 @@ async def arequest(
368
367
request_id : Optional [str ] = None ,
369
368
request_timeout : Optional [Union [float , Tuple [float , float ]]] = None ,
370
369
) -> Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool , str ]:
371
- ctx = aiohttp_session ()
370
+ ctx = AioHTTPSession ()
372
371
session = await ctx .__aenter__ ()
372
+ result = None
373
373
try :
374
374
result = await self .arequest_raw (
375
375
method .lower (),
@@ -383,6 +383,9 @@ async def arequest(
383
383
)
384
384
resp , got_stream = await self ._interpret_async_response (result , stream )
385
385
except Exception :
386
+ # Close the request before exiting session context.
387
+ if result is not None :
388
+ result .release ()
386
389
await ctx .__aexit__ (None , None , None )
387
390
raise
388
391
if got_stream :
@@ -393,10 +396,15 @@ async def wrap_resp():
393
396
async for r in resp :
394
397
yield r
395
398
finally :
399
+ # Close the request before exiting session context. Important to do it here
400
+ # as if stream is not fully exhausted, we need to close the request nevertheless.
401
+ result .release ()
396
402
await ctx .__aexit__ (None , None , None )
397
403
398
404
return wrap_resp (), got_stream , self .api_key
399
405
else :
406
+ # Close the request before exiting session context.
407
+ result .release ()
400
408
await ctx .__aexit__ (None , None , None )
401
409
return resp , got_stream , self .api_key
402
410
@@ -770,11 +778,22 @@ def _interpret_response_line(
770
778
return resp
771
779
772
780
773
- @asynccontextmanager
774
- async def aiohttp_session () -> AsyncIterator [aiohttp .ClientSession ]:
775
- user_set_session = openai .aiosession .get ()
776
- if user_set_session :
777
- yield user_set_session
778
- else :
779
- async with aiohttp .ClientSession () as session :
780
- yield session
781
+ class AioHTTPSession (AsyncContextManager ):
782
+ def __init__ (self ):
783
+ self ._session = None
784
+ self ._should_close_session = False
785
+
786
+ async def __aenter__ (self ):
787
+ self ._session = openai .aiosession .get ()
788
+ if self ._session is None :
789
+ self ._session = await aiohttp .ClientSession ().__aenter__ ()
790
+ self ._should_close_session = True
791
+
792
+ return self ._session
793
+
794
+ async def __aexit__ (self , exc_type , exc_value , traceback ):
795
+ if self ._session is None :
796
+ raise RuntimeError ("Session is not initialized" )
797
+
798
+ if self ._should_close_session :
799
+ await self ._session .__aexit__ (exc_type , exc_value , traceback )
0 commit comments