Skip to content

Commit 86c7de8

Browse files
authored
Feature/support header parameters (#148)
* Added support for Header parameters (#117) Co-authored-by: Ethan Mann <[email protected]>
1 parent c0d101b commit 86c7de8

File tree

19 files changed

+125
-26
lines changed

19 files changed

+125
-26
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77

88
## 0.5.3 - Unrelease
9+
### Additions
10+
- Added support for header parameters (#117)
11+
912
### Fixes
1013
- JSON bodies will now be assigned correctly in generated clients(#139 & #147). Thanks @pawamoy!
1114

end_to_end_tests/fastapi_app/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66
from typing import Any, Dict, List, Union
77

8-
from fastapi import APIRouter, FastAPI, File, Query, UploadFile
8+
from fastapi import APIRouter, FastAPI, File, Header, Query, UploadFile
99
from pydantic import BaseModel
1010

1111
app = FastAPI(title="My Test API", description="An API for testing openapi-python-client",)
@@ -55,7 +55,7 @@ def get_list(an_enum_value: List[AnEnum] = Query(...), some_date: Union[date, da
5555

5656

5757
@test_router.post("/upload")
58-
async def upload_file(some_file: UploadFile = File(...)):
58+
async def upload_file(some_file: UploadFile = File(...), keep_alive: bool = Header(None)):
5959
""" Upload a file """
6060
data = await some_file.read()
6161
return (some_file.filename, some_file.content_type, data)

end_to_end_tests/fastapi_app/openapi.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,17 @@
102102
"summary": "Upload File",
103103
"description": "Upload a file ",
104104
"operationId": "upload_file_tests_upload_post",
105+
"parameters": [
106+
{
107+
"required": false,
108+
"schema": {
109+
"title": "Keep-Alive",
110+
"type": "boolean"
111+
},
112+
"name": "keep-alive",
113+
"in": "header"
114+
}
115+
],
105116
"requestBody": {
106117
"content": {
107118
"multipart/form-data": {

end_to_end_tests/golden-master/my_test_api_client/api/default.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ def ping_ping_get(*, client: Client,) -> bool:
1212
""" A quick check to see if the system is running """
1313
url = "{}/ping".format(client.base_url)
1414

15-
response = httpx.get(url=url, headers=client.get_headers(),)
15+
headers: Dict[str, Any] = client.get_headers()
16+
17+
response = httpx.get(url=url, headers=headers,)
1618

1719
if response.status_code == 200:
1820
return bool(response.text)

end_to_end_tests/golden-master/my_test_api_client/api/tests.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def get_user_list(
2121
""" Get a list of things """
2222
url = "{}/tests/".format(client.base_url)
2323

24+
headers: Dict[str, Any] = client.get_headers()
25+
2426
json_an_enum_value = []
2527
for an_enum_value_item_data in an_enum_value:
2628
an_enum_value_item = an_enum_value_item_data.value
@@ -38,7 +40,7 @@ def get_user_list(
3840
"some_date": json_some_date,
3941
}
4042

41-
response = httpx.get(url=url, headers=client.get_headers(), params=params,)
43+
response = httpx.get(url=url, headers=headers, params=params,)
4244

4345
if response.status_code == 200:
4446
return [AModel.from_dict(item) for item in cast(List[Dict[str, Any]], response.json())]
@@ -49,15 +51,19 @@ def get_user_list(
4951

5052

5153
def upload_file_tests_upload_post(
52-
*, client: Client, multipart_data: BodyUploadFileTestsUploadPost,
54+
*, client: Client, multipart_data: BodyUploadFileTestsUploadPost, keep_alive: Optional[bool] = None,
5355
) -> Union[
5456
None, HTTPValidationError,
5557
]:
5658

5759
""" Upload a file """
5860
url = "{}/tests/upload".format(client.base_url)
5961

60-
response = httpx.post(url=url, headers=client.get_headers(), files=multipart_data.to_dict(),)
62+
headers: Dict[str, Any] = client.get_headers()
63+
if keep_alive is not None:
64+
headers["keep-alive"] = keep_alive
65+
66+
response = httpx.post(url=url, headers=headers, files=multipart_data.to_dict(),)
6167

6268
if response.status_code == 200:
6369
return None
@@ -76,9 +82,11 @@ def json_body_tests_json_body_post(
7682
""" Try sending a JSON body """
7783
url = "{}/tests/json_body".format(client.base_url)
7884

85+
headers: Dict[str, Any] = client.get_headers()
86+
7987
json_json_body = json_body.to_dict()
8088

81-
response = httpx.post(url=url, headers=client.get_headers(), json=json_json_body,)
89+
response = httpx.post(url=url, headers=headers, json=json_json_body,)
8290

8391
if response.status_code == 200:
8492
return None

end_to_end_tests/golden-master/my_test_api_client/async_api/default.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ async def ping_ping_get(*, client: Client,) -> bool:
1212
""" A quick check to see if the system is running """
1313
url = "{}/ping".format(client.base_url,)
1414

15+
headers: Dict[str, Any] = client.get_headers()
16+
1517
async with httpx.AsyncClient() as _client:
16-
response = await _client.get(url=url, headers=client.get_headers(),)
18+
response = await _client.get(url=url, headers=headers,)
1719

1820
if response.status_code == 200:
1921
return bool(response.text)

end_to_end_tests/golden-master/my_test_api_client/async_api/tests.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ async def get_user_list(
2121
""" Get a list of things """
2222
url = "{}/tests/".format(client.base_url,)
2323

24+
headers: Dict[str, Any] = client.get_headers()
25+
2426
json_an_enum_value = []
2527
for an_enum_value_item_data in an_enum_value:
2628
an_enum_value_item = an_enum_value_item_data.value
@@ -39,7 +41,7 @@ async def get_user_list(
3941
}
4042

4143
async with httpx.AsyncClient() as _client:
42-
response = await _client.get(url=url, headers=client.get_headers(), params=params,)
44+
response = await _client.get(url=url, headers=headers, params=params,)
4345

4446
if response.status_code == 200:
4547
return [AModel.from_dict(item) for item in cast(List[Dict[str, Any]], response.json())]
@@ -50,16 +52,20 @@ async def get_user_list(
5052

5153

5254
async def upload_file_tests_upload_post(
53-
*, client: Client, multipart_data: BodyUploadFileTestsUploadPost,
55+
*, client: Client, multipart_data: BodyUploadFileTestsUploadPost, keep_alive: Optional[bool] = None,
5456
) -> Union[
5557
None, HTTPValidationError,
5658
]:
5759

5860
""" Upload a file """
5961
url = "{}/tests/upload".format(client.base_url,)
6062

63+
headers: Dict[str, Any] = client.get_headers()
64+
if keep_alive is not None:
65+
headers["keep-alive"] = keep_alive
66+
6167
async with httpx.AsyncClient() as _client:
62-
response = await _client.post(url=url, headers=client.get_headers(), files=multipart_data.to_dict(),)
68+
response = await _client.post(url=url, headers=headers, files=multipart_data.to_dict(),)
6369

6470
if response.status_code == 200:
6571
return None
@@ -78,10 +84,12 @@ async def json_body_tests_json_body_post(
7884
""" Try sending a JSON body """
7985
url = "{}/tests/json_body".format(client.base_url,)
8086

87+
headers: Dict[str, Any] = client.get_headers()
88+
8189
json_json_body = json_body.to_dict()
8290

8391
async with httpx.AsyncClient() as _client:
84-
response = await _client.post(url=url, headers=client.get_headers(), json=json_json_body,)
92+
response = await _client.post(url=url, headers=headers, json=json_json_body,)
8593

8694
if response.status_code == 200:
8795
return None

openapi_python_client/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _get_document(*, url: Optional[str], path: Optional[Path]) -> Union[Dict[str
8383

8484

8585
class Project:
86-
TEMPLATE_FILTERS = {"snakecase": utils.snake_case}
86+
TEMPLATE_FILTERS = {"snakecase": utils.snake_case, "spinalcase": utils.spinal_case}
8787
project_name_override: Optional[str] = None
8888
package_name_override: Optional[str] = None
8989

openapi_python_client/parser/openapi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class ParameterLocation(str, Enum):
1919

2020
QUERY = "query"
2121
PATH = "path"
22+
HEADER = "header"
2223

2324

2425
def import_string_from_reference(reference: Reference, prefix: str = "") -> str:
@@ -78,6 +79,7 @@ class Endpoint:
7879
relative_imports: Set[str] = field(default_factory=set)
7980
query_parameters: List[Property] = field(default_factory=list)
8081
path_parameters: List[Property] = field(default_factory=list)
82+
header_parameters: List[Property] = field(default_factory=list)
8183
responses: List[Response] = field(default_factory=list)
8284
form_body_reference: Optional[Reference] = None
8385
json_body: Optional[Property] = None
@@ -164,6 +166,8 @@ def _add_parameters(endpoint: Endpoint, data: oai.Operation) -> Union[Endpoint,
164166
endpoint.query_parameters.append(prop)
165167
elif param.param_in == ParameterLocation.PATH:
166168
endpoint.path_parameters.append(prop)
169+
elif param.param_in == ParameterLocation.HEADER:
170+
endpoint.header_parameters.append(prop)
167171
else:
168172
return ParseError(data=param, detail="Parameter must be declared in path or query")
169173
return endpoint

openapi_python_client/templates/async_endpoint_module.pyi

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ from ..errors import ApiResponseError
1111
{% endfor %}
1212
{% for endpoint in collection.endpoints %}
1313

14-
{% from "endpoint_macros.pyi" import query_params, json_body, return_type %}
14+
{% from "endpoint_macros.pyi" import header_params, query_params, json_body, return_type %}
1515

1616
async def {{ endpoint.name | snakecase }}(
1717
*,
@@ -41,6 +41,9 @@ async def {{ endpoint.name | snakecase }}(
4141
{% for parameter in endpoint.query_parameters %}
4242
{{ parameter.to_string() }},
4343
{% endfor %}
44+
{% for parameter in endpoint.header_parameters %}
45+
{{ parameter.to_string() }},
46+
{% endfor %}
4447
{{ return_type(endpoint) }}
4548
""" {{ endpoint.description }} """
4649
url = "{}{{ endpoint.path }}".format(
@@ -50,13 +53,16 @@ async def {{ endpoint.name | snakecase }}(
5053
{% endfor %}
5154
)
5255

56+
headers: Dict[str, Any] = client.get_headers()
57+
{{ header_params(endpoint) | indent(4) }}
58+
5359
{{ query_params(endpoint) | indent(4) }}
5460
{{ json_body(endpoint) | indent(4) }}
5561

5662
async with httpx.AsyncClient() as _client:
5763
response = await _client.{{ endpoint.method }}(
5864
url=url,
59-
headers=client.get_headers(),
65+
headers=headers,
6066
{% if endpoint.form_body_reference %}
6167
data=asdict(form_data),
6268
{% endif %}

openapi_python_client/templates/endpoint_macros.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
{% macro header_params(endpoint) %}
2+
{% if endpoint.header_parameters %}
3+
{% for parameter in endpoint.header_parameters %}
4+
{% if parameter.required %}
5+
headers["{{ parameter.python_name | spinalcase}}"] = {{ parameter.python_name }}
6+
{% else %}
7+
if {{ parameter.python_name }} is not None:
8+
headers["{{ parameter.python_name | spinalcase}}"] = {{ parameter.python_name }}
9+
{% endif %}
10+
{% endfor %}
11+
{% endif %}
12+
{% endmacro %}
13+
114
{% macro query_params(endpoint) %}
215
{% if endpoint.query_parameters %}
316
{% for property in endpoint.query_parameters %}

openapi_python_client/templates/endpoint_module.pyi

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ from ..errors import ApiResponseError
1111
{% endfor %}
1212
{% for endpoint in collection.endpoints %}
1313

14-
{% from "endpoint_macros.pyi" import query_params, json_body, return_type %}
14+
{% from "endpoint_macros.pyi" import header_params, query_params, json_body, return_type %}
1515

1616
def {{ endpoint.name | snakecase }}(
1717
*,
@@ -41,6 +41,9 @@ def {{ endpoint.name | snakecase }}(
4141
{% for parameter in endpoint.query_parameters %}
4242
{{ parameter.to_string() }},
4343
{% endfor %}
44+
{% for parameter in endpoint.header_parameters %}
45+
{{ parameter.to_string() }},
46+
{% endfor %}
4447
{{ return_type(endpoint) }}
4548
""" {{ endpoint.description }} """
4649
url = "{}{{ endpoint.path }}".format(
@@ -50,14 +53,17 @@ def {{ endpoint.name | snakecase }}(
5053
{%- endfor -%}
5154
)
5255

56+
headers: Dict[str, Any] = client.get_headers()
57+
{{ header_params(endpoint) | indent(4) }}
58+
5359
{{ query_params(endpoint) | indent(4) }}
5460

5561
{{ json_body(endpoint) | indent(4) }}
5662

5763

5864
response = httpx.{{ endpoint.method }}(
5965
url=url,
60-
headers=client.get_headers(),
66+
headers=headers,
6167
{% if endpoint.form_body_reference %}
6268
data=asdict(form_data),
6369
{% endif %}

openapi_python_client/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ def snake_case(value: str) -> str:
1111

1212
def pascal_case(value: str) -> str:
1313
return stringcase.pascalcase(value)
14+
15+
16+
def spinal_case(value: str) -> str:
17+
return stringcase.spinalcase(value)

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ isort .\
5858
openapi = "python -m end_to_end_tests.fastapi_app"
5959
gm = "python -m end_to_end_tests.regen_golden_master"
6060
e2e = "pytest openapi_python_client end_to_end_tests"
61+
oge = """
62+
task openapi\
63+
&& task gm\
64+
&& task e2e\
65+
"""
6166

6267
[tool.black]
6368
line-length = 120

tests/test_openapi_parser/test_openapi.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,15 @@ def test__add_parameters_happy(self, mocker):
452452
query_prop = mocker.MagicMock(autospec=Property)
453453
query_prop_import = mocker.MagicMock()
454454
query_prop.get_imports = mocker.MagicMock(return_value={query_prop_import})
455-
property_from_data = mocker.patch(f"{MODULE_NAME}.property_from_data", side_effect=[path_prop, query_prop])
455+
header_prop = mocker.MagicMock(autospec=Property)
456+
header_prop_import = mocker.MagicMock()
457+
header_prop.get_imports = mocker.MagicMock(return_value={header_prop_import})
458+
property_from_data = mocker.patch(
459+
f"{MODULE_NAME}.property_from_data", side_effect=[path_prop, query_prop, header_prop]
460+
)
456461
path_schema = mocker.MagicMock()
457462
query_schema = mocker.MagicMock()
463+
header_schema = mocker.MagicMock()
458464
data = oai.Operation.construct(
459465
parameters=[
460466
oai.Parameter.construct(
@@ -463,6 +469,9 @@ def test__add_parameters_happy(self, mocker):
463469
oai.Parameter.construct(
464470
name="query_prop_name", required=False, param_schema=query_schema, param_in="query"
465471
),
472+
oai.Parameter.construct(
473+
name="header_prop_name", required=False, param_schema=header_schema, param_in="header"
474+
),
466475
oai.Reference.construct(), # Should be ignored
467476
oai.Parameter.construct(), # Should be ignored
468477
]
@@ -474,17 +483,16 @@ def test__add_parameters_happy(self, mocker):
474483
[
475484
mocker.call(name="path_prop_name", required=True, data=path_schema),
476485
mocker.call(name="query_prop_name", required=False, data=query_schema),
486+
mocker.call(name="header_prop_name", required=False, data=header_schema),
477487
]
478488
)
479489
path_prop.get_imports.assert_called_once_with(prefix="..models")
480490
query_prop.get_imports.assert_called_once_with(prefix="..models")
481-
assert endpoint.relative_imports == {
482-
"import_3",
483-
path_prop_import,
484-
query_prop_import,
485-
}
491+
header_prop.get_imports.assert_called_once_with(prefix="..models")
492+
assert endpoint.relative_imports == {"import_3", path_prop_import, query_prop_import, header_prop_import}
486493
assert endpoint.path_parameters == [path_prop]
487494
assert endpoint.query_parameters == [query_prop]
495+
assert endpoint.header_parameters == [header_prop]
488496

489497
def test_from_data_bad_params(self, mocker):
490498
from openapi_python_client.parser.openapi import Endpoint

0 commit comments

Comments
 (0)