@@ -4,7 +4,10 @@ use futures::future::FutureExt as _;
44use spawned_rt:: tasks:: { self as rt, mpsc, oneshot, timeout, CancellationToken } ;
55use std:: { fmt:: Debug , future:: Future , panic:: AssertUnwindSafe , time:: Duration } ;
66
7- use crate :: error:: GenServerError ;
7+ use crate :: {
8+ error:: GenServerError ,
9+ tasks:: InitResult :: { NoSuccess , Success } ,
10+ } ;
811
912const DEFAULT_CALL_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
1013
@@ -120,6 +123,11 @@ pub enum CastResponse<G: GenServer> {
120123 Stop ,
121124}
122125
126+ pub enum InitResult < G : GenServer > {
127+ Success ( G ) ,
128+ NoSuccess ( G ) ,
129+ }
130+
123131pub trait GenServer : Send + Sized + Clone {
124132 type CallMsg : Clone + Send + Sized + Sync ;
125133 type CastMsg : Clone + Send + Sized + Sync ;
@@ -145,14 +153,18 @@ pub trait GenServer: Send + Sized + Clone {
145153 rx : & mut mpsc:: Receiver < GenServerInMsg < Self > > ,
146154 ) -> impl Future < Output = Result < ( ) , GenServerError > > + Send {
147155 async {
148- let init_result = self
149- . init ( handle)
150- . await
151- . inspect_err ( |err| tracing:: error!( "Initialization failed: {err:?}" ) ) ;
152-
153- let res = match init_result {
154- Ok ( new_state) => new_state. main_loop ( handle, rx) . await ,
155- Err ( _) => Err ( GenServerError :: Initialization ) ,
156+ let res = match self . init ( handle) . await {
157+ Ok ( Success ( new_state) ) => new_state. main_loop ( handle, rx) . await ,
158+ Ok ( NoSuccess ( intermediate_state) ) => {
159+ // new_state is NoSuccess, this means the initialization failed, but the error was handled
160+ // in callback. No need to report the error.
161+ // Just skip main_loop and return the state to teardown the GenServer
162+ Ok ( intermediate_state)
163+ }
164+ Err ( err) => {
165+ tracing:: error!( "Initialization failed with unhandled error: {err:?}" ) ;
166+ Err ( GenServerError :: Initialization )
167+ }
156168 } ;
157169
158170 handle. cancellation_token ( ) . cancel ( ) ;
@@ -171,8 +183,8 @@ pub trait GenServer: Send + Sized + Clone {
171183 fn init (
172184 self ,
173185 _handle : & GenServerHandle < Self > ,
174- ) -> impl Future < Output = Result < Self , Self :: Error > > + Send {
175- async { Ok ( self ) }
186+ ) -> impl Future < Output = Result < InitResult < Self > , Self :: Error > > + Send {
187+ async { Ok ( Success ( self ) ) }
176188 }
177189
178190 fn main_loop (
@@ -297,8 +309,12 @@ pub trait GenServer: Send + Sized + Clone {
297309mod tests {
298310
299311 use super :: * ;
300- use crate :: tasks:: send_after;
301- use std:: { thread, time:: Duration } ;
312+ use crate :: { messages:: Unused , tasks:: send_after} ;
313+ use std:: {
314+ sync:: { Arc , Mutex } ,
315+ thread,
316+ time:: Duration ,
317+ } ;
302318
303319 #[ derive( Clone ) ]
304320 struct BadlyBehavedTask ;
@@ -315,16 +331,16 @@ mod tests {
315331
316332 impl GenServer for BadlyBehavedTask {
317333 type CallMsg = InMessage ;
318- type CastMsg = ( ) ;
319- type OutMsg = ( ) ;
320- type Error = ( ) ;
334+ type CastMsg = Unused ;
335+ type OutMsg = Unused ;
336+ type Error = Unused ;
321337
322338 async fn handle_call (
323339 self ,
324340 _: Self :: CallMsg ,
325341 _: & GenServerHandle < Self > ,
326342 ) -> CallResponse < Self > {
327- CallResponse :: Stop ( ( ) )
343+ CallResponse :: Stop ( Unused )
328344 }
329345
330346 async fn handle_cast (
@@ -345,9 +361,9 @@ mod tests {
345361
346362 impl GenServer for WellBehavedTask {
347363 type CallMsg = InMessage ;
348- type CastMsg = ( ) ;
364+ type CastMsg = Unused ;
349365 type OutMsg = OutMsg ;
350- type Error = ( ) ;
366+ type Error = Unused ;
351367
352368 async fn handle_call (
353369 self ,
@@ -370,7 +386,7 @@ mod tests {
370386 ) -> CastResponse < Self > {
371387 self . count += 1 ;
372388 println ! ( "{:?}: good still alive" , thread:: current( ) . id( ) ) ;
373- send_after ( Duration :: from_millis ( 100 ) , handle. to_owned ( ) , ( ) ) ;
389+ send_after ( Duration :: from_millis ( 100 ) , handle. to_owned ( ) , Unused ) ;
374390 CastResponse :: NoReply ( self )
375391 }
376392 }
@@ -380,9 +396,9 @@ mod tests {
380396 let runtime = rt:: Runtime :: new ( ) . unwrap ( ) ;
381397 runtime. block_on ( async move {
382398 let mut badboy = BadlyBehavedTask . start ( ) ;
383- let _ = badboy. cast ( ( ) ) . await ;
399+ let _ = badboy. cast ( Unused ) . await ;
384400 let mut goodboy = WellBehavedTask { count : 0 } . start ( ) ;
385- let _ = goodboy. cast ( ( ) ) . await ;
401+ let _ = goodboy. cast ( Unused ) . await ;
386402 rt:: sleep ( Duration :: from_secs ( 1 ) ) . await ;
387403 let count = goodboy. call ( InMessage :: GetCount ) . await . unwrap ( ) ;
388404
@@ -400,9 +416,9 @@ mod tests {
400416 let runtime = rt:: Runtime :: new ( ) . unwrap ( ) ;
401417 runtime. block_on ( async move {
402418 let mut badboy = BadlyBehavedTask . start_blocking ( ) ;
403- let _ = badboy. cast ( ( ) ) . await ;
419+ let _ = badboy. cast ( Unused ) . await ;
404420 let mut goodboy = WellBehavedTask { count : 0 } . start ( ) ;
405- let _ = goodboy. cast ( ( ) ) . await ;
421+ let _ = goodboy. cast ( Unused ) . await ;
406422 rt:: sleep ( Duration :: from_secs ( 1 ) ) . await ;
407423 let count = goodboy. call ( InMessage :: GetCount ) . await . unwrap ( ) ;
408424
@@ -428,9 +444,9 @@ mod tests {
428444
429445 impl GenServer for SomeTask {
430446 type CallMsg = SomeTaskCallMsg ;
431- type CastMsg = ( ) ;
432- type OutMsg = ( ) ;
433- type Error = ( ) ;
447+ type CastMsg = Unused ;
448+ type OutMsg = Unused ;
449+ type Error = Unused ;
434450
435451 async fn handle_call (
436452 self ,
@@ -441,12 +457,12 @@ mod tests {
441457 SomeTaskCallMsg :: SlowOperation => {
442458 // Simulate a slow operation that will not resolve in time
443459 rt:: sleep ( TIMEOUT_DURATION * 2 ) . await ;
444- CallResponse :: Reply ( self , ( ) )
460+ CallResponse :: Reply ( self , Unused )
445461 }
446462 SomeTaskCallMsg :: FastOperation => {
447463 // Simulate a fast operation that resolves in time
448464 rt:: sleep ( TIMEOUT_DURATION / 2 ) . await ;
449- CallResponse :: Reply ( self , ( ) )
465+ CallResponse :: Reply ( self , Unused )
450466 }
451467 }
452468 }
@@ -461,12 +477,59 @@ mod tests {
461477 let result = unresolving_task
462478 . call_with_timeout ( SomeTaskCallMsg :: FastOperation , TIMEOUT_DURATION )
463479 . await ;
464- assert ! ( matches!( result, Ok ( ( ) ) ) ) ;
480+ assert ! ( matches!( result, Ok ( Unused ) ) ) ;
465481
466482 let result = unresolving_task
467483 . call_with_timeout ( SomeTaskCallMsg :: SlowOperation , TIMEOUT_DURATION )
468484 . await ;
469485 assert ! ( matches!( result, Err ( GenServerError :: CallTimeout ) ) ) ;
470486 } ) ;
471487 }
488+
489+ #[ derive( Clone ) ]
490+ struct SomeTaskThatFailsOnInit {
491+ sender_channel : Arc < Mutex < mpsc:: Receiver < u8 > > > ,
492+ }
493+
494+ impl SomeTaskThatFailsOnInit {
495+ pub fn new ( sender_channel : Arc < Mutex < mpsc:: Receiver < u8 > > > ) -> Self {
496+ Self { sender_channel }
497+ }
498+ }
499+
500+ impl GenServer for SomeTaskThatFailsOnInit {
501+ type CallMsg = Unused ;
502+ type CastMsg = Unused ;
503+ type OutMsg = Unused ;
504+ type Error = Unused ;
505+
506+ async fn init (
507+ self ,
508+ _handle : & GenServerHandle < Self > ,
509+ ) -> Result < InitResult < Self > , Self :: Error > {
510+ // Simulate an initialization failure by returning NoSuccess
511+ Ok ( NoSuccess ( self ) )
512+ }
513+
514+ async fn teardown ( self , _handle : & GenServerHandle < Self > ) -> Result < ( ) , Self :: Error > {
515+ self . sender_channel . lock ( ) . unwrap ( ) . close ( ) ;
516+ Ok ( ( ) )
517+ }
518+ }
519+
520+ #[ test]
521+ pub fn task_fails_with_intermediate_state ( ) {
522+ let runtime = rt:: Runtime :: new ( ) . unwrap ( ) ;
523+ runtime. block_on ( async move {
524+ let ( rx, tx) = mpsc:: channel :: < u8 > ( ) ;
525+ let sender_channel = Arc :: new ( Mutex :: new ( tx) ) ;
526+ let _task = SomeTaskThatFailsOnInit :: new ( sender_channel) . start ( ) ;
527+
528+ // Wait a while to ensure the task has time to run and fail
529+ rt:: sleep ( Duration :: from_secs ( 1 ) ) . await ;
530+
531+ // We assure that the teardown function has ran by checking that the receiver channel is closed
532+ assert ! ( rx. is_closed( ) )
533+ } ) ;
534+ }
472535}
0 commit comments