@@ -652,6 +652,8 @@ def __call__(self, func: Union[Type[T2], "Callable[P, Awaitable[Any]]", "Callabl
652652 return self .decorate_class (func )
653653 elif inspect .iscoroutinefunction (func ):
654654 return self .decorate_coroutine (func )
655+ elif inspect .isgeneratorfunction (func ):
656+ return self .decorate_generator_function (func ) # type: ignore
655657 return self .decorate_callable (func ) # type: ignore
656658
657659 def decorate_class (self , klass : Type [T2 ]) -> Type [T2 ]:
@@ -914,20 +916,33 @@ def stop(self) -> None:
914916 def decorate_coroutine (self , coroutine : "Callable[P, Awaitable[T]]" ) -> "Callable[P, Awaitable[T]]" :
915917 return wrap_coroutine (self , coroutine )
916918
919+ def _call_with_time_factory (self , time_factory : Union [StepTickTimeFactory , TickingDateTimeFactory , FrozenDateTimeFactory ], func : "Callable[P, T]" , * args : "P.args" , ** kwargs : "P.kwargs" ) -> T :
920+ if self .as_arg and self .as_kwarg :
921+ assert False , "You can't specify both as_arg and as_kwarg at the same time. Pick one."
922+ if self .as_arg :
923+ result = func (time_factory , * args , ** kwargs ) # type: ignore
924+ if self .as_kwarg :
925+ kwargs [self .as_kwarg ] = time_factory
926+ result = func (* args , ** kwargs )
927+ else :
928+ result = func (* args , ** kwargs )
929+ return result
930+
931+ def decorate_generator_function (self , func : "Callable[P, Iterator[T]]" ) -> "Callable[P, Iterator[T]]" :
932+
933+ @functools .wraps (func )
934+ def wrapper (* args : "P.args" , ** kwargs : "P.kwargs" ) -> Iterator [T ]:
935+ with self as time_factory :
936+ yield from self ._call_with_time_factory (time_factory , func , * args , ** kwargs )
937+
938+ return wrapper
939+
917940 def decorate_callable (self , func : "Callable[P, T]" ) -> "Callable[P, T]" :
941+
918942 @functools .wraps (func )
919943 def wrapper (* args : "P.args" , ** kwargs : "P.kwargs" ) -> T :
920944 with self as time_factory :
921- if self .as_arg and self .as_kwarg :
922- assert False , "You can't specify both as_arg and as_kwarg at the same time. Pick one."
923- elif self .as_arg :
924- result = func (time_factory , * args , ** kwargs ) # type: ignore
925- elif self .as_kwarg :
926- kwargs [self .as_kwarg ] = time_factory
927- result = func (* args , ** kwargs )
928- else :
929- result = func (* args , ** kwargs )
930- return result
945+ return self ._call_with_time_factory (time_factory , func , * args , ** kwargs )
931946
932947 return wrapper
933948
0 commit comments