Skip to content

Commit bc74c0d

Browse files
authored
fix(postgrest): fix execute type definition (#1262)
1 parent a62c4dc commit bc74c0d

File tree

3 files changed

+128
-46
lines changed

3 files changed

+128
-46
lines changed

src/postgrest/src/postgrest/_async/request_builder.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Generic, Optional, TypeVar, Union
3+
from typing import Any, Generic, Literal, Optional, TypeVar, Union, overload
44

55
from httpx import AsyncClient, BasicAuth, Headers, QueryParams, Response
66
from pydantic import ValidationError
@@ -32,7 +32,7 @@ class AsyncQueryRequestBuilder:
3232
def __init__(self, request: ReqConfig):
3333
self.request = request
3434

35-
async def execute(self) -> APIResponse | str:
35+
async def execute(self) -> APIResponse:
3636
"""Execute the query.
3737
3838
.. tip::
@@ -47,17 +47,6 @@ async def execute(self) -> APIResponse | str:
4747
r = await self.request.send()
4848
try:
4949
if r.is_success:
50-
if self.request.http_method != "HEAD":
51-
body = r.text
52-
if self.request.headers.get("Accept") == "text/csv":
53-
return body
54-
if self.request.headers.get(
55-
"Accept"
56-
) and "application/vnd.pgrst.plan" in self.request.headers.get(
57-
"Accept"
58-
):
59-
if "+json" not in self.request.headers.get("Accept"):
60-
return body
6150
return APIResponse.from_http_request_response(r)
6251
else:
6352
json_obj = model_validate_json(APIErrorFromJSON, r.content)
@@ -95,6 +84,22 @@ async def execute(self) -> SingleAPIResponse:
9584
raise APIError(generate_default_error_message(r))
9685

9786

87+
class AsyncExplainRequestBuilder:
88+
def __init__(self, request: ReqConfig):
89+
self.request = request
90+
91+
async def execute(self) -> str:
92+
r = await self.request.send()
93+
try:
94+
if r.is_success:
95+
return r.text
96+
else:
97+
json_obj = model_validate_json(APIErrorFromJSON, r.content)
98+
raise APIError(dict(json_obj))
99+
except ValidationError as e:
100+
raise APIError(generate_default_error_message(r))
101+
102+
98103
class AsyncMaybeSingleRequestBuilder:
99104
def __init__(self, request: ReqConfig):
100105
self.request = request
@@ -176,6 +181,52 @@ def csv(self) -> AsyncSingleRequestBuilder:
176181
self.request.headers["Accept"] = "text/csv"
177182
return AsyncSingleRequestBuilder(self.request)
178183

184+
@overload
185+
def explain(
186+
self,
187+
analyze: bool = False,
188+
verbose: bool = False,
189+
settings: bool = False,
190+
buffers: bool = False,
191+
wal: bool = False,
192+
format: Literal["text"] = "text",
193+
) -> AsyncExplainRequestBuilder: ...
194+
195+
@overload
196+
def explain(
197+
self,
198+
analyze: bool = False,
199+
verbose: bool = False,
200+
settings: bool = False,
201+
buffers: bool = False,
202+
wal: bool = False,
203+
*,
204+
format: Literal["json"],
205+
) -> AsyncSingleRequestBuilder: ...
206+
207+
def explain(
208+
self,
209+
analyze: bool = False,
210+
verbose: bool = False,
211+
settings: bool = False,
212+
buffers: bool = False,
213+
wal: bool = False,
214+
format: Literal["text", "json"] = "text",
215+
) -> AsyncExplainRequestBuilder | AsyncSingleRequestBuilder:
216+
options = [
217+
key
218+
for key, value in locals().items()
219+
if key not in ["self", "format"] and value
220+
]
221+
options_str = "|".join(options)
222+
self.request.headers["Accept"] = (
223+
f"application/vnd.pgrst.plan+{format}; options={options_str}"
224+
)
225+
if format == "text":
226+
return AsyncExplainRequestBuilder(self.request)
227+
else:
228+
return AsyncSingleRequestBuilder(self.request)
229+
179230

180231
class AsyncRequestBuilder: #
181232
def __init__(

src/postgrest/src/postgrest/_sync/request_builder.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Generic, Optional, TypeVar, Union
3+
from typing import Any, Generic, Literal, Optional, TypeVar, Union, overload
44

55
from httpx import BasicAuth, Client, Headers, QueryParams, Response
66
from pydantic import ValidationError
@@ -32,7 +32,7 @@ class SyncQueryRequestBuilder:
3232
def __init__(self, request: ReqConfig):
3333
self.request = request
3434

35-
def execute(self) -> APIResponse | str:
35+
def execute(self) -> APIResponse:
3636
"""Execute the query.
3737
3838
.. tip::
@@ -47,17 +47,6 @@ def execute(self) -> APIResponse | str:
4747
r = self.request.send()
4848
try:
4949
if r.is_success:
50-
if self.request.http_method != "HEAD":
51-
body = r.text
52-
if self.request.headers.get("Accept") == "text/csv":
53-
return body
54-
if self.request.headers.get(
55-
"Accept"
56-
) and "application/vnd.pgrst.plan" in self.request.headers.get(
57-
"Accept"
58-
):
59-
if "+json" not in self.request.headers.get("Accept"):
60-
return body
6150
return APIResponse.from_http_request_response(r)
6251
else:
6352
json_obj = model_validate_json(APIErrorFromJSON, r.content)
@@ -95,6 +84,22 @@ def execute(self) -> SingleAPIResponse:
9584
raise APIError(generate_default_error_message(r))
9685

9786

87+
class SyncExplainRequestBuilder:
88+
def __init__(self, request: ReqConfig):
89+
self.request = request
90+
91+
def execute(self) -> str:
92+
r = self.request.send()
93+
try:
94+
if r.is_success:
95+
return r.text
96+
else:
97+
json_obj = model_validate_json(APIErrorFromJSON, r.content)
98+
raise APIError(dict(json_obj))
99+
except ValidationError as e:
100+
raise APIError(generate_default_error_message(r))
101+
102+
98103
class SyncMaybeSingleRequestBuilder:
99104
def __init__(self, request: ReqConfig):
100105
self.request = request
@@ -176,6 +181,52 @@ def csv(self) -> SyncSingleRequestBuilder:
176181
self.request.headers["Accept"] = "text/csv"
177182
return SyncSingleRequestBuilder(self.request)
178183

184+
@overload
185+
def explain(
186+
self,
187+
analyze: bool = False,
188+
verbose: bool = False,
189+
settings: bool = False,
190+
buffers: bool = False,
191+
wal: bool = False,
192+
format: Literal["text"] = "text",
193+
) -> SyncExplainRequestBuilder: ...
194+
195+
@overload
196+
def explain(
197+
self,
198+
analyze: bool = False,
199+
verbose: bool = False,
200+
settings: bool = False,
201+
buffers: bool = False,
202+
wal: bool = False,
203+
*,
204+
format: Literal["json"],
205+
) -> SyncSingleRequestBuilder: ...
206+
207+
def explain(
208+
self,
209+
analyze: bool = False,
210+
verbose: bool = False,
211+
settings: bool = False,
212+
buffers: bool = False,
213+
wal: bool = False,
214+
format: Literal["text", "json"] = "text",
215+
) -> SyncExplainRequestBuilder | SyncSingleRequestBuilder:
216+
options = [
217+
key
218+
for key, value in locals().items()
219+
if key not in ["self", "format"] and value
220+
]
221+
options_str = "|".join(options)
222+
self.request.headers["Accept"] = (
223+
f"application/vnd.pgrst.plan+{format}; options={options_str}"
224+
)
225+
if format == "text":
226+
return SyncExplainRequestBuilder(self.request)
227+
else:
228+
return SyncSingleRequestBuilder(self.request)
229+
179230

180231
class SyncRequestBuilder: #
181232
def __init__(

src/postgrest/src/postgrest/base_request_builder.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -557,26 +557,6 @@ def max_affected(self: Self, value: int) -> Self:
557557

558558

559559
class BaseSelectRequestBuilder(BaseFilterRequestBuilder[C]):
560-
def explain(
561-
self: Self,
562-
analyze: bool = False,
563-
verbose: bool = False,
564-
settings: bool = False,
565-
buffers: bool = False,
566-
wal: bool = False,
567-
format: Literal["text", "json"] = "text",
568-
) -> Self:
569-
options = [
570-
key
571-
for key, value in locals().items()
572-
if key not in ["self", "format"] and value
573-
]
574-
options_str = "|".join(options)
575-
self.request.headers["Accept"] = (
576-
f"application/vnd.pgrst.plan+{format}; options={options_str}"
577-
)
578-
return self
579-
580560
def order(
581561
self: Self,
582562
column: str,

0 commit comments

Comments
 (0)