Skip to content

Commit 0c93bc6

Browse files
hynky1999megamanics
authored andcommitted
🐛 fixed asyncio bugs (openai#584)
1 parent 3c1c3e7 commit 0c93bc6

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

openai/api_requestor.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import threading
77
import time
88
import warnings
9-
from contextlib import asynccontextmanager
109
from json import JSONDecodeError
1110
from typing import (
11+
AsyncContextManager,
1212
AsyncGenerator,
13-
AsyncIterator,
1413
Callable,
1514
Dict,
1615
Iterator,
@@ -368,8 +367,9 @@ async def arequest(
368367
request_id: Optional[str] = None,
369368
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
370369
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
371-
ctx = aiohttp_session()
370+
ctx = AioHTTPSession()
372371
session = await ctx.__aenter__()
372+
result = None
373373
try:
374374
result = await self.arequest_raw(
375375
method.lower(),
@@ -383,6 +383,9 @@ async def arequest(
383383
)
384384
resp, got_stream = await self._interpret_async_response(result, stream)
385385
except Exception:
386+
# Close the request before exiting session context.
387+
if result is not None:
388+
result.release()
386389
await ctx.__aexit__(None, None, None)
387390
raise
388391
if got_stream:
@@ -393,10 +396,15 @@ async def wrap_resp():
393396
async for r in resp:
394397
yield r
395398
finally:
399+
# Close the request before exiting session context. Important to do it here
400+
# as if stream is not fully exhausted, we need to close the request nevertheless.
401+
result.release()
396402
await ctx.__aexit__(None, None, None)
397403

398404
return wrap_resp(), got_stream, self.api_key
399405
else:
406+
# Close the request before exiting session context.
407+
result.release()
400408
await ctx.__aexit__(None, None, None)
401409
return resp, got_stream, self.api_key
402410

@@ -770,11 +778,22 @@ def _interpret_response_line(
770778
return resp
771779

772780

773-
@asynccontextmanager
774-
async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
775-
user_set_session = openai.aiosession.get()
776-
if user_set_session:
777-
yield user_set_session
778-
else:
779-
async with aiohttp.ClientSession() as session:
780-
yield session
781+
class AioHTTPSession(AsyncContextManager):
782+
def __init__(self):
783+
self._session = None
784+
self._should_close_session = False
785+
786+
async def __aenter__(self):
787+
self._session = openai.aiosession.get()
788+
if self._session is None:
789+
self._session = await aiohttp.ClientSession().__aenter__()
790+
self._should_close_session = True
791+
792+
return self._session
793+
794+
async def __aexit__(self, exc_type, exc_value, traceback):
795+
if self._session is None:
796+
raise RuntimeError("Session is not initialized")
797+
798+
if self._should_close_session:
799+
await self._session.__aexit__(exc_type, exc_value, traceback)

0 commit comments

Comments
 (0)