|
| 1 | +import inspect |
| 2 | +import json |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +from fastapi import FastAPI |
| 6 | +from fastapi.routing import APIRoute |
| 7 | +from pydantic import BaseModel |
| 8 | +from starlette.responses import Response |
| 9 | + |
| 10 | +from src.app_config import app |
| 11 | +from src.models import SystemErr |
| 12 | +from src.utility.logger import get_logger |
| 13 | + |
| 14 | +logger = get_logger() |
| 15 | + |
| 16 | + |
| 17 | +def check_response( |
| 18 | + out_response: Any, |
| 19 | + responses: dict | None = None, |
| 20 | + route_path: str | None = None, |
| 21 | + application: FastAPI = app, |
| 22 | +) -> Response: |
| 23 | + """ |
| 24 | + Check if the type of out_response is included in one of: |
| 25 | + - the responses param |
| 26 | + - in the responses of the specified route |
| 27 | + - in the responses of the caller route |
| 28 | +
|
| 29 | + and, in case, returns a Response object that includes the HTTP response code and |
| 30 | + the JSON or the text corresponding to the out_response parameter. |
| 31 | + If the type is not accepted it returns a Response containing a SystemErr. |
| 32 | +
|
| 33 | + Args: |
| 34 | + out_response: (Any) The response that the FastAPI route wants to return as output. |
| 35 | + responses: (dict, optional) A dictionary used to specify the possible responses for the route. |
| 36 | + It is used to define the possible HTTP response codes that the API can return along with their details, |
| 37 | + including the data model to be returned. Defaults to None. |
| 38 | + route_path: (str, optional) The path of the FastAPI route. Defaults to None. |
| 39 | + application: (FastAPI, optional) The FastAPI application. Defaults to the main application. |
| 40 | +
|
| 41 | + Returns: |
| 42 | + starlette.responses.Response: A Response object that includes the HTTP response code and |
| 43 | + the JSON or the text corresponding to the out_response parameter. |
| 44 | + """ # noqa: E501 |
| 45 | + |
| 46 | + if responses is not None: |
| 47 | + return _check_response_type(responses, out_response) |
| 48 | + |
| 49 | + if route_path is not None: |
| 50 | + endpoint = _find_caller_endpoint_by_path( |
| 51 | + application=application, caller_path=route_path |
| 52 | + ) |
| 53 | + |
| 54 | + else: |
| 55 | + caller_function = _find_caller_function() |
| 56 | + |
| 57 | + if caller_function is None: |
| 58 | + logger.error("Check_responses: caller function not found") |
| 59 | + return Response( |
| 60 | + status_code=500, |
| 61 | + content=SystemErr( |
| 62 | + error="An unexpected error occurred while processing the request. " |
| 63 | + "If the issue still persists, contact the platform team for assistance!" # noqa: E501 |
| 64 | + ).json(), |
| 65 | + media_type="application/json", |
| 66 | + ) |
| 67 | + |
| 68 | + endpoint = _find_caller_endpoint_by_name( |
| 69 | + application=application, caller_name=caller_function |
| 70 | + ) |
| 71 | + |
| 72 | + responses = endpoint.responses if endpoint is not None else None |
| 73 | + |
| 74 | + if responses is None: |
| 75 | + logger.error( |
| 76 | + "Check_responses: endpoint not found in app.routes or responses parameter has no value " # noqa: E501 |
| 77 | + ) |
| 78 | + return Response( |
| 79 | + status_code=500, |
| 80 | + content=SystemErr( |
| 81 | + error="An unexpected error occurred while processing the request. " |
| 82 | + "If the issue still persists, contact the platform team for assistance!" # noqa: E501 |
| 83 | + ).json(), |
| 84 | + media_type="application/json", |
| 85 | + ) |
| 86 | + |
| 87 | + return _check_response_type(responses, out_response) |
| 88 | + |
| 89 | + |
| 90 | +def _check_response_type(responses: dict, out_response: Any) -> Response: |
| 91 | + """ |
| 92 | + Ensures that the type of the parameter 'out_response' is contained in the 'model' |
| 93 | + field of the 'responses' dictionary. |
| 94 | +
|
| 95 | + Args: |
| 96 | + responses: (dict) A dictionary used to specify the possible responses for the route. |
| 97 | + It is used to define the possible HTTP response codes that the API can return along with their details, |
| 98 | + including the data model to be returned. |
| 99 | + out_response: (Any) The response that the FastAPI route wants to return as output. |
| 100 | + Returns: |
| 101 | + starlette.responses.Response: A Response object that includes the HTTP response code and |
| 102 | + the JSON or the text corresponding to the out_response parameter. |
| 103 | + If the type of 'out_response' is not accepted as per the 'model' field in 'responses', |
| 104 | + it returns a Response containing a SystemErr. |
| 105 | + """ # noqa: E501 |
| 106 | + |
| 107 | + # set default response_code = 500, if correct response type is found it will be changed # noqa: E501 |
| 108 | + response_code = 500 |
| 109 | + correct_output_type = False |
| 110 | + for k in responses.keys(): |
| 111 | + endpoint_response = responses.get(k) |
| 112 | + if type(endpoint_response) == dict and endpoint_response is not None: |
| 113 | + if endpoint_response.get('model') == type(out_response): |
| 114 | + correct_output_type = True |
| 115 | + response_code = k |
| 116 | + break |
| 117 | + |
| 118 | + if not correct_output_type: |
| 119 | + logger.error("Check response type: response type indicated not allowed") |
| 120 | + return Response( |
| 121 | + status_code=500, |
| 122 | + content=SystemErr( |
| 123 | + error="An unexpected error occurred while processing the request. " |
| 124 | + "If the issue still persists, contact the platform team for assistance!" # noqa: E501 |
| 125 | + ).json(), |
| 126 | + media_type="application/json", |
| 127 | + ) |
| 128 | + |
| 129 | + if isinstance(out_response, BaseModel): |
| 130 | + content = json.dumps(out_response.dict()) |
| 131 | + media_type = "application/json" |
| 132 | + elif isinstance(out_response, list) and all( |
| 133 | + isinstance(item, BaseModel) for item in out_response |
| 134 | + ): # noqa: E501 |
| 135 | + out_response_dicts = [item.dict() for item in out_response] |
| 136 | + content = json.dumps(out_response_dicts) |
| 137 | + media_type = "application/json" |
| 138 | + else: |
| 139 | + content = str(out_response) |
| 140 | + media_type = "text/plain" |
| 141 | + |
| 142 | + return Response( |
| 143 | + status_code=int(response_code), content=content, media_type=media_type |
| 144 | + ) |
| 145 | + |
| 146 | + |
| 147 | +def _find_caller_function(n_back: int = 2) -> str | None: |
| 148 | + """ |
| 149 | + Returns the name of the caller function 'n_back' frames up the call stack. |
| 150 | +
|
| 151 | + This function inspects the call stack to find the name of the caller function |
| 152 | + 'n_back' frames up from the current function call. If 'n_back' is not specified, |
| 153 | + it defaults to 2. |
| 154 | +
|
| 155 | + Args: |
| 156 | + n_back (int, optional): The number of frames up the call stack to look for |
| 157 | + the caller function. Defaults to 2, which retrieves the immediate caller. |
| 158 | +
|
| 159 | + Returns: |
| 160 | + str | None: The name of the caller function as a string, or None if the caller |
| 161 | + function cannot be determined (e.g., when 'n_back' exceeds the call stack depth |
| 162 | + or when this function is executed at the highest level in the call stack). |
| 163 | + """ # noqa: E501 |
| 164 | + |
| 165 | + frame = inspect.currentframe() |
| 166 | + |
| 167 | + if frame is not None: |
| 168 | + for _ in range(n_back): |
| 169 | + if frame is None: |
| 170 | + return None |
| 171 | + frame = frame.f_back |
| 172 | + |
| 173 | + caller_function = frame.f_code.co_name if frame is not None else None |
| 174 | + return caller_function |
| 175 | + |
| 176 | + |
| 177 | +def _find_caller_endpoint_by_path( |
| 178 | + application: FastAPI, caller_path: str |
| 179 | +) -> APIRoute | None: |
| 180 | + """ |
| 181 | + Find and return the FastAPI endpoint (APIRoute) based on the provided caller_path. |
| 182 | +
|
| 183 | + This function iterates through all registered routes in the FastAPI application and searches for an APIRoute |
| 184 | + that matches the caller_path argument. |
| 185 | + If a match is found, the corresponding APIRoute object is returned. If no match is found, the function returns None. |
| 186 | +
|
| 187 | + Args: |
| 188 | + application (FastAPI): The FastAPI application instance to search for the endpoint. |
| 189 | + caller_path (str): The path of the route to search for. |
| 190 | +
|
| 191 | + Returns: |
| 192 | + APIRoute or None: If a matching route is found, it returns the corresponding APIRoute object. |
| 193 | + If no matching route is found, it returns None. |
| 194 | +
|
| 195 | + """ # noqa: E501 |
| 196 | + |
| 197 | + for route in application.routes: |
| 198 | + if isinstance(route, APIRoute) and route.path == caller_path: |
| 199 | + return route |
| 200 | + |
| 201 | + return None |
| 202 | + |
| 203 | + |
| 204 | +def _find_caller_endpoint_by_name( |
| 205 | + application: FastAPI, caller_name: str |
| 206 | +) -> APIRoute | None: |
| 207 | + """ |
| 208 | + Find and return the FastAPI endpoint (APIRoute) based on the provided caller_name. |
| 209 | +
|
| 210 | + This function iterates through all registered routes in the FastAPI application and searches for an APIRoute |
| 211 | + that matches the caller_name argument. |
| 212 | + If a match is found, the corresponding APIRoute object is returned. If no match is found, the function returns None. |
| 213 | +
|
| 214 | + Args: |
| 215 | + application (FastAPI): The FastAPI application instance to search for the endpoint. |
| 216 | + caller_name (str): The name of the route to search for. |
| 217 | +
|
| 218 | + Returns: |
| 219 | + APIRoute or None: If a matching route is found, it returns the corresponding APIRoute object. |
| 220 | + If no matching route is found, it returns None. |
| 221 | +
|
| 222 | + """ # noqa: E501 |
| 223 | + |
| 224 | + for route in application.routes: |
| 225 | + if isinstance(route, APIRoute) and route.name == caller_name: |
| 226 | + return route |
| 227 | + |
| 228 | + return None |
0 commit comments