diff --git a/src/functions_framework/__init__.py b/src/functions_framework/__init__.py index ece4f446..ba1dbdd4 100644 --- a/src/functions_framework/__init__.py +++ b/src/functions_framework/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. import functools -import inspect import io import json import logging @@ -22,7 +21,6 @@ import sys import types -from inspect import signature from typing import Callable, Type import cloudevents.exceptions as cloud_exceptions @@ -41,6 +39,12 @@ ) from google.cloud.functions.context import Context +try: + from pydantic import BaseModel +except ModuleNotFoundError: + BaseModel = types.NoneType + + _FUNCTION_STATUS_HEADER_FIELD = "X-Google-Status" _CRASH = "crash" @@ -146,7 +150,10 @@ def _typed_event_func_wrapper(function, request, inputType: Type): def view_func(path): try: data = request.get_json() - input = inputType.from_dict(data) + if issubclass(inputType, BaseModel): + input = inputType(**data) + else: + input = inputType.from_dict(data) response = function(input) if response is None: return "", 200 diff --git a/src/functions_framework/_typed_event.py b/src/functions_framework/_typed_event.py index 40e715ae..6ef84d98 100644 --- a/src/functions_framework/_typed_event.py +++ b/src/functions_framework/_typed_event.py @@ -12,14 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import inspect +import types + +from inspect import signature as _signature +if sys.version_info.major == 3 and sys.version_info.minor < 10: + signature = _signature +else: + from functools import partial + signature = partial(_signature, eval_str=True) -from inspect import signature from functions_framework import _function_registry from functions_framework.exceptions import FunctionsFrameworkException +try: + from pydantic import BaseModel +except ModuleNotFoundError: + BaseModel = types.NoneType + """Registers user function in the REGISTRY_MAP and the INPUT_TYPE_MAP. Also performs some validity checks for the input type of the function @@ -96,10 +109,13 @@ def _select_input_type(decorator_type, annotation_type): def _validate_input_type(input_type): - if not ( - hasattr(input_type, "from_dict") and callable(getattr(input_type, "from_dict")) - ): - raise AttributeError( - "The type {decorator_type} does not have the required method called " - " 'from_dict'.".format(decorator_type=input_type) - ) + if BaseModel and issubclass(input_type, BaseModel): + # Pydantic model - we are good + return + if (hasattr(input_type, "from_dict") and callable(getattr(input_type, "from_dict"))): + # Use our customer from/to_doct protocol - we are good + return + raise AttributeError( + "The type {decorator_type} is neither Pydantic model no has the required method called " + " 'from_dict'.".format(decorator_type=input_type) + ) diff --git a/tests/test_functions/typed_events/pydantic_event.py b/tests/test_functions/typed_events/pydantic_event.py new file mode 100644 index 00000000..bbec0d51 --- /dev/null +++ b/tests/test_functions/typed_events/pydantic_event.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import functions_framework +from pydantic import BaseModel + + +class TestType(BaseModel): + name: str + age: int + + +@functions_framework.typed +def function_typed_pydantic(testType: TestType): + valid_event = testType.name == "jane" and testType.age == 20 + if not valid_event: + raise Exception("Received invalid input") + return testType.model_dump() + + diff --git a/tests/test_typed_event_functions.py b/tests/test_typed_event_functions.py index 3b8d5da1..b1eb4bde 100644 --- a/tests/test_typed_event_functions.py +++ b/tests/test_typed_event_functions.py @@ -121,3 +121,12 @@ def test_missing_parameter_typed_decorator(): def test_missing_to_dict_typed_decorator(typed_decorator_missing_to_dict): resp = typed_decorator_missing_to_dict.post("/", json={"name": "john", "age": 10}) assert resp.status_code == 500 + + +def test_typed_decorator_pydantic(): + source = TEST_FUNCTIONS_DIR / "typed_events" / "pydantic_event.py" + target = "function_typed_pydantic" + client = create_app(target, source).test_client() + resp = client.post("/", json={"name": "jane", "age": 20}) + assert resp.status_code == 200 + assert resp.json["name"] == "jane" and resp.json["age"] == 20 diff --git a/tox.ini b/tox.ini index e8c555b5..1efe5ecd 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,7 @@ deps = pytest-cov pytest-integration pretend + pydantic setenv = PYTESTARGS = --cov=functions_framework --cov-branch --cov-report term-missing --cov-fail-under=100 windows-latest: PYTESTARGS =