Skip to content

Commit df3f3c7

Browse files
feat(rig-926): make agent multi stream prompting more granular (#796)
* feat(rig-926): make agent multi stream prompting more granular * chore: clippy
1 parent c06ac72 commit df3f3c7

File tree

2 files changed

+40
-32
lines changed

2 files changed

+40
-32
lines changed

rig-core/src/agent/prompt_request/streaming.rs

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ use crate::{
1818
};
1919

2020
#[cfg(not(target_arch = "wasm32"))]
21-
type StreamingResult =
22-
Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>> + Send>>;
21+
pub type StreamingResult<R> =
22+
Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
2323

2424
#[cfg(target_arch = "wasm32")]
25-
type StreamingResult = Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>>>>;
25+
pub type StreamingResult<R> =
26+
Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
2627

2728
#[derive(Deserialize, Serialize, Debug, Clone)]
2829
#[serde(tag = "type", rename_all = "camelCase")]
29-
pub enum MultiTurnStreamItem {
30-
Text(Text),
30+
#[non_exhaustive]
31+
pub enum MultiTurnStreamItem<R> {
32+
StreamItem(StreamedAssistantContent<R>),
3133
FinalResponse(FinalResponse),
3234
}
3335

@@ -55,11 +57,9 @@ impl FinalResponse {
5557
}
5658
}
5759

58-
impl MultiTurnStreamItem {
59-
pub(crate) fn text(text: &str) -> Self {
60-
Self::Text(Text {
61-
text: text.to_string(),
62-
})
60+
impl<R> MultiTurnStreamItem<R> {
61+
pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
62+
Self::StreamItem(item)
6363
}
6464

6565
pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
@@ -151,11 +151,14 @@ where
151151
}
152152

153153
#[cfg_attr(feature = "worker", worker::send)]
154-
async fn send(self) -> StreamingResult {
154+
async fn send(self) -> StreamingResult<M::StreamingResponse> {
155155
let agent_name = self.agent.name_owned();
156156

157157
#[tracing::instrument(skip_all, fields(agent_name = agent_name))]
158-
fn inner<M, P>(req: StreamingPromptRequest<M, P>, agent_name: String) -> StreamingResult
158+
fn inner<M, P>(
159+
req: StreamingPromptRequest<M, P>,
160+
agent_name: String,
161+
) -> StreamingResult<M::StreamingResponse>
159162
where
160163
M: CompletionModel + 'static,
161164
<M as CompletionModel>::StreamingResponse: Send,
@@ -230,7 +233,7 @@ where
230233
is_text_response = true;
231234
}
232235
last_text_response.push_str(&text.text);
233-
yield Ok(MultiTurnStreamItem::text(&text.text));
236+
yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
234237
did_call_tool = false;
235238
},
236239
Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
@@ -256,25 +259,21 @@ where
256259
chat_history.write().await.push(rig::message::Message::Assistant {
257260
id: None,
258261
content: OneOrMany::one(AssistantContent::Reasoning(Reasoning {
259-
reasoning: reasoning.clone(), id
262+
reasoning: reasoning.clone(), id: id.clone()
260263
}))
261264
});
262-
let text = reasoning.into_iter().collect::<Vec<String>>().join("");
263-
yield Ok(MultiTurnStreamItem::text(&text));
265+
yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })));
264266
did_call_tool = false;
265267
},
266268
Ok(StreamedAssistantContent::Final(final_resp)) => {
269+
if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
267270
if is_text_response {
268271
if let Some(ref hook) = req.hook {
269272
hook.on_stream_completion_response_finish(&prompt, &final_resp).await;
270273
}
271-
yield Ok(MultiTurnStreamItem::text("\n"));
274+
yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
272275
is_text_response = false;
273276
}
274-
if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
275-
// Do nothing here, since at the moment the final generic is actually unreachable.
276-
// We need to implement a trait that aggregates token usage.
277-
// TODO: Add a way to aggregate token responses from the generic variant
278277
}
279278
Err(e) => {
280279
yield Err(e.into());
@@ -345,7 +344,7 @@ where
345344
<M as CompletionModel>::StreamingResponse: Send,
346345
P: PromptHook<M> + 'static,
347346
{
348-
type Output = StreamingResult; // what `.await` returns
347+
type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
349348
type IntoFuture = Pin<Box<dyn futures::Future<Output = Self::Output> + Send>>;
350349

351350
fn into_future(self) -> Self::IntoFuture {
@@ -355,23 +354,31 @@ where
355354
}
356355

357356
/// helper function to stream a completion request to stdout
358-
pub async fn stream_to_stdout(
359-
stream: &mut StreamingResult,
357+
pub async fn stream_to_stdout<R>(
358+
stream: &mut StreamingResult<R>,
360359
) -> Result<FinalResponse, std::io::Error> {
361360
let mut final_res = FinalResponse::empty();
362361
print!("Response: ");
363362
while let Some(content) = stream.next().await {
364363
match content {
365-
Ok(MultiTurnStreamItem::Text(Text { text })) => {
364+
Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text { text }))) => {
366365
print!("{text}");
367-
std::io::Write::flush(&mut std::io::stdout())?;
366+
std::io::Write::flush(&mut std::io::stdout()).unwrap();
367+
}
368+
Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Reasoning(
369+
Reasoning { reasoning, .. },
370+
))) => {
371+
let reasoning = reasoning.join("\n");
372+
print!("{reasoning}");
373+
std::io::Write::flush(&mut std::io::stdout()).unwrap();
368374
}
369375
Ok(MultiTurnStreamItem::FinalResponse(res)) => {
370376
final_res = res;
371377
}
372378
Err(err) => {
373379
eprintln!("Error: {err}");
374380
}
381+
_ => {}
375382
}
376383
}
377384

rig-core/src/cli_chatbot.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ use std::io::{self, Write};
33
use futures::StreamExt;
44

55
use crate::{
6-
agent::{Agent, prompt_request::streaming::MultiTurnStreamItem},
6+
agent::{Agent, Text, prompt_request::streaming::MultiTurnStreamItem},
77
completion::{Chat, CompletionError, CompletionModel, Message, PromptError},
8-
streaming::StreamingPrompt,
8+
streaming::{StreamedAssistantContent, StreamingPrompt},
99
};
1010

1111
/// Type-state representing an empty `agent` field in `ChatbotBuilder`
@@ -150,22 +150,23 @@ where
150150

151151
while let Some(chunk) = stream_response.next().await {
152152
match chunk {
153-
Ok(MultiTurnStreamItem::Text(s)) => {
154-
let text = s.text.as_str();
153+
Ok(MultiTurnStreamItem::StreamItem(
154+
StreamedAssistantContent::Text(Text { text }),
155+
)) => {
155156
print!("{text}");
156-
response.push_str(text);
157+
response.push_str(&text);
157158
}
158159
Ok(MultiTurnStreamItem::FinalResponse(r)) => {
159160
if self.show_usage {
160161
usage = Some(r.usage());
161162
}
162163
}
163-
164164
Err(e) => {
165165
return Err(PromptError::CompletionError(
166166
CompletionError::ResponseError(e.to_string()),
167167
));
168168
}
169+
_ => {}
169170
}
170171
}
171172

0 commit comments

Comments
 (0)