1
+ import asyncio
1
2
import json
2
3
import platform
3
4
import sys
4
5
import threading
5
6
import warnings
6
7
from json import JSONDecodeError
7
- from typing import Dict , Iterator , Optional , Tuple , Union , overload
8
+ from typing import AsyncGenerator , Dict , Iterator , Optional , Tuple , Union , overload
8
9
from urllib .parse import urlencode , urlsplit , urlunsplit
9
10
11
+ import aiohttp
10
12
import requests
11
13
12
14
if sys .version_info >= (3 , 8 ):
@@ -49,6 +51,20 @@ def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
49
51
)
50
52
51
53
54
+ def _aiohttp_proxies_arg (proxy ) -> Optional [str ]:
55
+ """Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
56
+ if proxy is None :
57
+ return None
58
+ elif isinstance (proxy , str ):
59
+ return proxy
60
+ elif isinstance (proxy , dict ):
61
+ return proxy ["https" ] if "https" in proxy else proxy ["http" ]
62
+ else :
63
+ raise ValueError (
64
+ "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
65
+ )
66
+
67
+
52
68
def _make_session () -> requests .Session :
53
69
if not openai .verify_ssl_certs :
54
70
warnings .warn ("verify_ssl_certs is ignored; openai always verifies." )
@@ -63,18 +79,32 @@ def _make_session() -> requests.Session:
63
79
return s
64
80
65
81
82
+ def parse_stream_helper (line ):
83
+ if line :
84
+ if line == b"data: [DONE]" :
85
+ # return here will cause GeneratorExit exception in urllib3
86
+ # and it will close http connection with TCP Reset
87
+ return None
88
+ if hasattr (line , "decode" ):
89
+ line = line .decode ("utf-8" )
90
+ if line .startswith ("data: " ):
91
+ line = line [len ("data: " ) :]
92
+ return line
93
+ return None
94
+
95
+
66
96
def parse_stream (rbody ):
67
97
for line in rbody :
68
- if line :
69
- if line == b"data: [DONE]" :
70
- # return here will cause GeneratorExit exception in urllib3
71
- # and it will close http connection with TCP Reset
72
- continue
73
- if hasattr ( line , "decode" ):
74
- line = line . decode ( "utf-8" )
75
- if line . startswith ( "data: " ):
76
- line = line [ len ( "data: " ) :]
77
- yield line
98
+ _line = parse_stream_helper ( line )
99
+ if _line is not None :
100
+ yield _line
101
+
102
+
103
+ async def parse_stream_async ( rbody : aiohttp . StreamReader ):
104
+ async for line in rbody :
105
+ _line = parse_stream_helper ( line )
106
+ if _line is not None :
107
+ yield _line
78
108
79
109
80
110
class APIRequestor :
@@ -186,6 +216,86 @@ def request(
186
216
resp , got_stream = self ._interpret_response (result , stream )
187
217
return resp , got_stream , self .api_key
188
218
219
+ @overload
220
+ async def arequest (
221
+ self ,
222
+ method ,
223
+ url ,
224
+ params ,
225
+ headers ,
226
+ files ,
227
+ stream : Literal [True ],
228
+ request_id : Optional [str ] = ...,
229
+ request_timeout : Optional [Union [float , Tuple [float , float ]]] = ...,
230
+ ) -> Tuple [AsyncGenerator [OpenAIResponse , None ], bool , str ]:
231
+ pass
232
+
233
+ @overload
234
+ async def arequest (
235
+ self ,
236
+ method ,
237
+ url ,
238
+ params = ...,
239
+ headers = ...,
240
+ files = ...,
241
+ * ,
242
+ stream : Literal [True ],
243
+ request_id : Optional [str ] = ...,
244
+ request_timeout : Optional [Union [float , Tuple [float , float ]]] = ...,
245
+ ) -> Tuple [AsyncGenerator [OpenAIResponse , None ], bool , str ]:
246
+ pass
247
+
248
+ @overload
249
+ async def arequest (
250
+ self ,
251
+ method ,
252
+ url ,
253
+ params = ...,
254
+ headers = ...,
255
+ files = ...,
256
+ stream : Literal [False ] = ...,
257
+ request_id : Optional [str ] = ...,
258
+ request_timeout : Optional [Union [float , Tuple [float , float ]]] = ...,
259
+ ) -> Tuple [OpenAIResponse , bool , str ]:
260
+ pass
261
+
262
+ @overload
263
+ async def arequest (
264
+ self ,
265
+ method ,
266
+ url ,
267
+ params = ...,
268
+ headers = ...,
269
+ files = ...,
270
+ stream : bool = ...,
271
+ request_id : Optional [str ] = ...,
272
+ request_timeout : Optional [Union [float , Tuple [float , float ]]] = ...,
273
+ ) -> Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool , str ]:
274
+ pass
275
+
276
+ async def arequest (
277
+ self ,
278
+ method ,
279
+ url ,
280
+ params = None ,
281
+ headers = None ,
282
+ files = None ,
283
+ stream : bool = False ,
284
+ request_id : Optional [str ] = None ,
285
+ request_timeout : Optional [Union [float , Tuple [float , float ]]] = None ,
286
+ ) -> Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool , str ]:
287
+ result = await self .arequest_raw (
288
+ method .lower (),
289
+ url ,
290
+ params = params ,
291
+ supplied_headers = headers ,
292
+ files = files ,
293
+ request_id = request_id ,
294
+ request_timeout = request_timeout ,
295
+ )
296
+ resp , got_stream = await self ._interpret_async_response (result , stream )
297
+ return resp , got_stream , self .api_key
298
+
189
299
def handle_error_response (self , rbody , rcode , resp , rheaders , stream_error = False ):
190
300
try :
191
301
error_data = resp ["error" ]
@@ -315,18 +425,15 @@ def _validate_headers(
315
425
316
426
return headers
317
427
318
- def request_raw (
428
+ def _prepare_request_raw (
319
429
self ,
320
- method ,
321
430
url ,
322
- * ,
323
- params = None ,
324
- supplied_headers : Dict [str , str ] = None ,
325
- files = None ,
326
- stream : bool = False ,
327
- request_id : Optional [str ] = None ,
328
- request_timeout : Optional [Union [float , Tuple [float , float ]]] = None ,
329
- ) -> requests .Response :
431
+ supplied_headers ,
432
+ method ,
433
+ params ,
434
+ files ,
435
+ request_id : Optional [str ],
436
+ ) -> Tuple [str , Dict [str , str ], Optional [bytes ]]:
330
437
abs_url = "%s%s" % (self .api_base , url )
331
438
headers = self ._validate_headers (supplied_headers )
332
439
@@ -355,6 +462,24 @@ def request_raw(
355
462
util .log_info ("Request to OpenAI API" , method = method , path = abs_url )
356
463
util .log_debug ("Post details" , data = data , api_version = self .api_version )
357
464
465
+ return abs_url , headers , data
466
+
467
+ def request_raw (
468
+ self ,
469
+ method ,
470
+ url ,
471
+ * ,
472
+ params = None ,
473
+ supplied_headers : Optional [Dict [str , str ]] = None ,
474
+ files = None ,
475
+ stream : bool = False ,
476
+ request_id : Optional [str ] = None ,
477
+ request_timeout : Optional [Union [float , Tuple [float , float ]]] = None ,
478
+ ) -> requests .Response :
479
+ abs_url , headers , data = self ._prepare_request_raw (
480
+ url , supplied_headers , method , params , files , request_id
481
+ )
482
+
358
483
if not hasattr (_thread_context , "session" ):
359
484
_thread_context .session = _make_session ()
360
485
try :
@@ -385,6 +510,71 @@ def request_raw(
385
510
)
386
511
return result
387
512
513
+ async def arequest_raw (
514
+ self ,
515
+ method ,
516
+ url ,
517
+ * ,
518
+ params = None ,
519
+ supplied_headers : Optional [Dict [str , str ]] = None ,
520
+ files = None ,
521
+ request_id : Optional [str ] = None ,
522
+ request_timeout : Optional [Union [float , Tuple [float , float ]]] = None ,
523
+ ) -> aiohttp .ClientResponse :
524
+ abs_url , headers , data = self ._prepare_request_raw (
525
+ url , supplied_headers , method , params , files , request_id
526
+ )
527
+
528
+ if isinstance (request_timeout , tuple ):
529
+ timeout = aiohttp .ClientTimeout (
530
+ connect = request_timeout [0 ],
531
+ total = request_timeout [1 ],
532
+ )
533
+ else :
534
+ timeout = aiohttp .ClientTimeout (
535
+ total = request_timeout if request_timeout else TIMEOUT_SECS
536
+ )
537
+ user_set_session = openai .aiosession .get ()
538
+
539
+ if files :
540
+ # TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
541
+ # For now we use the private `requests` method that is known to have worked so far.
542
+ data , content_type = requests .models .RequestEncodingMixin ._encode_files ( # type: ignore
543
+ files , data
544
+ )
545
+ headers ["Content-Type" ] = content_type
546
+ request_kwargs = {
547
+ "method" : method ,
548
+ "url" : abs_url ,
549
+ "headers" : headers ,
550
+ "data" : data ,
551
+ "proxy" : _aiohttp_proxies_arg (openai .proxy ),
552
+ "timeout" : timeout ,
553
+ }
554
+ try :
555
+ if user_set_session :
556
+ result = await user_set_session .request (** request_kwargs )
557
+ else :
558
+ async with aiohttp .ClientSession () as session :
559
+ result = await session .request (** request_kwargs )
560
+ util .log_info (
561
+ "OpenAI API response" ,
562
+ path = abs_url ,
563
+ response_code = result .status ,
564
+ processing_ms = result .headers .get ("OpenAI-Processing-Ms" ),
565
+ request_id = result .headers .get ("X-Request-Id" ),
566
+ )
567
+ # Don't read the whole stream for debug logging unless necessary.
568
+ if openai .log == "debug" :
569
+ util .log_debug (
570
+ "API response body" , body = result .content , headers = result .headers
571
+ )
572
+ return result
573
+ except (aiohttp .ServerTimeoutError , asyncio .TimeoutError ) as e :
574
+ raise error .Timeout ("Request timed out" ) from e
575
+ except aiohttp .ClientError as e :
576
+ raise error .APIConnectionError ("Error communicating with OpenAI" ) from e
577
+
388
578
def _interpret_response (
389
579
self , result : requests .Response , stream : bool
390
580
) -> Tuple [Union [OpenAIResponse , Iterator [OpenAIResponse ]], bool ]:
@@ -404,6 +594,29 @@ def _interpret_response(
404
594
False ,
405
595
)
406
596
597
+ async def _interpret_async_response (
598
+ self , result : aiohttp .ClientResponse , stream : bool
599
+ ) -> Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool ]:
600
+ """Returns the response(s) and a bool indicating whether it is a stream."""
601
+ if stream and "text/event-stream" in result .headers .get ("Content-Type" , "" ):
602
+ return (
603
+ self ._interpret_response_line (
604
+ line , result .status , result .headers , stream = True
605
+ )
606
+ async for line in parse_stream_async (result .content )
607
+ ), True
608
+ else :
609
+ try :
610
+ await result .read ()
611
+ except aiohttp .ClientError as e :
612
+ util .log_warn (e , body = result .content )
613
+ return (
614
+ self ._interpret_response_line (
615
+ await result .read (), result .status , result .headers , stream = False
616
+ ),
617
+ False ,
618
+ )
619
+
407
620
def _interpret_response_line (
408
621
self , rbody , rcode , rheaders , stream : bool
409
622
) -> OpenAIResponse :
0 commit comments