Skip to content

Commit 3839223

Browse files
committed
Add context function to CoRouterFunctionDsl
This new function allows to customize the CoroutineContext potentially dynamically based on the incoming ServerRequest. Closes gh-27010
1 parent 64ff37f commit 3839223

File tree

2 files changed

+101
-15
lines changed

2 files changed

+101
-15
lines changed

spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import kotlinx.coroutines.Job
2121
import kotlinx.coroutines.currentCoroutineContext
2222
import kotlinx.coroutines.reactor.awaitSingle
2323
import kotlinx.coroutines.reactor.mono
24+
import kotlinx.coroutines.withContext
2425
import org.springframework.core.io.Resource
2526
import org.springframework.http.HttpMethod
2627
import org.springframework.http.HttpStatusCode
@@ -72,6 +73,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
7273
@PublishedApi
7374
internal val builder = RouterFunctions.route()
7475

76+
private var contextProvider: (suspend (ServerRequest) -> CoroutineContext)? = null
77+
7578
/**
7679
* Return a composed request predicate that tests against both this predicate AND
7780
* the [other] predicate (String processed as a path predicate). When evaluating the
@@ -510,9 +513,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
510513
*/
511514
fun resources(lookupFunction: suspend (ServerRequest) -> Resource?) {
512515
builder.resources {
513-
mono(Dispatchers.Unconfined) {
514-
lookupFunction.invoke(it)
515-
}
516+
asMono(it, handler = lookupFunction)
516517
}
517518
}
518519

@@ -534,7 +535,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
534535
*/
535536
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
536537
builder.filter { serverRequest, handlerFunction ->
537-
mono(Dispatchers.Unconfined) {
538+
asMono(serverRequest) {
538539
filterFunction(serverRequest) { handlerRequest ->
539540
if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
540541
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
@@ -578,7 +579,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
578579
*/
579580
fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
580581
builder.onError(predicate) { throwable, request ->
581-
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
582+
asMono(request) { responseProvider.invoke(throwable, request) }
582583
}
583584
}
584585

@@ -591,7 +592,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
591592
*/
592593
inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
593594
builder.onError({it is E}) { throwable, request ->
594-
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
595+
asMono(request) { responseProvider.invoke(throwable, request) }
595596
}
596597
}
597598

@@ -619,6 +620,19 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
619620
builder.withAttributes(attributesConsumer)
620621
}
621622

623+
/**
624+
* Allow to provide the default [CoroutineContext], potentially dynamically based on
625+
* the incoming [ServerRequest].
626+
* @param provider the [CoroutineContext] provider
627+
* @since 6.1.0
628+
*/
629+
fun context(provider: suspend (ServerRequest) -> CoroutineContext) {
630+
if (this.contextProvider != null) {
631+
throw IllegalStateException("The Coroutine context provider should be defined not more than once")
632+
}
633+
this.contextProvider = provider
634+
}
635+
622636
/**
623637
* Return a composed routing function created from all the registered routes.
624638
*/
@@ -627,8 +641,22 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
627641
return builder.build()
628642
}
629643

630-
private fun <T : ServerResponse> asHandlerFunction(handler: suspend (ServerRequest) -> T) =
631-
CoroutineContextAwareHandlerFunction(handler)
644+
@PublishedApi
645+
internal fun <T> asMono(request: ServerRequest, context: CoroutineContext = Dispatchers.Unconfined, handler: suspend (ServerRequest) -> T): Mono<T> {
646+
return mono(context) {
647+
contextProvider?.let {
648+
withContext(it.invoke(request)) {
649+
handler.invoke(request)
650+
}
651+
} ?: run {
652+
handler.invoke(request)
653+
}
654+
}
655+
}
656+
657+
private fun asHandlerFunction(handler: suspend (ServerRequest) -> ServerResponse) = CoroutineContextAwareHandlerFunction { request ->
658+
handler.invoke(request)
659+
}
632660

633661
/**
634662
* @see ServerResponse.from
@@ -698,15 +726,15 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
698726
fun status(status: Int) = ServerResponse.status(status)
699727

700728

701-
private class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
729+
private inner class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
702730
private val handler: suspend (ServerRequest) -> T
703731
) : HandlerFunction<T> {
704732

705733
override fun handle(request: ServerRequest): Mono<T> {
706734
return handle(Dispatchers.Unconfined, request)
707735
}
708736

709-
fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) {
737+
fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) {
710738
handler(request)
711739
}
712740

spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@
1616

1717
package org.springframework.web.reactive.function.server
1818

19-
import kotlinx.coroutines.CoroutineName
20-
import kotlinx.coroutines.currentCoroutineContext
21-
import kotlinx.coroutines.withContext
22-
import org.assertj.core.api.Assertions.assertThat
23-
import org.assertj.core.api.Assertions.assertThatExceptionOfType
19+
import kotlinx.coroutines.*
20+
import org.assertj.core.api.Assertions.*
2421
import org.junit.jupiter.api.Test
2522
import org.springframework.core.io.ClassPathResource
2623
import org.springframework.http.HttpHeaders.ACCEPT
@@ -179,6 +176,48 @@ class CoRouterFunctionDslTests {
179176
.verifyComplete()
180177
}
181178

179+
@Test
180+
fun contextProvider() {
181+
val mockRequest = get("https://example.com/")
182+
.header("Custom-Header", "foo")
183+
.build()
184+
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
185+
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
186+
.expectNextMatches { response ->
187+
response.headers().getFirst("context")!!.contains("foo")
188+
}
189+
.verifyComplete()
190+
}
191+
192+
@Test
193+
fun contextProviderAndFilter() {
194+
val mockRequest = get("https://example.com/")
195+
.header("Custom-Header", "bar")
196+
.build()
197+
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
198+
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
199+
.expectNextMatches { response ->
200+
response.headers().getFirst("context")!!.let {
201+
it.contains("bar") && it.contains("Dispatchers.Default")
202+
}
203+
}
204+
.verifyComplete()
205+
}
206+
207+
@Test
208+
fun multipleContextProviders() {
209+
assertThatIllegalStateException().isThrownBy {
210+
coRouter {
211+
context {
212+
CoroutineName("foo")
213+
}
214+
context {
215+
Dispatchers.Default
216+
}
217+
}
218+
}
219+
}
220+
182221
@Test
183222
fun attributes() {
184223
val visitor = AttributesTestVisitor()
@@ -251,6 +290,25 @@ class CoRouterFunctionDslTests {
251290
}
252291
}
253292

293+
private val routerWithContextProvider = coRouter {
294+
context {
295+
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
296+
}
297+
GET("/") {
298+
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
299+
}
300+
filter { request, next ->
301+
if (request.headers().firstHeader("Custom-Header") == "bar") {
302+
withContext(currentCoroutineContext() + Dispatchers.Default) {
303+
next.invoke(request)
304+
}
305+
}
306+
else {
307+
next.invoke(request)
308+
}
309+
}
310+
}
311+
254312
private val otherRouter = router {
255313
"/other" {
256314
ok().build()

0 commit comments

Comments
 (0)