@@ -273,161 +273,60 @@ where
273273mod test {
274274 use std:: sync:: Arc ;
275275
276+ use futures:: { future, pin_mut} ;
276277 use tokio:: sync:: Notify ;
277278 use turmoil:: net:: { TcpListener , TcpStream } ;
278- use uuid:: Uuid ;
279279
280280 use super :: * ;
281281
282282 #[ test]
283283 fn invalid_handshake ( ) {
284284 let mut sim = turmoil:: Builder :: new ( ) . build ( ) ;
285285
286- let host_node_id = NodeId :: new_v4 ( ) ;
287- sim. host ( "host" , move || async move {
288- let bus = Bus :: new ( host_node_id) ;
289- let listener = turmoil:: net:: TcpListener :: bind ( "0.0.0.0:1234" )
290- . await
291- . unwrap ( ) ;
292- let ( s, _) = listener. accept ( ) . await . unwrap ( ) ;
293- let mut connection = Connection :: new_acceptor ( s, bus) ;
294- connection. tick ( ) . await ;
295-
296- Ok ( ( ) )
286+ let host_node_id = 0 ;
287+ let done = Arc :: new ( Notify :: new ( ) ) ;
288+ let done_clone = done. clone ( ) ;
289+ sim. host ( "host" , move || {
290+ let done_clone = done_clone. clone ( ) ;
291+ async move {
292+ let bus = Arc :: new ( Bus :: new ( host_node_id, |_, _| async { } ) ) ;
293+ let listener = turmoil:: net:: TcpListener :: bind ( "0.0.0.0:1234" )
294+ . await
295+ . unwrap ( ) ;
296+ let ( s, _) = listener. accept ( ) . await . unwrap ( ) ;
297+ let connection = Connection :: new_acceptor ( s, bus) ;
298+ let done = done_clone. notified ( ) ;
299+ let run = connection. run ( ) ;
300+ pin_mut ! ( done) ;
301+ pin_mut ! ( run) ;
302+ future:: select ( run, done) . await ;
303+
304+ Ok ( ( ) )
305+ }
297306 } ) ;
298307
299308 sim. client ( "client" , async move {
300309 let s = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ;
301- let mut s = AsyncBincodeStream :: < _ , Message , Message , _ > :: from ( s) . for_async ( ) ;
302-
303- s. send ( Message :: Node ( NodeMessage :: Handshake {
304- protocol_version : 1234 ,
305- node_id : Uuid :: new_v4 ( ) ,
306- } ) )
307- . await
308- . unwrap ( ) ;
310+ let mut s = AsyncBincodeStream :: < _ , Enveloppe , Enveloppe , _ > :: from ( s) . for_async ( ) ;
311+
312+ let msg = Enveloppe {
313+ database_id : None ,
314+ message : Message :: Handshake {
315+ protocol_version : 1234 ,
316+ node_id : 1 ,
317+ } ,
318+ } ;
319+ s. send ( msg) . await . unwrap ( ) ;
309320 let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
310321
311322 assert ! ( matches!(
312- m,
313- Message :: Node ( NodeMessage :: Error (
314- NodeError :: HandshakeVersionMismatch { .. }
315- ) )
323+ m. message ,
324+ Message :: Error (
325+ ProtoError :: HandshakeVersionMismatch { .. }
326+ )
316327 ) ) ;
317328
318- Ok ( ( ) )
319- } ) ;
320-
321- sim. run ( ) . unwrap ( ) ;
322- }
323-
324- #[ test]
325- fn stream_closed ( ) {
326- let mut sim = turmoil:: Builder :: new ( ) . build ( ) ;
327-
328- let database_id = DatabaseId :: new_v4 ( ) ;
329- let host_node_id = NodeId :: new_v4 ( ) ;
330- let notify = Arc :: new ( Notify :: new ( ) ) ;
331- sim. host ( "host" , {
332- let notify = notify. clone ( ) ;
333- move || {
334- let notify = notify. clone ( ) ;
335- async move {
336- let bus = Bus :: new ( host_node_id) ;
337- let mut sub = bus. subscribe ( database_id) . unwrap ( ) ;
338- let listener = turmoil:: net:: TcpListener :: bind ( "0.0.0.0:1234" )
339- . await
340- . unwrap ( ) ;
341- let ( s, _) = listener. accept ( ) . await . unwrap ( ) ;
342- let connection = Connection :: new_acceptor ( s, bus) ;
343- tokio:: task:: spawn_local ( connection. run ( ) ) ;
344- let mut streams = Vec :: new ( ) ;
345- loop {
346- tokio:: select! {
347- Some ( mut stream) = sub. next( ) => {
348- let m = stream. next( ) . await . unwrap( ) ;
349- stream. send( m) . await . unwrap( ) ;
350- streams. push( stream) ;
351- }
352- _ = notify. notified( ) => {
353- break ;
354- }
355- }
356- }
357-
358- Ok ( ( ) )
359- }
360- }
361- } ) ;
362-
363- sim. client ( "client" , async move {
364- let stream_id = StreamId :: new ( 1 ) ;
365- let node_id = NodeId :: new_v4 ( ) ;
366- let s = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ;
367- let mut s = AsyncBincodeStream :: < _ , Message , Message , _ > :: from ( s) . for_async ( ) ;
368-
369- s. send ( Message :: Node ( NodeMessage :: Handshake {
370- protocol_version : CURRENT_PROTO_VERSION ,
371- node_id,
372- } ) )
373- . await
374- . unwrap ( ) ;
375- let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
376- assert ! ( matches!( m, Message :: Node ( NodeMessage :: Handshake { .. } ) ) ) ;
377-
378- // send message to unexisting stream:
379- s. send ( Message :: Stream {
380- stream_id,
381- payload : StreamMessage :: Dummy ,
382- } )
383- . await
384- . unwrap ( ) ;
385- let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
386- assert_eq ! (
387- m,
388- Message :: Node ( NodeMessage :: Error ( NodeError :: UnknownStream ( stream_id) ) )
389- ) ;
390-
391- // open stream then send message
392- s. send ( Message :: Node ( NodeMessage :: OpenStream {
393- stream_id,
394- database_id,
395- } ) )
396- . await
397- . unwrap ( ) ;
398- s. send ( Message :: Stream {
399- stream_id,
400- payload : StreamMessage :: Dummy ,
401- } )
402- . await
403- . unwrap ( ) ;
404- let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
405- assert_eq ! (
406- m,
407- Message :: Stream {
408- stream_id,
409- payload: StreamMessage :: Dummy
410- }
411- ) ;
412-
413- s. send ( Message :: Node ( NodeMessage :: CloseStream {
414- stream_id : StreamId :: new ( 1 ) ,
415- } ) )
416- . await
417- . unwrap ( ) ;
418- s. send ( Message :: Stream {
419- stream_id,
420- payload : StreamMessage :: Dummy ,
421- } )
422- . await
423- . unwrap ( ) ;
424- let m = s. next ( ) . await . unwrap ( ) . unwrap ( ) ;
425- assert_eq ! (
426- m,
427- Message :: Node ( NodeMessage :: Error ( NodeError :: UnknownStream ( stream_id) ) )
428- ) ;
429-
430- notify. notify_waiters ( ) ;
329+ done. notify_waiters ( ) ;
431330
432331 Ok ( ( ) )
433332 } ) ;
@@ -459,7 +358,7 @@ mod test {
459358
460359 sim. client ( "client" , async move {
461360 let stream = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ;
462- let bus = Bus :: new ( NodeId :: new_v4 ( ) ) ;
361+ let bus = Arc :: new ( Bus :: new ( 1 , |_ , _| async { } ) ) ;
463362 let mut conn = Connection :: new_acceptor ( stream, bus) ;
464363
465364 notify. notify_waiters ( ) ;
@@ -473,57 +372,4 @@ mod test {
473372
474373 sim. run ( ) . unwrap ( ) ;
475374 }
476-
477- #[ test]
478- fn zero_stream_id ( ) {
479- let mut sim = turmoil:: Builder :: new ( ) . build ( ) ;
480-
481- let notify = Arc :: new ( Notify :: new ( ) ) ;
482- sim. host ( "host" , {
483- let notify = notify. clone ( ) ;
484- move || {
485- let notify = notify. clone ( ) ;
486- async move {
487- let listener = TcpListener :: bind ( "0.0.0.0:1234" ) . await . unwrap ( ) ;
488- let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
489- let ( connection_messages_sender, connection_messages) = mpsc:: channel ( 1 ) ;
490- let conn = Connection {
491- peer : Some ( NodeId :: new_v4 ( ) ) ,
492- state : ConnectionState :: Connected ,
493- conn : AsyncBincodeStream :: from ( stream) . for_async ( ) ,
494- streams : HashMap :: new ( ) ,
495- connection_messages,
496- connection_messages_sender,
497- is_initiator : false ,
498- bus : Bus :: new ( NodeId :: new_v4 ( ) ) ,
499- stream_id_allocator : StreamIdAllocator :: new ( false ) ,
500- registration : None ,
501- } ;
502-
503- conn. run ( ) . await ;
504-
505- Ok ( ( ) )
506- }
507- }
508- } ) ;
509-
510- sim. client ( "client" , async move {
511- let stream = TcpStream :: connect ( "host:1234" ) . await . unwrap ( ) ;
512- let mut stream = AsyncBincodeStream :: < _ , Message , Message , _ > :: from ( stream) . for_async ( ) ;
513-
514- stream
515- . send ( Message :: Stream {
516- stream_id : StreamId :: new_unchecked ( 0 ) ,
517- payload : StreamMessage :: Dummy ,
518- } )
519- . await
520- . unwrap ( ) ;
521-
522- assert ! ( stream. next( ) . await . is_none( ) ) ;
523-
524- Ok ( ( ) )
525- } ) ;
526-
527- sim. run ( ) . unwrap ( ) ;
528- }
529375}
0 commit comments