Skip to content

feat(event_handler): Ensure Bedrock Agents resolver works with Pydantic v2 #5156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Callable

from typing_extensions import override
Expand All @@ -10,10 +11,12 @@
ProxyEventType,
ResponseBuilder,
)
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION

if TYPE_CHECKING:
from re import Match

from aws_lambda_powertools.event_handler.openapi.models import Contact, License, SecurityScheme, Server, Tag
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent

Expand Down Expand Up @@ -273,3 +276,107 @@ def _convert_matches_into_route_keys(self, match: Match) -> dict[str, str]:
if match.groupdict() and self.current_event.parameters:
parameters = {parameter["name"]: parameter["value"] for parameter in self.current_event.parameters}
return parameters

@override
def get_openapi_json_schema(
self,
*,
title: str = "Powertools API",
version: str = DEFAULT_API_VERSION,
openapi_version: str = DEFAULT_OPENAPI_VERSION,
summary: str | None = None,
description: str | None = None,
tags: list[Tag | str] | None = None,
servers: list[Server] | None = None,
terms_of_service: str | None = None,
contact: Contact | None = None,
license_info: License | None = None,
security_schemes: dict[str, SecurityScheme] | None = None,
security: list[dict[str, list[str]]] | None = None,
) -> str:
"""
Returns the OpenAPI schema as a JSON serializable dict.
Since Bedrock Agents only support OpenAPI 3.0.0, we convert OpenAPI 3.1.0 schemas
and enforce 3.0.0 compatibility for seamless integration.

Parameters
----------
title: str
The title of the application.
version: str
The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API
openapi_version: str, default = "3.0.0"
The version of the OpenAPI Specification (which the document uses).
summary: str, optional
A short summary of what the application does.
description: str, optional
A verbose explanation of the application behavior.
tags: list[Tag, str], optional
A list of tags used by the specification with additional metadata.
servers: list[Server], optional
An array of Server Objects, which provide connectivity information to a target server.
terms_of_service: str, optional
A URL to the Terms of Service for the API. MUST be in the format of a URL.
contact: Contact, optional
The contact information for the exposed API.
license_info: License, optional
The license information for the exposed API.
security_schemes: dict[str, SecurityScheme]], optional
A declaration of the security schemes available to be used in the specification.
security: list[dict[str, list[str]]], optional
A declaration of which security mechanisms are applied globally across the API.

Returns
-------
str
The OpenAPI schema as a JSON serializable dict.
"""
from aws_lambda_powertools.event_handler.openapi.compat import model_json

schema = super().get_openapi_schema(
title=title,
version=version,
openapi_version=openapi_version,
summary=summary,
description=description,
tags=tags,
servers=servers,
terms_of_service=terms_of_service,
contact=contact,
license_info=license_info,
security_schemes=security_schemes,
security=security,
)
schema.openapi = "3.0.3"

# Transform OpenAPI 3.1 into 3.0
def inner(yaml_dict):
if isinstance(yaml_dict, dict):
if "anyOf" in yaml_dict and isinstance((anyOf := yaml_dict["anyOf"]), list):
for i, item in enumerate(anyOf):
if isinstance(item, dict) and item.get("type") == "null":
anyOf.pop(i)
yaml_dict["nullable"] = True
if "examples" in yaml_dict:
examples = yaml_dict["examples"]
del yaml_dict["examples"]
if isinstance(examples, list) and len(examples):
yaml_dict["example"] = examples[0]
for value in yaml_dict.values():
inner(value)
elif isinstance(yaml_dict, list):
for item in yaml_dict:
inner(item)

model = json.loads(
model_json(
schema,
by_alias=True,
exclude_none=True,
indent=2,
),
)

inner(model)

return json.dumps(model)
21 changes: 20 additions & 1 deletion tests/functional/event_handler/test_bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any, Dict
from typing import Any, Dict, Optional

import pytest
from typing_extensions import Annotated

from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types
Expand Down Expand Up @@ -181,3 +182,21 @@ def send_reminders(
# THEN return the correct result
body = result["response"]["responseBody"]["application/json"]["body"]
assert json.loads(body) is True


@pytest.mark.usefixtures("pydanticv2_only")
def test_openapi_schema_for_pydanticv2(openapi30_schema):
# GIVEN BedrockAgentResolver is initialized with enable_validation=True
app = BedrockAgentResolver(enable_validation=True)

# WHEN we have a simple handler
@app.get("/", description="Testing")
def handler() -> Optional[Dict]:
pass

# WHEN we get the schema
schema = json.loads(app.get_openapi_json_schema())

# THEN the schema must be a valid 3.0.3 version
assert openapi30_schema(schema)
assert schema.get("openapi") == "3.0.3"
Loading