@@ -21,6 +21,7 @@ import kotlinx.coroutines.Job
21
21
import kotlinx.coroutines.currentCoroutineContext
22
22
import kotlinx.coroutines.reactor.awaitSingle
23
23
import kotlinx.coroutines.reactor.mono
24
+ import kotlinx.coroutines.withContext
24
25
import org.springframework.core.io.Resource
25
26
import org.springframework.http.HttpMethod
26
27
import org.springframework.http.HttpStatusCode
@@ -72,6 +73,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
72
73
@PublishedApi
73
74
internal val builder = RouterFunctions .route()
74
75
76
+ private var contextProvider: (suspend (ServerRequest ) -> CoroutineContext )? = null
77
+
75
78
/* *
76
79
* Return a composed request predicate that tests against both this predicate AND
77
80
* the [other] predicate (String processed as a path predicate). When evaluating the
@@ -510,9 +513,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
510
513
*/
511
514
fun resources (lookupFunction : suspend (ServerRequest ) -> Resource ? ) {
512
515
builder.resources {
513
- mono(Dispatchers .Unconfined ) {
514
- lookupFunction.invoke(it)
515
- }
516
+ asMono(it, handler = lookupFunction)
516
517
}
517
518
}
518
519
@@ -534,7 +535,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
534
535
*/
535
536
fun filter (filterFunction : suspend (ServerRequest , suspend (ServerRequest ) -> ServerResponse ) -> ServerResponse ) {
536
537
builder.filter { serverRequest, handlerFunction ->
537
- mono( Dispatchers . Unconfined ) {
538
+ asMono(serverRequest ) {
538
539
filterFunction(serverRequest) { handlerRequest ->
539
540
if (handlerFunction is CoroutineContextAwareHandlerFunction <* >) {
540
541
handlerFunction.handle(currentCoroutineContext().minusKey(Job .Key ), handlerRequest).awaitSingle()
@@ -578,7 +579,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
578
579
*/
579
580
fun onError (predicate : (Throwable ) -> Boolean , responseProvider : suspend (Throwable , ServerRequest ) -> ServerResponse ) {
580
581
builder.onError(predicate) { throwable, request ->
581
- mono( Dispatchers . Unconfined ) { responseProvider.invoke(throwable, request) }
582
+ asMono(request ) { responseProvider.invoke(throwable, request) }
582
583
}
583
584
}
584
585
@@ -591,7 +592,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
591
592
*/
592
593
inline fun <reified E : Throwable > onError (noinline responseProvider : suspend (Throwable , ServerRequest ) -> ServerResponse ) {
593
594
builder.onError({it is E }) { throwable, request ->
594
- mono( Dispatchers . Unconfined ) { responseProvider.invoke(throwable, request) }
595
+ asMono(request ) { responseProvider.invoke(throwable, request) }
595
596
}
596
597
}
597
598
@@ -619,6 +620,19 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
619
620
builder.withAttributes(attributesConsumer)
620
621
}
621
622
623
+ /* *
624
+ * Allow to provide the default [CoroutineContext], potentially dynamically based on
625
+ * the incoming [ServerRequest].
626
+ * @param provider the [CoroutineContext] provider
627
+ * @since 6.1.0
628
+ */
629
+ fun context (provider : suspend (ServerRequest ) -> CoroutineContext ) {
630
+ if (this .contextProvider != null ) {
631
+ throw IllegalStateException (" The Coroutine context provider should be defined not more than once" )
632
+ }
633
+ this .contextProvider = provider
634
+ }
635
+
622
636
/* *
623
637
* Return a composed routing function created from all the registered routes.
624
638
*/
@@ -627,8 +641,22 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
627
641
return builder.build()
628
642
}
629
643
630
- private fun <T : ServerResponse > asHandlerFunction (handler : suspend (ServerRequest ) -> T ) =
631
- CoroutineContextAwareHandlerFunction (handler)
644
+ @PublishedApi
645
+ internal fun <T > asMono (request : ServerRequest , context : CoroutineContext = Dispatchers .Unconfined , handler : suspend (ServerRequest ) -> T ): Mono <T > {
646
+ return mono(context) {
647
+ contextProvider?.let {
648
+ withContext(it.invoke(request)) {
649
+ handler.invoke(request)
650
+ }
651
+ } ? : run {
652
+ handler.invoke(request)
653
+ }
654
+ }
655
+ }
656
+
657
+ private fun asHandlerFunction (handler : suspend (ServerRequest ) -> ServerResponse ) = CoroutineContextAwareHandlerFunction { request ->
658
+ handler.invoke(request)
659
+ }
632
660
633
661
/* *
634
662
* @see ServerResponse.from
@@ -698,15 +726,15 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
698
726
fun status (status : Int ) = ServerResponse .status(status)
699
727
700
728
701
- private class CoroutineContextAwareHandlerFunction <T : ServerResponse >(
729
+ private inner class CoroutineContextAwareHandlerFunction <T : ServerResponse >(
702
730
private val handler : suspend (ServerRequest ) -> T
703
731
) : HandlerFunction<T> {
704
732
705
733
override fun handle (request : ServerRequest ): Mono <T > {
706
734
return handle(Dispatchers .Unconfined , request)
707
735
}
708
736
709
- fun handle (context : CoroutineContext , request : ServerRequest ) = mono( context) {
737
+ fun handle (context : CoroutineContext , request : ServerRequest ) = asMono(request, context) {
710
738
handler(request)
711
739
}
712
740
0 commit comments