@@ -18,16 +18,18 @@ use crate::{
1818} ;
1919
2020#[ cfg( not( target_arch = "wasm32" ) ) ]
21- type StreamingResult =
22- Pin < Box < dyn Stream < Item = Result < MultiTurnStreamItem , StreamingError > > + Send > > ;
21+ pub type StreamingResult < R > =
22+ Pin < Box < dyn Stream < Item = Result < MultiTurnStreamItem < R > , StreamingError > > + Send > > ;
2323
2424#[ cfg( target_arch = "wasm32" ) ]
25- type StreamingResult = Pin < Box < dyn Stream < Item = Result < MultiTurnStreamItem , StreamingError > > > > ;
25+ pub type StreamingResult < R > =
26+ Pin < Box < dyn Stream < Item = Result < MultiTurnStreamItem < R > , StreamingError > > > > ;
2627
2728#[ derive( Deserialize , Serialize , Debug , Clone ) ]
2829#[ serde( tag = "type" , rename_all = "camelCase" ) ]
29- pub enum MultiTurnStreamItem {
30- Text ( Text ) ,
30+ #[ non_exhaustive]
31+ pub enum MultiTurnStreamItem < R > {
32+ StreamItem ( StreamedAssistantContent < R > ) ,
3133 FinalResponse ( FinalResponse ) ,
3234}
3335
@@ -55,11 +57,9 @@ impl FinalResponse {
5557 }
5658}
5759
58- impl MultiTurnStreamItem {
59- pub ( crate ) fn text ( text : & str ) -> Self {
60- Self :: Text ( Text {
61- text : text. to_string ( ) ,
62- } )
60+ impl < R > MultiTurnStreamItem < R > {
61+ pub ( crate ) fn stream_item ( item : StreamedAssistantContent < R > ) -> Self {
62+ Self :: StreamItem ( item)
6363 }
6464
6565 pub fn final_response ( response : & str , aggregated_usage : crate :: completion:: Usage ) -> Self {
@@ -151,11 +151,14 @@ where
151151 }
152152
153153 #[ cfg_attr( feature = "worker" , worker:: send) ]
154- async fn send ( self ) -> StreamingResult {
154+ async fn send ( self ) -> StreamingResult < M :: StreamingResponse > {
155155 let agent_name = self . agent . name_owned ( ) ;
156156
157157 #[ tracing:: instrument( skip_all, fields( agent_name = agent_name) ) ]
158- fn inner < M , P > ( req : StreamingPromptRequest < M , P > , agent_name : String ) -> StreamingResult
158+ fn inner < M , P > (
159+ req : StreamingPromptRequest < M , P > ,
160+ agent_name : String ,
161+ ) -> StreamingResult < M :: StreamingResponse >
159162 where
160163 M : CompletionModel + ' static ,
161164 <M as CompletionModel >:: StreamingResponse : Send ,
@@ -230,7 +233,7 @@ where
230233 is_text_response = true ;
231234 }
232235 last_text_response. push_str( & text. text) ;
233- yield Ok ( MultiTurnStreamItem :: text ( & text. text ) ) ;
236+ yield Ok ( MultiTurnStreamItem :: stream_item ( StreamedAssistantContent :: Text ( text) ) ) ;
234237 did_call_tool = false ;
235238 } ,
236239 Ok ( StreamedAssistantContent :: ToolCall ( tool_call) ) => {
@@ -256,25 +259,21 @@ where
256259 chat_history. write( ) . await . push( rig:: message:: Message :: Assistant {
257260 id: None ,
258261 content: OneOrMany :: one( AssistantContent :: Reasoning ( Reasoning {
259- reasoning: reasoning. clone( ) , id
262+ reasoning: reasoning. clone( ) , id: id . clone ( )
260263 } ) )
261264 } ) ;
262- let text = reasoning. into_iter( ) . collect:: <Vec <String >>( ) . join( "" ) ;
263- yield Ok ( MultiTurnStreamItem :: text( & text) ) ;
265+ yield Ok ( MultiTurnStreamItem :: stream_item( StreamedAssistantContent :: Reasoning ( rig:: message:: Reasoning { reasoning, id } ) ) ) ;
264266 did_call_tool = false ;
265267 } ,
266268 Ok ( StreamedAssistantContent :: Final ( final_resp) ) => {
269+ if let Some ( usage) = final_resp. token_usage( ) { aggregated_usage += usage; } ;
267270 if is_text_response {
268271 if let Some ( ref hook) = req. hook {
269272 hook. on_stream_completion_response_finish( & prompt, & final_resp) . await ;
270273 }
271- yield Ok ( MultiTurnStreamItem :: text ( " \n " ) ) ;
274+ yield Ok ( MultiTurnStreamItem :: stream_item ( StreamedAssistantContent :: Final ( final_resp ) ) ) ;
272275 is_text_response = false ;
273276 }
274- if let Some ( usage) = final_resp. token_usage( ) { aggregated_usage += usage; } ;
275- // Do nothing here, since at the moment the final generic is actually unreachable.
276- // We need to implement a trait that aggregates token usage.
277- // TODO: Add a way to aggregate token responses from the generic variant
278277 }
279278 Err ( e) => {
280279 yield Err ( e. into( ) ) ;
@@ -345,7 +344,7 @@ where
345344 <M as CompletionModel >:: StreamingResponse : Send ,
346345 P : PromptHook < M > + ' static ,
347346{
348- type Output = StreamingResult ; // what `.await` returns
347+ type Output = StreamingResult < M :: StreamingResponse > ; // what `.await` returns
349348 type IntoFuture = Pin < Box < dyn futures:: Future < Output = Self :: Output > + Send > > ;
350349
351350 fn into_future ( self ) -> Self :: IntoFuture {
@@ -355,23 +354,31 @@ where
355354}
356355
357356/// helper function to stream a completion request to stdout
358- pub async fn stream_to_stdout (
359- stream : & mut StreamingResult ,
357+ pub async fn stream_to_stdout < R > (
358+ stream : & mut StreamingResult < R > ,
360359) -> Result < FinalResponse , std:: io:: Error > {
361360 let mut final_res = FinalResponse :: empty ( ) ;
362361 print ! ( "Response: " ) ;
363362 while let Some ( content) = stream. next ( ) . await {
364363 match content {
365- Ok ( MultiTurnStreamItem :: Text ( Text { text } ) ) => {
364+ Ok ( MultiTurnStreamItem :: StreamItem ( StreamedAssistantContent :: Text ( Text { text } ) ) ) => {
366365 print ! ( "{text}" ) ;
367- std:: io:: Write :: flush ( & mut std:: io:: stdout ( ) ) ?;
366+ std:: io:: Write :: flush ( & mut std:: io:: stdout ( ) ) . unwrap ( ) ;
367+ }
368+ Ok ( MultiTurnStreamItem :: StreamItem ( StreamedAssistantContent :: Reasoning (
369+ Reasoning { reasoning, .. } ,
370+ ) ) ) => {
371+ let reasoning = reasoning. join ( "\n " ) ;
372+ print ! ( "{reasoning}" ) ;
373+ std:: io:: Write :: flush ( & mut std:: io:: stdout ( ) ) . unwrap ( ) ;
368374 }
369375 Ok ( MultiTurnStreamItem :: FinalResponse ( res) ) => {
370376 final_res = res;
371377 }
372378 Err ( err) => {
373379 eprintln ! ( "Error: {err}" ) ;
374380 }
381+ _ => { }
375382 }
376383 }
377384
0 commit comments