From c91d6e275672470e1cd0892eaa493cbf028bb017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Wed, 23 Aug 2023 15:03:08 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fixed=20asyncio=20bugs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- openai/api_requestor.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/openai/api_requestor.py b/openai/api_requestor.py index 504f7c4411..571f388721 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -6,11 +6,10 @@ import threading import time import warnings -from contextlib import asynccontextmanager from json import JSONDecodeError from typing import ( + AsyncContextManager, AsyncGenerator, - AsyncIterator, Callable, Dict, Iterator, @@ -366,8 +365,9 @@ async def arequest( request_id: Optional[str] = None, request_timeout: Optional[Union[float, Tuple[float, float]]] = None, ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: - ctx = aiohttp_session() + ctx = AioHTTPSession() session = await ctx.__aenter__() + result = None try: result = await self.arequest_raw( method.lower(), @@ -381,6 +381,9 @@ async def arequest( ) resp, got_stream = await self._interpret_async_response(result, stream) except Exception: + # Close the request before exiting session context. + if result is not None: + result.release() await ctx.__aexit__(None, None, None) raise if got_stream: @@ -391,10 +394,15 @@ async def wrap_resp(): async for r in resp: yield r finally: + # Close the request before exiting session context. Important to do it here + # as if stream is not fully exhausted, we need to close the request nevertheless. + result.release() await ctx.__aexit__(None, None, None) return wrap_resp(), got_stream, self.api_key else: + # Close the request before exiting session context. + result.release() await ctx.__aexit__(None, None, None) return resp, got_stream, self.api_key @@ -768,11 +776,22 @@ def _interpret_response_line( return resp -@asynccontextmanager -async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]: - user_set_session = openai.aiosession.get() - if user_set_session: - yield user_set_session - else: - async with aiohttp.ClientSession() as session: - yield session +class AioHTTPSession(AsyncContextManager): + def __init__(self): + self._session = None + self._should_close_session = False + + async def __aenter__(self): + self._session = openai.aiosession.get() + if self._session is None: + self._session = await aiohttp.ClientSession().__aenter__() + self._should_close_session = True + + return self._session + + async def __aexit__(self, exc_type, exc_value, traceback): + if self._session is None: + raise RuntimeError("Session is not initialized") + + if self._should_close_session: + await self._session.__aexit__(exc_type, exc_value, traceback) \ No newline at end of file