@@ -168,9 +168,24 @@ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
168168import logging
169169import re
170170from functools import wraps
171- from typing import TYPE_CHECKING , Collection , Dict , List , Optional , Pattern , Type
171+ from typing import (
172+ TYPE_CHECKING ,
173+ Any ,
174+ Callable ,
175+ Collection ,
176+ Dict ,
177+ Generator ,
178+ Iterable ,
179+ List ,
180+ Optional ,
181+ Pattern ,
182+ Type ,
183+ TypeVar ,
184+ Union ,
185+ )
172186
173187import attr
188+ from typing_extensions import ParamSpec
174189
175190from twisted .internet import defer
176191from twisted .web .http import Request
@@ -256,7 +271,7 @@ class _WrappedRustReporter(BaseReporter):
256271 def set_process (self , * args , ** kwargs ):
257272 return self ._reporter .set_process (* args , ** kwargs )
258273
259- def report_span (self , span ) :
274+ def report_span (self , span : "opentracing.Span" ) -> None :
260275 try :
261276 return self ._reporter .report_span (span )
262277 except Exception :
@@ -307,15 +322,19 @@ class SynapseBaggage:
307322Sentinel = object ()
308323
309324
310- def only_if_tracing (func ):
325+ P = ParamSpec ("P" )
326+ R = TypeVar ("R" )
327+
328+
329+ def only_if_tracing (func : Callable [P , R ]) -> Callable [P , Optional [R ]]:
311330 """Executes the function only if we're tracing. Otherwise returns None."""
312331
313332 @wraps (func )
314- def _only_if_tracing_inner (* args , ** kwargs ) :
333+ def _only_if_tracing_inner (* args : P . args , ** kwargs : P . kwargs ) -> Optional [ R ] :
315334 if opentracing :
316335 return func (* args , ** kwargs )
317336 else :
318- return
337+ return None
319338
320339 return _only_if_tracing_inner
321340
@@ -356,17 +375,10 @@ def ensure_active_span_inner_2(*args, **kwargs):
356375 return ensure_active_span_inner_1
357376
358377
359- @contextlib .contextmanager
360- def noop_context_manager (* args , ** kwargs ):
361- """Does exactly what it says on the tin"""
362- # TODO: replace with contextlib.nullcontext once we drop support for Python 3.6
363- yield
364-
365-
366378# Setup
367379
368380
369- def init_tracer (hs : "HomeServer" ):
381+ def init_tracer (hs : "HomeServer" ) -> None :
370382 """Set the whitelists and initialise the JaegerClient tracer"""
371383 global opentracing
372384 if not hs .config .tracing .opentracer_enabled :
@@ -408,11 +420,11 @@ def init_tracer(hs: "HomeServer"):
408420
409421
410422@only_if_tracing
411- def set_homeserver_whitelist (homeserver_whitelist ) :
423+ def set_homeserver_whitelist (homeserver_whitelist : Iterable [ str ]) -> None :
412424 """Sets the homeserver whitelist
413425
414426 Args:
415- homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers
427+ homeserver_whitelist: regexes specifying whitelisted homeservers
416428 """
417429 global _homeserver_whitelist
418430 if homeserver_whitelist :
@@ -423,15 +435,15 @@ def set_homeserver_whitelist(homeserver_whitelist):
423435
424436
425437@only_if_tracing
426- def whitelisted_homeserver (destination ) :
438+ def whitelisted_homeserver (destination : str ) -> bool :
427439 """Checks if a destination matches the whitelist
428440
429441 Args:
430- destination (str)
442+ destination
431443 """
432444
433445 if _homeserver_whitelist :
434- return _homeserver_whitelist .match (destination )
446+ return _homeserver_whitelist .match (destination ) is not None
435447 return False
436448
437449
@@ -457,11 +469,11 @@ def start_active_span(
457469 Args:
458470 See opentracing.tracer
459471 Returns:
460- scope (Scope) or noop_context_manager
472+ scope (Scope) or contextlib.nullcontext
461473 """
462474
463475 if opentracing is None :
464- return noop_context_manager () # type: ignore[unreachable]
476+ return contextlib . nullcontext () # type: ignore[unreachable]
465477
466478 if tracer is None :
467479 # use the global tracer by default
@@ -505,7 +517,7 @@ def start_active_span_follows_from(
505517 tracer: override the opentracing tracer. By default the global tracer is used.
506518 """
507519 if opentracing is None :
508- return noop_context_manager () # type: ignore[unreachable]
520+ return contextlib . nullcontext () # type: ignore[unreachable]
509521
510522 references = [opentracing .follows_from (context ) for context in contexts ]
511523 scope = start_active_span (
@@ -525,27 +537,27 @@ def start_active_span_follows_from(
525537
526538
527539def start_active_span_from_edu (
528- edu_content ,
529- operation_name ,
530- references : Optional [list ] = None ,
531- tags = None ,
532- start_time = None ,
533- ignore_active_span = False ,
534- finish_on_close = True ,
535- ):
540+ edu_content : Dict [ str , Any ] ,
541+ operation_name : str ,
542+ references : Optional [List [ "opentracing.Reference" ] ] = None ,
543+ tags : Optional [ Dict ] = None ,
544+ start_time : Optional [ float ] = None ,
545+ ignore_active_span : bool = False ,
546+ finish_on_close : bool = True ,
547+ ) -> "opentracing.Scope" :
536548 """
537549 Extracts a span context from an edu and uses it to start a new active span
538550
539551 Args:
540- edu_content (dict): and edu_content with a `context` field whose value is
552+ edu_content: an edu_content with a `context` field whose value is
541553 canonical json for a dict which contains opentracing information.
542554
543555 For the other args see opentracing.tracer
544556 """
545557 references = references or []
546558
547559 if opentracing is None :
548- return noop_context_manager () # type: ignore[unreachable]
560+ return contextlib . nullcontext () # type: ignore[unreachable]
549561
550562 carrier = json_decoder .decode (edu_content .get ("context" , "{}" )).get (
551563 "opentracing" , {}
@@ -578,27 +590,27 @@ def start_active_span_from_edu(
578590
579591# Opentracing setters for tags, logs, etc
580592@only_if_tracing
581- def active_span ():
593+ def active_span () -> Optional [ "opentracing.Span" ] :
582594 """Get the currently active span, if any"""
583595 return opentracing .tracer .active_span
584596
585597
586598@ensure_active_span ("set a tag" )
587- def set_tag (key , value ) :
599+ def set_tag (key : str , value : Union [ str , bool , int , float ]) -> None :
588600 """Sets a tag on the active span"""
589601 assert opentracing .tracer .active_span is not None
590602 opentracing .tracer .active_span .set_tag (key , value )
591603
592604
593605@ensure_active_span ("log" )
594- def log_kv (key_values , timestamp = None ):
606+ def log_kv (key_values : Dict [ str , Any ], timestamp : Optional [ float ] = None ) -> None :
595607 """Log to the active span"""
596608 assert opentracing .tracer .active_span is not None
597609 opentracing .tracer .active_span .log_kv (key_values , timestamp )
598610
599611
600612@ensure_active_span ("set the traces operation name" )
601- def set_operation_name (operation_name ) :
613+ def set_operation_name (operation_name : str ) -> None :
602614 """Sets the operation name of the active span"""
603615 assert opentracing .tracer .active_span is not None
604616 opentracing .tracer .active_span .set_operation_name (operation_name )
@@ -624,7 +636,9 @@ def force_tracing(span=Sentinel) -> None:
624636 span .set_baggage_item (SynapseBaggage .FORCE_TRACING , "1" )
625637
626638
627- def is_context_forced_tracing (span_context ) -> bool :
639+ def is_context_forced_tracing (
640+ span_context : Optional ["opentracing.SpanContext" ],
641+ ) -> bool :
628642 """Check if sampling has been force for the given span context."""
629643 if span_context is None :
630644 return False
@@ -696,13 +710,13 @@ def inject_response_headers(response_headers: Headers) -> None:
696710
697711
698712@ensure_active_span ("get the active span context as a dict" , ret = {})
699- def get_active_span_text_map (destination = None ):
713+ def get_active_span_text_map (destination : Optional [ str ] = None ) -> Dict [ str , str ] :
700714 """
701715 Gets a span context as a dict. This can be used instead of manually
702716 injecting a span into an empty carrier.
703717
704718 Args:
705- destination (str) : the name of the remote server.
719+ destination: the name of the remote server.
706720
707721 Returns:
708722 dict: the active span's context if opentracing is enabled, otherwise empty.
@@ -721,7 +735,7 @@ def get_active_span_text_map(destination=None):
721735
722736
723737@ensure_active_span ("get the span context as a string." , ret = {})
724- def active_span_context_as_string ():
738+ def active_span_context_as_string () -> str :
725739 """
726740 Returns:
727741 The active span context encoded as a string.
@@ -750,21 +764,21 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon
750764
751765
752766@only_if_tracing
753- def span_context_from_string (carrier ) :
767+ def span_context_from_string (carrier : str ) -> Optional [ "opentracing.SpanContext" ] :
754768 """
755769 Returns:
756770 The active span context decoded from a string.
757771 """
758- carrier = json_decoder .decode (carrier )
759- return opentracing .tracer .extract (opentracing .Format .TEXT_MAP , carrier )
772+ payload : Dict [ str , str ] = json_decoder .decode (carrier )
773+ return opentracing .tracer .extract (opentracing .Format .TEXT_MAP , payload )
760774
761775
762776@only_if_tracing
763- def extract_text_map (carrier ) :
777+ def extract_text_map (carrier : Dict [ str , str ]) -> Optional [ "opentracing.SpanContext" ] :
764778 """
765779 Wrapper method for opentracing's tracer.extract for TEXT_MAP.
766780 Args:
767- carrier (dict) : a dict possibly containing a span context.
781+ carrier: a dict possibly containing a span context.
768782
769783 Returns:
770784 The active span context extracted from carrier.
@@ -843,7 +857,7 @@ def err_back(result):
843857 return decorator
844858
845859
846- def tag_args (func ) :
860+ def tag_args (func : Callable [ P , R ]) -> Callable [ P , R ] :
847861 """
848862 Tags all of the args to the active span.
849863 """
@@ -852,19 +866,21 @@ def tag_args(func):
852866 return func
853867
854868 @wraps (func )
855- def _tag_args_inner (* args , ** kwargs ) :
869+ def _tag_args_inner (* args : P . args , ** kwargs : P . kwargs ) -> R :
856870 argspec = inspect .getfullargspec (func )
857871 for i , arg in enumerate (argspec .args [1 :]):
858- set_tag ("ARG_" + arg , args [i ])
859- set_tag ("args" , args [len (argspec .args ) :])
872+ set_tag ("ARG_" + arg , args [i ]) # type: ignore[index]
873+ set_tag ("args" , args [len (argspec .args ) :]) # type: ignore[index]
860874 set_tag ("kwargs" , kwargs )
861875 return func (* args , ** kwargs )
862876
863877 return _tag_args_inner
864878
865879
866880@contextlib .contextmanager
867- def trace_servlet (request : "SynapseRequest" , extract_context : bool = False ):
881+ def trace_servlet (
882+ request : "SynapseRequest" , extract_context : bool = False
883+ ) -> Generator [None , None , None ]:
868884 """Returns a context manager which traces a request. It starts a span
869885 with some servlet specific tags such as the request metrics name and
870886 request information.
0 commit comments