Skip to content

Commit 02e4008

Browse files
Add support for fine-tuning and files using the Azure API. (#80)
* Add support for fine-tunning and files using the Azure API. * Small changes + version bumps * Version bump after merge * fix typo * adressed comments * Fixed 2 small issues that cause unit tests to fail. * Adressed comments * Version bump
1 parent 63cc289 commit 02e4008

13 files changed

+165
-50
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ __pycache__
66
build
77
*.egg
88
.vscode/settings.json
9-
.ipynb_checkpoints
9+
.ipynb_checkpoints
10+
.vscode/launch.json

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ search = openai.Engine(id="deployment-namme").search(documents=["White House", "
7777
print(search)
7878
```
7979

80-
Please note that for the moment, the Microsoft Azure endpoints can only be used for completion and search operations.
80+
Please note that for the moment, the Microsoft Azure endpoints can only be used for completion, search and fine-tuning operations.
8181

8282
### Command-line interface
8383

openai/api_requestor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def request(
122122

123123
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
124124
try:
125-
error_data = resp["error"]
125+
error_data = resp["error"] if self.api_type == ApiType.OPEN_AI else resp
126126
except (KeyError, TypeError):
127127
raise error.APIError(
128128
"Invalid response object from API: %r (HTTP response code "
@@ -333,6 +333,10 @@ def _interpret_response(
333333
def _interpret_response_line(
334334
self, rbody, rcode, rheaders, stream: bool
335335
) -> OpenAIResponse:
336+
# HTTP 204 response code does not have any content in the body.
337+
if rcode == 204:
338+
return OpenAIResponse(None, rheaders)
339+
336340
if rcode == 503:
337341
raise error.ServiceUnavailableError(
338342
"The server is overloaded or not ready yet.",

openai/api_resources/abstract/api_resource.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
class APIResource(OpenAIObject):
1010
api_prefix = ""
11-
azure_api_prefix = "openai/deployments"
11+
azure_api_prefix = "openai"
12+
azure_deployments_prefix = "deployments"
1213

1314
@classmethod
1415
def retrieve(cls, id, api_key=None, request_id=None, **params):
@@ -46,27 +47,34 @@ def instance_url(self, operation=None):
4647
"id",
4748
)
4849
api_version = self.api_version or openai.api_version
50+
extn = quote_plus(id)
4951

5052
if self.typed_api_type == ApiType.AZURE:
5153
if not api_version:
5254
raise error.InvalidRequestError(
5355
"An API version is required for the Azure API type."
5456
)
57+
5558
if not operation:
56-
raise error.InvalidRequestError(
57-
"The request needs an operation (eg: 'search') for the Azure OpenAI API type."
59+
base = self.class_url()
60+
return "/%s%s/%s?api-version=%s" % (
61+
self.azure_api_prefix,
62+
base,
63+
extn,
64+
api_version
5865
)
59-
extn = quote_plus(id)
60-
return "/%s/%s/%s?api-version=%s" % (
66+
67+
return "/%s/%s/%s/%s?api-version=%s" % (
6168
self.azure_api_prefix,
69+
self.azure_deployments_prefix,
6270
extn,
6371
operation,
64-
api_version,
72+
api_version
6573
)
6674

75+
6776
elif self.typed_api_type == ApiType.OPEN_AI:
6877
base = self.class_url()
69-
extn = quote_plus(id)
7078
return "%s/%s" % (base, extn)
7179

7280
else:
@@ -81,6 +89,7 @@ def _static_request(
8189
url_,
8290
api_key=None,
8391
api_base=None,
92+
api_type=None,
8493
request_id=None,
8594
api_version=None,
8695
organization=None,
@@ -91,10 +100,18 @@ def _static_request(
91100
api_version=api_version,
92101
organization=organization,
93102
api_base=api_base,
103+
api_type=api_type
94104
)
95105
response, _, api_key = requestor.request(
96106
method_, url_, params, request_id=request_id
97107
)
98108
return util.convert_to_openai_object(
99109
response, api_key, api_version, organization
100110
)
111+
112+
@classmethod
113+
def _get_api_type_and_version(cls, api_type: str, api_version: str):
114+
typed_api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
115+
typed_api_version = api_version or openai.api_version
116+
return (typed_api_type, typed_api_version)
117+

openai/api_resources/abstract/createable_api_resource.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from openai import api_requestor, util
1+
from openai import api_requestor, util, error
22
from openai.api_resources.abstract.api_resource import APIResource
3+
from openai.util import ApiType
34

45

56
class CreateableAPIResource(APIResource):
@@ -10,6 +11,7 @@ def create(
1011
cls,
1112
api_key=None,
1213
api_base=None,
14+
api_type=None,
1315
request_id=None,
1416
api_version=None,
1517
organization=None,
@@ -18,10 +20,20 @@ def create(
1820
requestor = api_requestor.APIRequestor(
1921
api_key,
2022
api_base=api_base,
23+
api_type=api_type,
2124
api_version=api_version,
2225
organization=organization,
2326
)
24-
url = cls.class_url()
27+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
28+
29+
if typed_api_type == ApiType.AZURE:
30+
base = cls.class_url()
31+
url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
32+
elif typed_api_type == ApiType.OPEN_AI:
33+
url = cls.class_url()
34+
else:
35+
raise error.InvalidAPIType('Unsupported API type %s' % api_type)
36+
2537
response, _, api_key = requestor.request(
2638
"post", url, params, request_id=request_id
2739
)
Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
from urllib.parse import quote_plus
22

3+
from openai import error
34
from openai.api_resources.abstract.api_resource import APIResource
4-
5+
from openai.util import ApiType
56

67
class DeletableAPIResource(APIResource):
78
@classmethod
8-
def delete(cls, sid, **params):
9+
def delete(cls, sid, api_type=None, api_version=None, **params):
910
if isinstance(cls, APIResource):
1011
raise ValueError(".delete may only be called as a class method now.")
11-
url = "%s/%s" % (cls.class_url(), quote_plus(sid))
12-
return cls._static_request("delete", url, **params)
12+
13+
base = cls.class_url()
14+
extn = quote_plus(sid)
15+
16+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
17+
if typed_api_type == ApiType.AZURE:
18+
url = "/%s%s/%s?api-version=%s" % (cls.azure_api_prefix, base, extn, api_version)
19+
elif typed_api_type == ApiType.OPEN_AI:
20+
url = "%s/%s" % (base, extn)
21+
else:
22+
raise error.InvalidAPIType('Unsupported API type %s' % api_type)
23+
24+
return cls._static_request("delete", url, api_type=api_type, api_version=api_version, **params)

openai/api_resources/abstract/engine_api_resource.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
class EngineAPIResource(APIResource):
1616
engine_required = True
1717
plain_old_data = False
18-
azure_api_prefix = "openai/deployments"
1918

2019
def __init__(self, engine: Optional[str] = None, **kwargs):
2120
super().__init__(engine=engine, **kwargs)
@@ -30,12 +29,7 @@ def class_url(
3029
# Namespaces are separated in object names with periods (.) and in URLs
3130
# with forward slashes (/), so replace the former with the latter.
3231
base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
33-
typed_api_type = (
34-
ApiType.from_str(api_type)
35-
if api_type
36-
else ApiType.from_str(openai.api_type)
37-
)
38-
api_version = api_version or openai.api_version
32+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
3933

4034
if typed_api_type == ApiType.AZURE:
4135
if not api_version:
@@ -47,11 +41,12 @@ def class_url(
4741
"You must provide the deployment name in the 'engine' parameter to access the Azure OpenAI service"
4842
)
4943
extn = quote_plus(engine)
50-
return "/%s/%s/%ss?api-version=%s" % (
44+
return "/%s/%s/%s/%ss?api-version=%s" % (
5145
cls.azure_api_prefix,
46+
cls.azure_deployments_prefix,
5247
extn,
5348
base,
54-
api_version,
49+
api_version
5550
)
5651

5752
elif typed_api_type == ApiType.OPEN_AI:
@@ -148,27 +143,29 @@ def instance_url(self):
148143
"id",
149144
)
150145

151-
params_connector = "?"
146+
extn = quote_plus(id)
147+
params_connector = '?'
148+
152149
if self.typed_api_type == ApiType.AZURE:
153150
api_version = self.api_version or openai.api_version
154151
if not api_version:
155152
raise error.InvalidRequestError(
156153
"An API version is required for the Azure API type."
157154
)
158-
extn = quote_plus(id)
159155
base = self.OBJECT_NAME.replace(".", "/")
160-
url = "/%s/%s/%ss/%s?api-version=%s" % (
156+
url = "/%s/%s/%s/%ss/%s?api-version=%s" % (
161157
self.azure_api_prefix,
158+
self.azure_deployments_prefix,
162159
self.engine,
163160
base,
164161
extn,
165-
api_version,
162+
api_version
166163
)
167-
params_connector = "&"
164+
params_connector = '&'
165+
168166

169167
elif self.typed_api_type == ApiType.OPEN_AI:
170168
base = self.class_url(self.engine, self.api_type, self.api_version)
171-
extn = quote_plus(id)
172169
url = "%s/%s" % (base, extn)
173170

174171
else:

openai/api_resources/abstract/listable_api_resource.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from openai import api_requestor, util
1+
from openai import api_requestor, util, error
22
from openai.api_resources.abstract.api_resource import APIResource
3+
from openai.util import ApiType
34

45

56
class ListableAPIResource(APIResource):
@@ -15,15 +16,27 @@ def list(
1516
api_version=None,
1617
organization=None,
1718
api_base=None,
19+
api_type=None,
1820
**params,
1921
):
2022
requestor = api_requestor.APIRequestor(
2123
api_key,
2224
api_base=api_base or cls.api_base(),
2325
api_version=api_version,
26+
api_type=api_type,
2427
organization=organization,
2528
)
26-
url = cls.class_url()
29+
30+
typed_api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
31+
32+
if typed_api_type == ApiType.AZURE:
33+
base = cls.class_url()
34+
url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
35+
elif typed_api_type == ApiType.OPEN_AI:
36+
url = cls.class_url()
37+
else:
38+
raise error.InvalidAPIType('Unsupported API type %s' % api_type)
39+
2740
response, _, api_key = requestor.request(
2841
"get", url, params, request_id=request_id
2942
)

0 commit comments

Comments
 (0)