Skip to content

Commit 0abf641

Browse files
Add async support (#146)
* Add async support * Fix aiohttp requests * Fix some syntax errors * Close aiohttp session properly * This is due to a lack of an async __del__ method * Fix code per review * Fix async tests and some mypy errors * Run black * Add todo for multipart form generation * Fix more mypy * Fix exception type * Don't yield twice Co-authored-by: Damien Deville <[email protected]>
1 parent ec4943f commit 0abf641

30 files changed

+1288
-74
lines changed

README.md

+26
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,32 @@ image_resp = openai.Image.create(prompt="two dogs playing chess, oil painting",
211211

212212
```
213213

214+
## Async API
215+
216+
Async support is available in the API by prepending `a` to a network-bound method:
217+
218+
```python
219+
import openai
220+
openai.api_key = "sk-..." # supply your API key however you choose
221+
222+
async def create_completion():
223+
completion_resp = await openai.Completion.acreate(prompt="This is a test", engine="davinci")
224+
225+
```
226+
227+
To make async requests more efficient, you can pass in your own
228+
``aiohttp.ClientSession``, but you must manually close the client session at the end
229+
of your program/event loop:
230+
231+
```python
232+
import openai
233+
from aiohttp import ClientSession
234+
235+
openai.aiosession.set(ClientSession())
236+
# At the end of your program, close the http session
237+
await openai.aiosession.get().close()
238+
```
239+
214240
See the [usage guide](https://beta.openai.com/docs/guides/images) for more details.
215241

216242
## Requirements

openai/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
# Originally forked from the MIT-licensed Stripe Python bindings.
44

55
import os
6-
from typing import Optional
6+
from contextvars import ContextVar
7+
from typing import Optional, TYPE_CHECKING
78

89
from openai.api_resources import (
910
Answer,
@@ -24,6 +25,9 @@
2425
)
2526
from openai.error import APIError, InvalidRequestError, OpenAIError
2627

28+
if TYPE_CHECKING:
29+
from aiohttp import ClientSession
30+
2731
api_key = os.environ.get("OPENAI_API_KEY")
2832
# Path of a file with an API key, whose contents can change. Supercedes
2933
# `api_key` if set. The main use case is volume-mounted Kubernetes secrets,
@@ -44,6 +48,11 @@
4448
debug = False
4549
log = None # Set to either 'debug' or 'info', controls console logging
4650

51+
aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
52+
"aiohttp-session", default=None
53+
) # Acts as a global aiohttp ClientSession that reuses connections.
54+
# This is user-supplied; otherwise, a session is remade for each request.
55+
4756
__all__ = [
4857
"APIError",
4958
"Answer",

openai/api_requestor.py

+234-21
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import asyncio
12
import json
23
import platform
34
import sys
45
import threading
56
import warnings
67
from json import JSONDecodeError
7-
from typing import Dict, Iterator, Optional, Tuple, Union, overload
8+
from typing import AsyncGenerator, Dict, Iterator, Optional, Tuple, Union, overload
89
from urllib.parse import urlencode, urlsplit, urlunsplit
910

11+
import aiohttp
1012
import requests
1113

1214
if sys.version_info >= (3, 8):
@@ -49,6 +51,20 @@ def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
4951
)
5052

5153

54+
def _aiohttp_proxies_arg(proxy) -> Optional[str]:
55+
"""Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
56+
if proxy is None:
57+
return None
58+
elif isinstance(proxy, str):
59+
return proxy
60+
elif isinstance(proxy, dict):
61+
return proxy["https"] if "https" in proxy else proxy["http"]
62+
else:
63+
raise ValueError(
64+
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
65+
)
66+
67+
5268
def _make_session() -> requests.Session:
5369
if not openai.verify_ssl_certs:
5470
warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
@@ -63,18 +79,32 @@ def _make_session() -> requests.Session:
6379
return s
6480

6581

82+
def parse_stream_helper(line):
83+
if line:
84+
if line == b"data: [DONE]":
85+
# return here will cause GeneratorExit exception in urllib3
86+
# and it will close http connection with TCP Reset
87+
return None
88+
if hasattr(line, "decode"):
89+
line = line.decode("utf-8")
90+
if line.startswith("data: "):
91+
line = line[len("data: ") :]
92+
return line
93+
return None
94+
95+
6696
def parse_stream(rbody):
6797
for line in rbody:
68-
if line:
69-
if line == b"data: [DONE]":
70-
# return here will cause GeneratorExit exception in urllib3
71-
# and it will close http connection with TCP Reset
72-
continue
73-
if hasattr(line, "decode"):
74-
line = line.decode("utf-8")
75-
if line.startswith("data: "):
76-
line = line[len("data: ") :]
77-
yield line
98+
_line = parse_stream_helper(line)
99+
if _line is not None:
100+
yield _line
101+
102+
103+
async def parse_stream_async(rbody: aiohttp.StreamReader):
104+
async for line in rbody:
105+
_line = parse_stream_helper(line)
106+
if _line is not None:
107+
yield _line
78108

79109

80110
class APIRequestor:
@@ -186,6 +216,86 @@ def request(
186216
resp, got_stream = self._interpret_response(result, stream)
187217
return resp, got_stream, self.api_key
188218

219+
@overload
220+
async def arequest(
221+
self,
222+
method,
223+
url,
224+
params,
225+
headers,
226+
files,
227+
stream: Literal[True],
228+
request_id: Optional[str] = ...,
229+
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
230+
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
231+
pass
232+
233+
@overload
234+
async def arequest(
235+
self,
236+
method,
237+
url,
238+
params=...,
239+
headers=...,
240+
files=...,
241+
*,
242+
stream: Literal[True],
243+
request_id: Optional[str] = ...,
244+
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
245+
) -> Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]:
246+
pass
247+
248+
@overload
249+
async def arequest(
250+
self,
251+
method,
252+
url,
253+
params=...,
254+
headers=...,
255+
files=...,
256+
stream: Literal[False] = ...,
257+
request_id: Optional[str] = ...,
258+
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
259+
) -> Tuple[OpenAIResponse, bool, str]:
260+
pass
261+
262+
@overload
263+
async def arequest(
264+
self,
265+
method,
266+
url,
267+
params=...,
268+
headers=...,
269+
files=...,
270+
stream: bool = ...,
271+
request_id: Optional[str] = ...,
272+
request_timeout: Optional[Union[float, Tuple[float, float]]] = ...,
273+
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
274+
pass
275+
276+
async def arequest(
277+
self,
278+
method,
279+
url,
280+
params=None,
281+
headers=None,
282+
files=None,
283+
stream: bool = False,
284+
request_id: Optional[str] = None,
285+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
286+
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
287+
result = await self.arequest_raw(
288+
method.lower(),
289+
url,
290+
params=params,
291+
supplied_headers=headers,
292+
files=files,
293+
request_id=request_id,
294+
request_timeout=request_timeout,
295+
)
296+
resp, got_stream = await self._interpret_async_response(result, stream)
297+
return resp, got_stream, self.api_key
298+
189299
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
190300
try:
191301
error_data = resp["error"]
@@ -315,18 +425,15 @@ def _validate_headers(
315425

316426
return headers
317427

318-
def request_raw(
428+
def _prepare_request_raw(
319429
self,
320-
method,
321430
url,
322-
*,
323-
params=None,
324-
supplied_headers: Dict[str, str] = None,
325-
files=None,
326-
stream: bool = False,
327-
request_id: Optional[str] = None,
328-
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
329-
) -> requests.Response:
431+
supplied_headers,
432+
method,
433+
params,
434+
files,
435+
request_id: Optional[str],
436+
) -> Tuple[str, Dict[str, str], Optional[bytes]]:
330437
abs_url = "%s%s" % (self.api_base, url)
331438
headers = self._validate_headers(supplied_headers)
332439

@@ -355,6 +462,24 @@ def request_raw(
355462
util.log_info("Request to OpenAI API", method=method, path=abs_url)
356463
util.log_debug("Post details", data=data, api_version=self.api_version)
357464

465+
return abs_url, headers, data
466+
467+
def request_raw(
468+
self,
469+
method,
470+
url,
471+
*,
472+
params=None,
473+
supplied_headers: Optional[Dict[str, str]] = None,
474+
files=None,
475+
stream: bool = False,
476+
request_id: Optional[str] = None,
477+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
478+
) -> requests.Response:
479+
abs_url, headers, data = self._prepare_request_raw(
480+
url, supplied_headers, method, params, files, request_id
481+
)
482+
358483
if not hasattr(_thread_context, "session"):
359484
_thread_context.session = _make_session()
360485
try:
@@ -385,6 +510,71 @@ def request_raw(
385510
)
386511
return result
387512

513+
async def arequest_raw(
514+
self,
515+
method,
516+
url,
517+
*,
518+
params=None,
519+
supplied_headers: Optional[Dict[str, str]] = None,
520+
files=None,
521+
request_id: Optional[str] = None,
522+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
523+
) -> aiohttp.ClientResponse:
524+
abs_url, headers, data = self._prepare_request_raw(
525+
url, supplied_headers, method, params, files, request_id
526+
)
527+
528+
if isinstance(request_timeout, tuple):
529+
timeout = aiohttp.ClientTimeout(
530+
connect=request_timeout[0],
531+
total=request_timeout[1],
532+
)
533+
else:
534+
timeout = aiohttp.ClientTimeout(
535+
total=request_timeout if request_timeout else TIMEOUT_SECS
536+
)
537+
user_set_session = openai.aiosession.get()
538+
539+
if files:
540+
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
541+
# For now we use the private `requests` method that is known to have worked so far.
542+
data, content_type = requests.models.RequestEncodingMixin._encode_files( # type: ignore
543+
files, data
544+
)
545+
headers["Content-Type"] = content_type
546+
request_kwargs = {
547+
"method": method,
548+
"url": abs_url,
549+
"headers": headers,
550+
"data": data,
551+
"proxy": _aiohttp_proxies_arg(openai.proxy),
552+
"timeout": timeout,
553+
}
554+
try:
555+
if user_set_session:
556+
result = await user_set_session.request(**request_kwargs)
557+
else:
558+
async with aiohttp.ClientSession() as session:
559+
result = await session.request(**request_kwargs)
560+
util.log_info(
561+
"OpenAI API response",
562+
path=abs_url,
563+
response_code=result.status,
564+
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
565+
request_id=result.headers.get("X-Request-Id"),
566+
)
567+
# Don't read the whole stream for debug logging unless necessary.
568+
if openai.log == "debug":
569+
util.log_debug(
570+
"API response body", body=result.content, headers=result.headers
571+
)
572+
return result
573+
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
574+
raise error.Timeout("Request timed out") from e
575+
except aiohttp.ClientError as e:
576+
raise error.APIConnectionError("Error communicating with OpenAI") from e
577+
388578
def _interpret_response(
389579
self, result: requests.Response, stream: bool
390580
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
@@ -404,6 +594,29 @@ def _interpret_response(
404594
False,
405595
)
406596

597+
async def _interpret_async_response(
598+
self, result: aiohttp.ClientResponse, stream: bool
599+
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
600+
"""Returns the response(s) and a bool indicating whether it is a stream."""
601+
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
602+
return (
603+
self._interpret_response_line(
604+
line, result.status, result.headers, stream=True
605+
)
606+
async for line in parse_stream_async(result.content)
607+
), True
608+
else:
609+
try:
610+
await result.read()
611+
except aiohttp.ClientError as e:
612+
util.log_warn(e, body=result.content)
613+
return (
614+
self._interpret_response_line(
615+
await result.read(), result.status, result.headers, stream=False
616+
),
617+
False,
618+
)
619+
407620
def _interpret_response_line(
408621
self, rbody, rcode, rheaders, stream: bool
409622
) -> OpenAIResponse:

0 commit comments

Comments
 (0)