|
| 1 | +import os |
| 2 | +import re |
| 3 | +import sys |
| 4 | +from typing import TYPE_CHECKING |
| 5 | + |
| 6 | + |
| 7 | +if TYPE_CHECKING: |
| 8 | + from typing import Callable |
| 9 | + from typing import Dict |
| 10 | + from typing import Iterable |
| 11 | + from typing import List |
| 12 | + from typing import Tuple |
| 13 | + from typing import Union |
| 14 | + |
| 15 | + from ddtrace import Span |
| 16 | + |
| 17 | +import graphql |
| 18 | +from graphql import MiddlewareManager |
| 19 | +from graphql.error import GraphQLError |
| 20 | +from graphql.execution import ExecutionResult |
| 21 | +from graphql.language.source import Source |
| 22 | + |
| 23 | +from ddtrace import config |
| 24 | +from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY |
| 25 | +from ddtrace.constants import ERROR_MSG |
| 26 | +from ddtrace.constants import ERROR_TYPE |
| 27 | +from ddtrace.constants import SPAN_MEASURED_KEY |
| 28 | +from ddtrace.internal.compat import stringify |
| 29 | +from ddtrace.internal.utils import ArgumentError |
| 30 | +from ddtrace.internal.utils import get_argument_value |
| 31 | +from ddtrace.internal.utils import set_argument_value |
| 32 | +from ddtrace.internal.utils.formats import asbool |
| 33 | +from ddtrace.internal.utils.version import parse_version |
| 34 | +from ddtrace.internal.wrapping import unwrap |
| 35 | +from ddtrace.internal.wrapping import wrap |
| 36 | +from ddtrace.pin import Pin |
| 37 | + |
| 38 | +from .. import trace_utils |
| 39 | +from ...ext import SpanTypes |
| 40 | + |
| 41 | + |
| 42 | +_graphql_version = parse_version(getattr(graphql, "__version__")) |
| 43 | + |
| 44 | +if _graphql_version < (3, 0): |
| 45 | + from graphql.language.ast import Document |
| 46 | +else: |
| 47 | + from graphql.language.ast import DocumentNode as Document |
| 48 | + |
| 49 | + |
| 50 | +config._add( |
| 51 | + "graphql", |
| 52 | + dict( |
| 53 | + _default_service="graphql", |
| 54 | + resolvers_enabled=asbool(os.getenv("DD_TRACE_GRAPHQL_RESOLVERS_ENABLED", default=False)), |
| 55 | + ), |
| 56 | +) |
| 57 | + |
| 58 | + |
| 59 | +def patch(): |
| 60 | + if getattr(graphql, "_datadog_patch", False): |
| 61 | + return |
| 62 | + setattr(graphql, "_datadog_patch", True) |
| 63 | + Pin().onto(graphql) |
| 64 | + |
| 65 | + for module_str, func_name, wrapper in _get_patching_candidates(): |
| 66 | + _update_patching(wrap, module_str, func_name, wrapper) |
| 67 | + |
| 68 | + |
| 69 | +def unpatch(): |
| 70 | + if not getattr(graphql, "_datadog_patch", False) or _graphql_version < (2, 0): |
| 71 | + return |
| 72 | + |
| 73 | + for module_str, func_name, wrapper in _get_patching_candidates(): |
| 74 | + _update_patching(unwrap, module_str, func_name, wrapper) |
| 75 | + |
| 76 | + setattr(graphql, "_datadog_patch", False) |
| 77 | + |
| 78 | + |
| 79 | +def _get_patching_candidates(): |
| 80 | + if _graphql_version < (3, 0): |
| 81 | + return [ |
| 82 | + ("graphql.graphql", "execute_graphql", _traced_query), |
| 83 | + ("graphql.language.parser", "parse", _traced_parse), |
| 84 | + ("graphql.validation.validation", "validate", _traced_validate), |
| 85 | + ("graphql.execution.executor", "execute", _traced_execute), |
| 86 | + ] |
| 87 | + return [ |
| 88 | + ("graphql.graphql", "graphql_impl", _traced_query), |
| 89 | + ("graphql.language.parser", "parse", _traced_parse), |
| 90 | + ("graphql.validation.validate", "validate", _traced_validate), |
| 91 | + ("graphql.execution.execute", "execute", _traced_execute), |
| 92 | + ] |
| 93 | + |
| 94 | + |
| 95 | +def _update_patching(operation, module_str, func_name, wrapper): |
| 96 | + module = sys.modules[module_str] |
| 97 | + func = getattr(module, func_name) |
| 98 | + operation(func, wrapper) |
| 99 | + |
| 100 | + |
| 101 | +def _traced_parse(func, args, kwargs): |
| 102 | + pin = Pin.get_from(graphql) |
| 103 | + if not pin or not pin.enabled(): |
| 104 | + return func(*args, **kwargs) |
| 105 | + |
| 106 | + # If graphql.parse() is called outside graphql.graphql(), graphql.parse will |
| 107 | + # be a top level span. Therefore we must explicitly set the service name. |
| 108 | + with pin.tracer.trace( |
| 109 | + name="graphql.parse", |
| 110 | + service=trace_utils.int_service(pin, config.graphql), |
| 111 | + span_type=SpanTypes.GRAPHQL, |
| 112 | + ): |
| 113 | + return func(*args, **kwargs) |
| 114 | + |
| 115 | + |
| 116 | +def _traced_validate(func, args, kwargs): |
| 117 | + pin = Pin.get_from(graphql) |
| 118 | + if not pin or not pin.enabled(): |
| 119 | + return func(*args, **kwargs) |
| 120 | + |
| 121 | + # If graphql.validate() is called outside graphql.graphql(), graphql.validate will |
| 122 | + # be a top level span. Therefore we must explicitly set the service name. |
| 123 | + with pin.tracer.trace( |
| 124 | + name="graphql.validate", |
| 125 | + service=trace_utils.int_service(pin, config.graphql), |
| 126 | + span_type=SpanTypes.GRAPHQL, |
| 127 | + ) as span: |
| 128 | + errors = func(*args, **kwargs) |
| 129 | + _set_span_errors(errors, span) |
| 130 | + return errors |
| 131 | + |
| 132 | + |
| 133 | +def _traced_execute(func, args, kwargs): |
| 134 | + pin = Pin.get_from(graphql) |
| 135 | + if not pin or not pin.enabled(): |
| 136 | + return func(*args, **kwargs) |
| 137 | + |
| 138 | + if config.graphql.resolvers_enabled: |
| 139 | + # patch resolvers |
| 140 | + args, kwargs = _inject_trace_middleware_to_args(_resolver_middleware, args, kwargs) |
| 141 | + |
| 142 | + # set resource name |
| 143 | + if _graphql_version < (3, 0): |
| 144 | + document = get_argument_value(args, kwargs, 1, "document_ast") |
| 145 | + else: |
| 146 | + document = get_argument_value(args, kwargs, 1, "document") |
| 147 | + resource = _get_source_str(document) |
| 148 | + |
| 149 | + with pin.tracer.trace( |
| 150 | + name="graphql.execute", |
| 151 | + resource=resource, |
| 152 | + service=trace_utils.int_service(pin, config.graphql), |
| 153 | + span_type=SpanTypes.GRAPHQL, |
| 154 | + ) as span: |
| 155 | + result = func(*args, **kwargs) |
| 156 | + if isinstance(result, ExecutionResult): |
| 157 | + # set error tags if the result contains a list of GraphqlErrors, skip if it's a promise |
| 158 | + _set_span_errors(result.errors, span) |
| 159 | + return result |
| 160 | + |
| 161 | + |
| 162 | +def _traced_query(func, args, kwargs): |
| 163 | + pin = Pin.get_from(graphql) |
| 164 | + if not pin or not pin.enabled(): |
| 165 | + return func(*args, **kwargs) |
| 166 | + |
| 167 | + # set resource name |
| 168 | + source = get_argument_value(args, kwargs, 1, "source") |
| 169 | + resource = _get_source_str(source) |
| 170 | + |
| 171 | + with pin.tracer.trace( |
| 172 | + name="graphql.query", |
| 173 | + resource=resource, |
| 174 | + service=trace_utils.int_service(pin, config.graphql), |
| 175 | + span_type=SpanTypes.GRAPHQL, |
| 176 | + ) as span: |
| 177 | + # mark span as measured and set sample rate |
| 178 | + span.set_tag(SPAN_MEASURED_KEY) |
| 179 | + sample_rate = config.graphql.get_analytics_sample_rate() |
| 180 | + if sample_rate is not None: |
| 181 | + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, sample_rate) |
| 182 | + |
| 183 | + result = func(*args, **kwargs) |
| 184 | + if isinstance(result, ExecutionResult): |
| 185 | + # set error tags if the result contains a list of GraphqlErrors, skip if it's a promise |
| 186 | + # If the wrapped validate and execute functions return a list of errors we will duplicate |
| 187 | + # the span errors here. |
| 188 | + _set_span_errors(result.errors, span) |
| 189 | + return result |
| 190 | + |
| 191 | + |
| 192 | +def _resolver_middleware(next_middleware, root, info, **args): |
| 193 | + """ |
| 194 | + trace middleware which wraps the resolvers of graphql fields. |
| 195 | + Note - graphql middlewares can not be a partial. It must be a class or a function. |
| 196 | + """ |
| 197 | + pin = Pin.get_from(graphql) |
| 198 | + if not pin or not pin.enabled(): |
| 199 | + return next_middleware(root, info, **args) |
| 200 | + |
| 201 | + with pin.tracer.trace( |
| 202 | + name="graphql.resolve", |
| 203 | + resource=info.field_name, |
| 204 | + span_type=SpanTypes.GRAPHQL, |
| 205 | + ): |
| 206 | + return next_middleware(root, info, **args) |
| 207 | + |
| 208 | + |
| 209 | +def _inject_trace_middleware_to_args(trace_middleware, args, kwargs): |
| 210 | + # type: (Callable, Tuple, Dict) -> Tuple[Tuple, Dict] |
| 211 | + """ |
| 212 | + Adds a trace middleware to graphql.execute(..., middleware, ...) |
| 213 | + """ |
| 214 | + middlewares_arg = 8 |
| 215 | + if _graphql_version >= (3, 2): |
| 216 | + # middleware is the 10th argument graphql.execute(..) version 3.2+ |
| 217 | + middlewares_arg = 9 |
| 218 | + |
| 219 | + # get middlewares from args or kwargs |
| 220 | + try: |
| 221 | + middlewares = get_argument_value(args, kwargs, middlewares_arg, "middleware") or [] |
| 222 | + if isinstance(middlewares, MiddlewareManager): |
| 223 | + # First we must get the middlewares iterable from the MiddlewareManager then append |
| 224 | + # trace_middleware. For the trace_middleware to be called a new MiddlewareManager will |
| 225 | + # need to initialized. This is handled in graphql.execute(): |
| 226 | + # https://github.com/graphql-python/graphql-core/blob/v3.2.1/src/graphql/execution/execute.py#L254 |
| 227 | + middlewares = middlewares.middlewares # type: Iterable |
| 228 | + except ArgumentError: |
| 229 | + middlewares = [] |
| 230 | + |
| 231 | + # Note - graphql middlewares are called in reverse order |
| 232 | + # add trace_middleware to the end of the list to wrap the execution of resolver and all middlewares |
| 233 | + middlewares = list(middlewares) + [trace_middleware] |
| 234 | + |
| 235 | + # update args and kwargs to contain trace_middleware |
| 236 | + args, kwargs = set_argument_value(args, kwargs, middlewares_arg, "middleware", middlewares) |
| 237 | + return args, kwargs |
| 238 | + |
| 239 | + |
| 240 | +def _get_source_str(obj): |
| 241 | + # type: (Union[str, Source, Document]) -> str |
| 242 | + """ |
| 243 | + Parses graphql Documents and Source objects to retrieve |
| 244 | + the graphql source input for a request. |
| 245 | + """ |
| 246 | + if isinstance(obj, str): |
| 247 | + source_str = obj |
| 248 | + elif isinstance(obj, Source): |
| 249 | + source_str = obj.body |
| 250 | + elif isinstance(obj, Document): |
| 251 | + source_str = obj.loc.source.body |
| 252 | + else: |
| 253 | + source_str = "" |
| 254 | + # remove new lines, tabs and extra whitespace from source_str |
| 255 | + return re.sub(r"\s+", " ", source_str).strip() |
| 256 | + |
| 257 | + |
| 258 | +def _set_span_errors(errors, span): |
| 259 | + # type: (List[GraphQLError], Span) -> None |
| 260 | + if not errors: |
| 261 | + # do nothing if the list of graphql errors is empty |
| 262 | + return |
| 263 | + |
| 264 | + span.error = 1 |
| 265 | + exc_type_str = "%s.%s" % (GraphQLError.__module__, GraphQLError.__name__) |
| 266 | + span._set_str_tag(ERROR_TYPE, exc_type_str) |
| 267 | + error_msgs = "\n".join([stringify(error) for error in errors]) |
| 268 | + # Since we do not support adding and visualizing multiple tracebacks to one span |
| 269 | + # we will not set the error.stack tag on graphql spans. Setting only one traceback |
| 270 | + # could be misleading and might obfuscate errors. |
| 271 | + span._set_str_tag(ERROR_MSG, error_msgs) |
0 commit comments