From 3c07c1626c309a8f9e75896b5f9e9f700f2b2526 Mon Sep 17 00:00:00 2001 From: Damien Deville Date: Thu, 5 Jan 2023 22:00:49 -0800 Subject: [PATCH] Fix API requestor hanging when not using a global session --- openai/api_requestor.py | 54 ++++++++++++++++++++++++------------ openai/api_resources/file.py | 21 +++++++------- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/openai/api_requestor.py b/openai/api_requestor.py index b10730216d..1961e1d093 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -4,8 +4,18 @@ import sys import threading import warnings +from contextlib import asynccontextmanager from json import JSONDecodeError -from typing import AsyncGenerator, Dict, Iterator, Optional, Tuple, Union, overload +from typing import ( + AsyncGenerator, + AsyncIterator, + Dict, + Iterator, + Optional, + Tuple, + Union, + overload, +) from urllib.parse import urlencode, urlsplit, urlunsplit import aiohttp @@ -284,17 +294,19 @@ async def arequest( request_id: Optional[str] = None, request_timeout: Optional[Union[float, Tuple[float, float]]] = None, ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]: - result = await self.arequest_raw( - method.lower(), - url, - params=params, - supplied_headers=headers, - files=files, - request_id=request_id, - request_timeout=request_timeout, - ) - resp, got_stream = await self._interpret_async_response(result, stream) - return resp, got_stream, self.api_key + async with aiohttp_session() as session: + result = await self.arequest_raw( + method.lower(), + url, + session, + params=params, + supplied_headers=headers, + files=files, + request_id=request_id, + request_timeout=request_timeout, + ) + resp, got_stream = await self._interpret_async_response(result, stream) + return resp, got_stream, self.api_key def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): try: @@ -514,6 +526,7 @@ async def arequest_raw( self, method, url, + session, *, params=None, supplied_headers: Optional[Dict[str, str]] = None, @@ -534,7 +547,6 @@ async def arequest_raw( timeout = aiohttp.ClientTimeout( total=request_timeout if request_timeout else TIMEOUT_SECS ) - user_set_session = openai.aiosession.get() if files: # TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here. @@ -552,11 +564,7 @@ async def arequest_raw( "timeout": timeout, } try: - if user_set_session: - result = await user_set_session.request(**request_kwargs) - else: - async with aiohttp.ClientSession() as session: - result = await session.request(**request_kwargs) + result = await session.request(**request_kwargs) util.log_info( "OpenAI API response", path=abs_url, @@ -648,3 +656,13 @@ def _interpret_response_line( rbody, rcode, resp.data, rheaders, stream_error=stream_error ) return resp + + +@asynccontextmanager +async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]: + user_set_session = openai.aiosession.get() + if user_set_session: + yield user_set_session + else: + async with aiohttp.ClientSession() as session: + yield session diff --git a/openai/api_resources/file.py b/openai/api_resources/file.py index 3654dd2d2e..80b989ada1 100644 --- a/openai/api_resources/file.py +++ b/openai/api_resources/file.py @@ -192,16 +192,17 @@ async def adownload( id, api_key, api_base, api_type, api_version, organization ) - result = await requestor.arequest_raw("get", url) - if not 200 <= result.status < 300: - raise requestor.handle_error_response( - result.content, - result.status, - json.loads(cast(bytes, result.content)), - result.headers, - stream_error=False, - ) - return result.content + async with api_requestor.aiohttp_session() as session: + result = await requestor.arequest_raw("get", url, session) + if not 200 <= result.status < 300: + raise requestor.handle_error_response( + result.content, + result.status, + json.loads(cast(bytes, result.content)), + result.headers, + stream_error=False, + ) + return result.content @classmethod def __find_matching_files(cls, name, all_files, purpose):