Skip to content

Commit 0c51b90

Browse files
committed
feat: capture call site coroutine context into call options
1 parent 54e048d commit 0c51b90

File tree

3 files changed

+192
-13
lines changed

3 files changed

+192
-13
lines changed

compiler/src/main/java/io/grpc/kotlin/generator/GrpcClientStubGenerator.kt

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import com.squareup.kotlinpoet.AnnotationSpec
2323
import com.squareup.kotlinpoet.CodeBlock
2424
import com.squareup.kotlinpoet.FunSpec
2525
import com.squareup.kotlinpoet.KModifier
26+
import com.squareup.kotlinpoet.MemberName
2627
import com.squareup.kotlinpoet.ParameterSpec
2728
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
2829
import com.squareup.kotlinpoet.TypeName
@@ -48,6 +49,7 @@ import io.grpc.kotlin.generator.protoc.methodName
4849
import io.grpc.kotlin.generator.protoc.of
4950
import io.grpc.kotlin.generator.protoc.serviceName
5051
import kotlinx.coroutines.flow.Flow
52+
import kotlin.coroutines.CoroutineContext
5153
import io.grpc.Channel as GrpcChannel
5254
import io.grpc.Metadata as GrpcMetadata
5355

@@ -62,6 +64,10 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co
6264
private val STREAMING_PARAMETER_NAME = MemberSimpleName("requests")
6365
private val GRPC_CHANNEL_PARAMETER_NAME = MemberSimpleName("channel")
6466
private val CALL_OPTIONS_PARAMETER_NAME = MemberSimpleName("callOptions")
67+
private val WITH_COROUTINE_CONTEXT_FUN_NAME = MemberName(ClientCalls::class.java.`package`.name, "withCoroutineContext")
68+
private val COROUTINE_CONTEXT_VAL_NAME = MemberName(CoroutineContext::class.java.`package`.name, "coroutineContext")
69+
private val FLOW_FUN_NAME = MemberName(Flow::class.java.`package`.name, "flow")
70+
private val EMIT_ALL_FUN_NAME = MemberName(Flow::class.java.`package`.name, "emitAll")
6571

6672
private val HEADERS_PARAMETER: ParameterSpec = ParameterSpec
6773
.builder("headers", GrpcMetadata::class)
@@ -94,6 +100,9 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co
94100
} else {
95101
if (isServerStreaming) MethodType.SERVER_STREAMING else MethodType.UNARY
96102
}
103+
104+
private val MethodDescriptor.isSuspendable: Boolean
105+
get() = !isServerStreaming
97106
}
98107

99108
override fun generate(service: ServiceDescriptor): Declarations = declarations {
@@ -189,28 +198,39 @@ class GrpcClientStubGenerator(config: GeneratorConfig) : ServiceCodeGenerator(co
189198
)
190199
}
191200

192-
val codeBlockMap = mapOf(
193-
"helperMethod" to helperMethod,
194-
"methodDescriptor" to method.descriptorCode,
195-
"parameter" to parameter,
196-
"headers" to HEADERS_PARAMETER
197-
)
201+
val codeBlockMap = buildMap {
202+
this["helperMethod"] = helperMethod
203+
this["methodDescriptor"] = method.descriptorCode
204+
this["parameter"] = parameter
205+
this["headers"] = HEADERS_PARAMETER
206+
this["withContext"] = WITH_COROUTINE_CONTEXT_FUN_NAME
207+
this["coroutineContext"] = COROUTINE_CONTEXT_VAL_NAME
208+
if (!method.isSuspendable) {
209+
this["flow"] = FLOW_FUN_NAME
210+
this["emitAll"] = EMIT_ALL_FUN_NAME
211+
}
212+
}
198213

199-
if (!method.isServerStreaming) {
214+
if (method.isSuspendable) {
200215
funSpecBuilder.addModifiers(KModifier.SUSPEND)
201216
}
202217

203-
funSpecBuilder.addNamedCode(
204-
"""
205-
return %helperMethod:M(
218+
val helperCall = """
219+
%helperMethod:M(
206220
channel,
207221
%methodDescriptor:L,
208222
%parameter:N,
209-
callOptions,
223+
callOptions.%withContext:M(%coroutineContext:M),
210224
%headers:N
211225
)
212-
""".trimIndent(),
213-
codeBlockMap
226+
""".trimIndent()
227+
funSpecBuilder.addNamedCode(
228+
if (method.isSuspendable) {
229+
"return $helperCall"
230+
} else {
231+
"return \n%flow:M {\n⇥%emitAll:M(\n$helperCall\n⇤)\n⇤}"
232+
},
233+
codeBlockMap,
214234
)
215235
return funSpecBuilder.build()
216236
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package io.grpc.kotlin
2+
3+
import io.grpc.CallOptions
4+
import kotlin.coroutines.CoroutineContext
5+
import kotlin.coroutines.EmptyCoroutineContext
6+
7+
private val COROUTINE_CONTEXT_OPTION: CallOptions.Key<CoroutineContext> =
8+
CallOptions.Key.createWithDefault("Coroutine context", EmptyCoroutineContext)
9+
10+
/**
11+
* Sets a coroutine context.
12+
*
13+
* @param context coroutine context to put into the call options
14+
* @return [CallOptions] instance with coroutine context
15+
*/
16+
fun CallOptions.withCoroutineContext(context: CoroutineContext): CallOptions =
17+
withOption(COROUTINE_CONTEXT_OPTION, context)
18+
19+
/**
20+
* Gets a coroutine context from the call options.
21+
*
22+
* Default: [EmptyCoroutineContext]
23+
*/
24+
val CallOptions.coroutineContext: CoroutineContext
25+
get() = getOption(COROUTINE_CONTEXT_OPTION)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package io.grpc.kotlin
2+
3+
import com.google.common.truth.Truth.assertThat
4+
import com.google.common.truth.extensions.proto.ProtoTruth
5+
import io.grpc.CallOptions
6+
import io.grpc.Channel
7+
import io.grpc.ClientCall
8+
import io.grpc.ClientInterceptor
9+
import io.grpc.ClientInterceptors
10+
import io.grpc.MethodDescriptor
11+
import io.grpc.examples.helloworld.GreeterGrpcKt
12+
import io.grpc.examples.helloworld.HelloRequest
13+
import io.grpc.examples.helloworld.MultiHelloRequest
14+
import kotlinx.coroutines.flow.Flow
15+
import kotlinx.coroutines.flow.first
16+
import kotlinx.coroutines.flow.flowOf
17+
import kotlinx.coroutines.flow.map
18+
import kotlinx.coroutines.withContext
19+
import org.junit.Test
20+
import org.junit.runner.RunWith
21+
import org.junit.runners.JUnit4
22+
import java.util.UUID
23+
import kotlin.coroutines.CoroutineContext
24+
25+
@RunWith(JUnit4::class)
26+
class ClientCallOptionsCoroutineContextPropagationTest : AbstractCallsTest() {
27+
28+
@Test
29+
fun `should capture coroutine context with unary call`() {
30+
val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() {
31+
override suspend fun sayHello(request: HelloRequest) = helloReply("Hello, ${request.name}!")
32+
}
33+
val interceptor = CoroutineContextCapturingInterceptor()
34+
val contextElement = DummyCoroutineContextElement()
35+
val channel = ClientInterceptors.intercept(makeChannel(server), interceptor)
36+
val stub = GreeterGrpcKt.GreeterCoroutineStub(channel)
37+
38+
runBlocking {
39+
withContext(contextElement) {
40+
ProtoTruth.assertThat(stub.sayHello(helloRequest("Steven")))
41+
.isEqualTo(helloReply("Hello, Steven!"))
42+
}
43+
}
44+
assertThat(interceptor.coroutineContext).isNotNull()
45+
assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement)
46+
}
47+
48+
@Test
49+
fun `should capture coroutine context with client streaming`() {
50+
val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() {
51+
override suspend fun clientStreamSayHello(requests: Flow<HelloRequest>) = requests.map { request ->
52+
helloReply("Hello, ${request.name}!")
53+
}.first()
54+
}
55+
val interceptor = CoroutineContextCapturingInterceptor()
56+
val contextElement = DummyCoroutineContextElement()
57+
val channel = ClientInterceptors.intercept(makeChannel(server), interceptor)
58+
val stub = GreeterGrpcKt.GreeterCoroutineStub(channel)
59+
60+
runBlocking {
61+
withContext(contextElement) {
62+
ProtoTruth.assertThat(stub.clientStreamSayHello(flowOf(helloRequest("Steven"))))
63+
.isEqualTo(helloReply("Hello, Steven!"))
64+
}
65+
}
66+
assertThat(interceptor.coroutineContext).isNotNull()
67+
assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement)
68+
}
69+
70+
@Test
71+
fun `should capture coroutine context with server streaming`() {
72+
val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() {
73+
override fun serverStreamSayHello(request: MultiHelloRequest) = flowOf(
74+
helloReply("Hello, ${request.nameList.joinToString()}!")
75+
)
76+
}
77+
val interceptor = CoroutineContextCapturingInterceptor()
78+
val contextElement = DummyCoroutineContextElement()
79+
val channel = ClientInterceptors.intercept(makeChannel(server), interceptor)
80+
val stub = GreeterGrpcKt.GreeterCoroutineStub(channel)
81+
82+
runBlocking {
83+
withContext(contextElement) {
84+
ProtoTruth.assertThat(stub.serverStreamSayHello(multiHelloRequest("Steven", "Andrew")).first())
85+
.isEqualTo(helloReply("Hello, Steven, Andrew!"))
86+
}
87+
}
88+
assertThat(interceptor.coroutineContext).isNotNull()
89+
assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement)
90+
}
91+
92+
@Test
93+
fun `should capture coroutine context with bidi streaming`() {
94+
val server = object : GreeterGrpcKt.GreeterCoroutineImplBase() {
95+
override fun bidiStreamSayHello(requests: Flow<HelloRequest>) = requests.map { request ->
96+
helloReply("Hello, ${request.name}!")
97+
}
98+
}
99+
val interceptor = CoroutineContextCapturingInterceptor()
100+
val contextElement = DummyCoroutineContextElement()
101+
val channel = ClientInterceptors.intercept(makeChannel(server), interceptor)
102+
val stub = GreeterGrpcKt.GreeterCoroutineStub(channel)
103+
104+
runBlocking {
105+
withContext(contextElement) {
106+
ProtoTruth.assertThat(stub.bidiStreamSayHello(flowOf(helloRequest("Steven"))).first())
107+
.isEqualTo(helloReply("Hello, Steven!"))
108+
}
109+
}
110+
assertThat(interceptor.coroutineContext).isNotNull()
111+
assertThat(interceptor.coroutineContext!![DummyCoroutineContextElement]).isEqualTo(contextElement)
112+
}
113+
}
114+
115+
private data class DummyCoroutineContextElement(val value: UUID = UUID.randomUUID()) : CoroutineContext.Element {
116+
override val key: CoroutineContext.Key<*> = Key
117+
118+
companion object Key : CoroutineContext.Key<DummyCoroutineContextElement>
119+
}
120+
121+
private class CoroutineContextCapturingInterceptor : ClientInterceptor {
122+
123+
var coroutineContext: CoroutineContext? = null
124+
125+
override fun <ReqT : Any?, RespT : Any?> interceptCall(
126+
method: MethodDescriptor<ReqT, RespT>,
127+
callOptions: CallOptions,
128+
next: Channel,
129+
): ClientCall<ReqT, RespT> {
130+
coroutineContext = callOptions.coroutineContext
131+
132+
return next.newCall(method, callOptions)
133+
}
134+
}

0 commit comments

Comments
 (0)