diff --git a/README.md b/README.md index ab98ec4403..2a1bb52f31 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,10 @@ completion = openai.Completion.create(engine="ada", prompt="Hello world") print(completion.choices[0].text) ``` + +### Params +All endpoints have a `.create` method that support a `request_timeout` param. This param takes a `Union[float, Tuple[float, float]]` and will raise a `openai.error.TimeoutError` error if the request exceeds that time in seconds (See: https://requests.readthedocs.io/en/latest/user/quickstart/#timeouts). + ### Microsoft Azure Endpoints In order to use the library with Microsoft Azure endpoints, you need to set the api_type, api_base and api_version in addition to the api_key. The api_type must be set to 'azure' and the others correspond to the properties of your endpoint. diff --git a/openai/api_requestor.py b/openai/api_requestor.py index 954afc05e0..5970510434 100644 --- a/openai/api_requestor.py +++ b/openai/api_requestor.py @@ -3,10 +3,11 @@ import threading import warnings from json import JSONDecodeError -from typing import Dict, Iterator, Optional, Tuple, Union +from typing import Dict, Iterator, Optional, Tuple, Union, overload from urllib.parse import urlencode, urlsplit, urlunsplit import requests +from typing_extensions import Literal import openai from openai import error, util, version @@ -99,6 +100,63 @@ def format_app_info(cls, info): str += " (%s)" % (info["url"],) return str + @overload + def request( + self, + method, + url, + params, + headers, + files, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + pass + + @overload + def request( + self, + method, + url, + params=..., + headers=..., + files=..., + *, + stream: Literal[True], + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Iterator[OpenAIResponse], bool, str]: + pass + + @overload + def request( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: Literal[False] = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[OpenAIResponse, bool, str]: + pass + + @overload + def request( + self, + method, + url, + params=..., + headers=..., + files=..., + stream: bool = ..., + request_id: Optional[str] = ..., + request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., + ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: + pass + def request( self, method, @@ -106,8 +164,9 @@ def request( params=None, headers=None, files=None, - stream=False, + stream: bool = False, request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: result = self.request_raw( method.lower(), @@ -117,6 +176,7 @@ def request( files=files, stream=stream, request_id=request_id, + request_timeout=request_timeout, ) resp, got_stream = self._interpret_response(result, stream) return resp, got_stream, self.api_key @@ -179,7 +239,11 @@ def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False return error.APIError(message, rbody, rcode, resp, rheaders) else: return error.APIError( - error_data.get("message"), rbody, rcode, resp, rheaders + f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", + rbody, + rcode, + resp, + rheaders, ) def request_headers( @@ -256,6 +320,7 @@ def request_raw( files=None, stream: bool = False, request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, ) -> requests.Response: abs_url = "%s%s" % (self.api_base, url) headers = self._validate_headers(supplied_headers) @@ -295,8 +360,10 @@ def request_raw( data=data, files=files, stream=stream, - timeout=TIMEOUT_SECS, + timeout=request_timeout if request_timeout else TIMEOUT_SECS, ) + except requests.exceptions.Timeout as e: + raise error.Timeout("Request timed out") from e except requests.exceptions.RequestException as e: raise error.APIConnectionError("Error communicating with OpenAI") from e util.log_info( @@ -304,6 +371,7 @@ def request_raw( path=abs_url, response_code=result.status_code, processing_ms=result.headers.get("OpenAI-Processing-Ms"), + request_id=result.headers.get("X-Request-Id"), ) # Don't read the whole stream for debug logging unless necessary. if openai.log == "debug": diff --git a/openai/api_resources/abstract/api_resource.py b/openai/api_resources/abstract/api_resource.py index dd13fcbf0e..aa7cfe88e1 100644 --- a/openai/api_resources/abstract/api_resource.py +++ b/openai/api_resources/abstract/api_resource.py @@ -13,14 +13,21 @@ class APIResource(OpenAIObject): azure_deployments_prefix = "deployments" @classmethod - def retrieve(cls, id, api_key=None, request_id=None, **params): + def retrieve( + cls, id, api_key=None, request_id=None, request_timeout=None, **params + ): instance = cls(id, api_key, **params) - instance.refresh(request_id=request_id) + instance.refresh(request_id=request_id, request_timeout=request_timeout) return instance - def refresh(self, request_id=None): + def refresh(self, request_id=None, request_timeout=None): self.refresh_from( - self.request("get", self.instance_url(), request_id=request_id) + self.request( + "get", + self.instance_url(), + request_id=request_id, + request_timeout=request_timeout, + ) ) return self diff --git a/openai/api_resources/abstract/engine_api_resource.py b/openai/api_resources/abstract/engine_api_resource.py index 9d4eda1a45..152313c202 100644 --- a/openai/api_resources/abstract/engine_api_resource.py +++ b/openai/api_resources/abstract/engine_api_resource.py @@ -77,7 +77,7 @@ def create( timeout = params.pop("timeout", None) stream = params.get("stream", False) headers = params.pop("headers", None) - + request_timeout = params.pop("request_timeout", None) typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0] if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): if deployment_id is None and engine is None: @@ -119,6 +119,7 @@ def create( headers=headers, stream=stream, request_id=request_id, + request_timeout=request_timeout, ) if stream: diff --git a/openai/cli.py b/openai/cli.py index 09ede89131..dbde7e19cf 100644 --- a/openai/cli.py +++ b/openai/cli.py @@ -9,7 +9,6 @@ import requests import openai -import openai.wandb_logger from openai.upload_progress import BufferReader from openai.validators import ( apply_necessary_remediation, @@ -542,6 +541,8 @@ def prepare_data(cls, args): class WandbLogger: @classmethod def sync(cls, args): + import openai.wandb_logger + resp = openai.wandb_logger.WandbLogger.sync( id=args.id, n_fine_tunes=args.n_fine_tunes, diff --git a/openai/error.py b/openai/error.py index 47f9aab6bc..d22e71c902 100644 --- a/openai/error.py +++ b/openai/error.py @@ -76,6 +76,10 @@ class TryAgain(OpenAIError): pass +class Timeout(OpenAIError): + pass + + class APIConnectionError(OpenAIError): def __init__( self, diff --git a/openai/openai_object.py b/openai/openai_object.py index 58e458dfed..5bfa29e45f 100644 --- a/openai/openai_object.py +++ b/openai/openai_object.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from typing import Optional +from typing import Optional, Tuple, Union import openai from openai import api_requestor, util @@ -165,6 +165,7 @@ def request( stream=False, plain_old_data=False, request_id: Optional[str] = None, + request_timeout: Optional[Union[float, Tuple[float, float]]] = None, ): if params is None: params = self._retrieve_params @@ -182,6 +183,7 @@ def request( stream=stream, headers=headers, request_id=request_id, + request_timeout=request_timeout, ) if stream: diff --git a/openai/tests/test_endpoints.py b/openai/tests/test_endpoints.py index f590328eec..e5c466add1 100644 --- a/openai/tests/test_endpoints.py +++ b/openai/tests/test_endpoints.py @@ -1,7 +1,10 @@ import io import json +import pytest + import openai +from openai import error # FILE TESTS @@ -34,3 +37,24 @@ def test_completions_model(): result = openai.Completion.create(prompt="This was a test", n=5, model="ada") assert len(result.choices) == 5 assert result.model.startswith("ada") + + +def test_timeout_raises_error(): + # A query that should take awhile to return + with pytest.raises(error.Timeout): + openai.Completion.create( + prompt="test" * 1000, + n=10, + model="ada", + max_tokens=100, + request_timeout=0.01, + ) + + +def test_timeout_does_not_error(): + # A query that should be fast + openai.Completion.create( + prompt="test", + model="ada", + request_timeout=10, + ) diff --git a/openai/tests/test_exceptions.py b/openai/tests/test_exceptions.py index e97b4cb386..7760cdc5f6 100644 --- a/openai/tests/test_exceptions.py +++ b/openai/tests/test_exceptions.py @@ -21,6 +21,7 @@ openai.error.SignatureVerificationError("message", "sig_header?"), openai.error.APIConnectionError("message!", should_retry=True), openai.error.TryAgain(), + openai.error.Timeout(), openai.error.APIError( message="message", code=400, diff --git a/openai/version.py b/openai/version.py index dae64cb9e3..68c325024b 100644 --- a/openai/version.py +++ b/openai/version.py @@ -1 +1 @@ -VERSION = "0.22.1" +VERSION = "0.23.0" diff --git a/setup.py b/setup.py index 707e8f299b..03bd6c3b6a 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,8 @@ "pandas>=1.2.3", # Needed for CLI fine-tuning data preparation tool "pandas-stubs>=1.1.0.11", # Needed for type hints for mypy "openpyxl>=3.0.7", # Needed for CLI fine-tuning data preparation tool xlsx format - "numpy>=1.22.0", # To address a vuln in <1.21.6 + "numpy", + "typing_extensions", # Needed for type hints for mypy ], extras_require={ "dev": ["black~=21.6b0", "pytest==6.*"],