diff --git a/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/CoroutineBinding.kt b/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/CoroutineBinding.kt index 7de3cc0..129df64 100644 --- a/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/CoroutineBinding.kt +++ b/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/CoroutineBinding.kt @@ -56,7 +56,7 @@ public suspend inline fun coroutineBinding(crossinline block: suspend Cor } } } catch (ex: BindCancellationException) { - receiver.result!! + receiver.result ?: throw ex } } diff --git a/kotlin-result-coroutines/src/commonTest/kotlin/com/github/michaelbull/result/coroutines/AsyncCoroutineBindingTest.kt b/kotlin-result-coroutines/src/commonTest/kotlin/com/github/michaelbull/result/coroutines/AsyncCoroutineBindingTest.kt index a190210..171d6a4 100644 --- a/kotlin-result-coroutines/src/commonTest/kotlin/com/github/michaelbull/result/coroutines/AsyncCoroutineBindingTest.kt +++ b/kotlin-result-coroutines/src/commonTest/kotlin/com/github/michaelbull/result/coroutines/AsyncCoroutineBindingTest.kt @@ -164,4 +164,128 @@ class AsyncCoroutineBindingTest { assertTrue(yStateChange) assertFalse(zStateChange) } + + @Test + fun shouldHandleNestedBindings() = runTest { + var xStateChange = false + var yStateChange = false + var zStateChange = false + + suspend fun provideX(): Result { + delay(1) + xStateChange = true + return Ok(1) + } + + suspend fun provideXWrapped() = coroutineBinding { + provideX().bind() + } + + suspend fun provideY(): Result { + delay(20) + yStateChange = true + return Ok(1) + } + + suspend fun provideYWrapped() = coroutineBinding { + provideY().bind() + } + + suspend fun provideZ(): Result { + delay(100) + zStateChange = true + return Ok(1) + } + + suspend fun provideZWrapped() = coroutineBinding { + provideZ().bind() + } + + val dispatcherA = StandardTestDispatcher(testScheduler) + val dispatcherB = StandardTestDispatcher(testScheduler) + val dispatcherC = StandardTestDispatcher(testScheduler) + + val result: Result = coroutineBinding { + val x = async(dispatcherA) { provideXWrapped().bind() } + val y = async(dispatcherB) { provideYWrapped().bind() } + + testScheduler.advanceTimeBy(2) + testScheduler.runCurrent() + + val z = async(dispatcherC) { provideZWrapped().bind() } + + x.await() + y.await() + z.await() + } + + assertEquals( + expected = Ok(3), + actual = result + ) + + assertTrue(xStateChange) + assertTrue(yStateChange) + assertTrue(zStateChange) + } + + @Test + fun shouldHandleExceptionsWithNestedBindings() = runTest { + var xStateChange = false + var yStateChange = false + var zStateChange = false + + suspend fun provideX(): Result { + delay(1) + xStateChange = true + return Ok(1) + } + + suspend fun provideXWrapped() = coroutineBinding { + provideX().bind() + } + + suspend fun provideY(): Result { + delay(20) + yStateChange = true + return Err(BindingError.BindingErrorA) + } + + suspend fun provideYWrapped() = coroutineBinding { + provideY().bind() + } + + suspend fun provideZ(): Result { + delay(100) + zStateChange = true + return Ok(1) + } + + suspend fun provideZWrapped() = coroutineBinding { + provideZ().bind() + } + + val dispatcherA = StandardTestDispatcher(testScheduler) + val dispatcherB = StandardTestDispatcher(testScheduler) + val dispatcherC = StandardTestDispatcher(testScheduler) + + val result: Result = coroutineBinding { + val x = async(dispatcherA) { provideXWrapped().bind() } + val y = async(dispatcherB) { provideYWrapped().bind() } + + testScheduler.advanceTimeBy(2) + testScheduler.runCurrent() + + val z = async(dispatcherC) { provideZWrapped().bind() } + + x.await() + y.await() + z.await() + } + + assertEquals( + expected = Err(BindingError.BindingErrorA), + actual = result + ) + + assertTrue(xStateChange) + assertTrue(yStateChange) + assertFalse(zStateChange) + } } diff --git a/kotlin-result-coroutines/src/commonTest/kotlin/com/github/michaelbull/result/coroutines/CoroutineBindingTest.kt b/kotlin-result-coroutines/src/commonTest/kotlin/com/github/michaelbull/result/coroutines/CoroutineBindingTest.kt index 318b546..b2d62b7 100644 --- a/kotlin-result-coroutines/src/commonTest/kotlin/com/github/michaelbull/result/coroutines/CoroutineBindingTest.kt +++ b/kotlin-result-coroutines/src/commonTest/kotlin/com/github/michaelbull/result/coroutines/CoroutineBindingTest.kt @@ -12,7 +12,11 @@ import kotlin.test.assertTrue class CoroutineBindingTest { - private object BindingError + private sealed interface BindingError { + data object BindingErrorA : BindingError + data object BindingErrorB : BindingError + data object BindingErrorC : BindingError + } @Test fun returnsOkIfAllBindsSuccessful() = runTest { @@ -71,7 +75,7 @@ class CoroutineBindingTest { suspend fun provideY(): Result { delay(1) - return Err(BindingError) + return Err(BindingError.BindingErrorA) } suspend fun provideZ(): Result { @@ -87,7 +91,7 @@ class CoroutineBindingTest { } assertEquals( - expected = Err(BindingError), + expected = Err(BindingError.BindingErrorA), actual = result, ) } @@ -107,13 +111,13 @@ class CoroutineBindingTest { suspend fun provideY(): Result { delay(10) yStateChange = true - return Err(BindingError) + return Err(BindingError.BindingErrorA) } suspend fun provideZ(): Result { delay(1) zStateChange = true - return Err(BindingError) + return Err(BindingError.BindingErrorA) } val result: Result = coroutineBinding { @@ -124,7 +128,7 @@ class CoroutineBindingTest { } assertEquals( - expected = Err(BindingError), + expected = Err(BindingError.BindingErrorA), actual = result, ) @@ -142,7 +146,7 @@ class CoroutineBindingTest { suspend fun provideY(): Result { delay(1) - return Err(BindingError) + return Err(BindingError.BindingErrorA) } suspend fun provideZ(): Result { @@ -158,8 +162,32 @@ class CoroutineBindingTest { } assertEquals( - expected = Err(BindingError), + expected = Err(BindingError.BindingErrorA), actual = result, ) } + + @Test + fun shouldHandleExceptionsWithMultipleNestedBindings() = runTest { + val result: Result = coroutineBinding { + val b: Result = coroutineBinding { + val c: Result = coroutineBinding { + Err(BindingError.BindingErrorC).bind() + } + + assertEquals(Err(BindingError.BindingErrorC), c) + + Ok(2).bind() + } + + assertEquals(Ok(2), b) + + Err(BindingError.BindingErrorB).bind() + } + + assertEquals( + expected = Err(BindingError.BindingErrorB), + actual = result + ) + } }