From 8405775c78ab0420b1e856e6e4a4e7ac82d30fbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Fri, 23 Jan 2026 20:38:17 +0800 Subject: [PATCH 1/2] feat: implement exponential backoff retry for provider stream errors --- crates/anthropic/src/client/messages.rs | 96 ++++++-- crates/coco-tui/src/components/chat.rs | 71 +++++- .../coco-tui/src/components/messages/combo.rs | 6 + crates/coco-tui/src/events.rs | 4 + crates/openai/src/client.rs | 73 +++++- src/agent.rs | 222 ++++++++++++++---- src/error.rs | 25 ++ src/lib.rs | 2 + src/provider.rs | 17 +- src/provider/openai.rs | 15 +- src/stream_error.rs | 48 ++++ src/tools/run_combo.rs | 103 +++++--- src/tools/run_task.rs | 27 ++- 13 files changed, 565 insertions(+), 144 deletions(-) create mode 100644 src/stream_error.rs diff --git a/crates/anthropic/src/client/messages.rs b/crates/anthropic/src/client/messages.rs index 36af7b0..4645235 100644 --- a/crates/anthropic/src/client/messages.rs +++ b/crates/anthropic/src/client/messages.rs @@ -17,6 +17,42 @@ use tracing::{trace, warn}; use crate::{Block, Message, RetryAttempt, RetryConfig, Role, Tool}; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamErrorKind { + Transport, + Decode, +} + +#[derive(Debug, Clone)] +pub struct StreamError { + pub kind: StreamErrorKind, + pub message: String, +} + +impl StreamError { + fn transport(context: &'static str, err: impl std::fmt::Display) -> Self { + Self { + kind: StreamErrorKind::Transport, + message: format!("{context}: {err}"), + } + } + + fn decode(context: &'static str, err: impl std::fmt::Display) -> Self { + Self { + kind: StreamErrorKind::Decode, + message: format!("{context}: {err}"), + } + } +} + +impl std::fmt::Display for StreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for StreamError {} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ToolChoice { @@ -377,7 +413,7 @@ impl SseEventStream { Some(line) } - fn process_line(&mut self, line: &[u8]) -> Result<(), Whatever> { + fn process_line(&mut self, line: &[u8]) -> std::result::Result<(), StreamError> { if line.is_empty() { self.flush_event(); return Ok(()); @@ -387,8 +423,8 @@ impl SseEventStream { } if let Some(rest) = line.strip_prefix(b"event:") { let rest = trim_leading_space(rest); - let name = - std::str::from_utf8(rest).whatever_context("stream event name is not utf-8")?; + let name = std::str::from_utf8(rest) + .map_err(|err| StreamError::decode("stream event name is not utf-8", err))?; self.event = Some(name.to_string()); return Ok(()); } @@ -413,7 +449,7 @@ impl SseEventStream { } impl Stream for SseEventStream { - type Item = Result; + type Item = std::result::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -427,10 +463,10 @@ impl Stream for SseEventStream { match this.inner.as_mut().poll_next(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Some(Err(err))) => { - let message = format!("read stream chunk error: {err}"); - return Poll::Ready(Some(Err( - ::without_source(message), - ))); + return Poll::Ready(Some(Err(StreamError::transport( + "read stream chunk error", + err, + )))); } Poll::Ready(Some(Ok(chunk))) => { this.buffer.extend_from_slice(&chunk); @@ -479,7 +515,7 @@ impl MessagesStream { } impl Stream for MessagesStream { - type Item = Result; + type Item = std::result::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -500,13 +536,16 @@ fn trim_leading_space(bytes: &[u8]) -> &[u8] { } } -fn parse_messages_stream_event(event: SseEvent) -> Result { +fn parse_messages_stream_event( + event: SseEvent, +) -> std::result::Result { let data = if event.data.is_empty() { Value::Object(Map::new()) } else { - let text = - std::str::from_utf8(&event.data).whatever_context("stream event data is not utf-8")?; - serde_json::from_str::(text).whatever_context("decode stream event data")? + let text = std::str::from_utf8(&event.data) + .map_err(|err| StreamError::decode("stream event data is not utf-8", err))?; + serde_json::from_str::(text) + .map_err(|err| StreamError::decode("decode stream event data", err))? }; let event_type = data .get("type") @@ -516,15 +555,15 @@ fn parse_messages_stream_event(event: SseEvent) -> Result { - let event: MessageStartEvent = - serde_json::from_value(data).whatever_context("decode message_start event")?; + let event: MessageStartEvent = serde_json::from_value(data) + .map_err(|err| StreamError::decode("decode message_start event", err))?; Ok(MessagesStreamEvent::MessageStart { message: event.message, }) } "content_block_start" => { let event: ContentBlockStartEvent = serde_json::from_value(data) - .whatever_context("decode content_block_start event")?; + .map_err(|err| StreamError::decode("decode content_block_start event", err))?; Ok(MessagesStreamEvent::ContentBlockStart { index: event.index, content_block: event.content_block, @@ -532,20 +571,20 @@ fn parse_messages_stream_event(event: SseEvent) -> Result { let event: ContentBlockDeltaEvent = serde_json::from_value(data) - .whatever_context("decode content_block_delta event")?; + .map_err(|err| StreamError::decode("decode content_block_delta event", err))?; Ok(MessagesStreamEvent::ContentBlockDelta { index: event.index, delta: event.delta, }) } "content_block_stop" => { - let event: ContentBlockStopEvent = - serde_json::from_value(data).whatever_context("decode content_block_stop event")?; + let event: ContentBlockStopEvent = serde_json::from_value(data) + .map_err(|err| StreamError::decode("decode content_block_stop event", err))?; Ok(MessagesStreamEvent::ContentBlockStop { index: event.index }) } "message_delta" => { - let event: MessageDeltaEvent = - serde_json::from_value(data).whatever_context("decode message_delta event")?; + let event: MessageDeltaEvent = serde_json::from_value(data) + .map_err(|err| StreamError::decode("decode message_delta event", err))?; Ok(MessagesStreamEvent::MessageDelta { delta: event.delta.unwrap_or_default(), usage: event.usage, @@ -554,8 +593,8 @@ fn parse_messages_stream_event(event: SseEvent) -> Result Ok(MessagesStreamEvent::MessageStop), "ping" => Ok(MessagesStreamEvent::Ping), "error" => { - let event: StreamErrorEvent = - serde_json::from_value(data).whatever_context("decode error event")?; + let event: StreamErrorEvent = serde_json::from_value(data) + .map_err(|err| StreamError::decode("decode error event", err))?; Ok(MessagesStreamEvent::Error { error: event.error }) } _ => Ok(MessagesStreamEvent::Unknown { @@ -789,7 +828,7 @@ impl AnthropicRequestError { match self { Self::Transport { source, .. } => source.is_timeout() || source.is_connect(), Self::HttpStatus { status, .. } => is_retryable_status(*status), - Self::Decode { .. } => false, + Self::Decode { .. } => true, } } @@ -850,6 +889,15 @@ mod retry_tests { assert!(is_retryable_status(StatusCode::SERVICE_UNAVAILABLE)); assert!(!is_retryable_status(StatusCode::BAD_REQUEST)); } + + #[test] + fn retryable_decode_errors() { + let err = AnthropicRequestError::Decode { + context: "messages", + message: "bad json".to_string(), + }; + assert!(err.is_retryable()); + } } #[cfg(test)] diff --git a/crates/coco-tui/src/components/chat.rs b/crates/coco-tui/src/components/chat.rs index 2f5e743..fa1e3f9 100644 --- a/crates/coco-tui/src/components/chat.rs +++ b/crates/coco-tui/src/components/chat.rs @@ -422,7 +422,8 @@ impl Chat<'static> { | ComboEvent::RecordStart { .. } | ComboEvent::RecordOutput { .. } | ComboEvent::RecordEnd { .. } - | ComboEvent::PromptStream { .. } => { + | ComboEvent::PromptStream { .. } + | ComboEvent::PromptStreamReset { .. } => { self.set_processing(); } ComboEvent::Prompt { thinking, .. } => { @@ -646,6 +647,10 @@ impl Chat<'static> { }, text: text.clone(), }, + ComboToolEvent::PromptStreamReset { name } => ComboEvent::PromptStreamReset { + id: id.to_string(), + name: name.clone(), + }, ComboToolEvent::ReplyToolUse { name, tool_use, @@ -2278,19 +2283,39 @@ async fn task_chat(mut agent: Agent, content: ChatContent, cancel_token: Cancell let thinking_seen_stream = thinking_seen.clone(); let resp = agent .chat_stream(msg, cancel_token.clone(), move |update| { - let (index, kind, text) = match update { + match update { + ChatStreamUpdate::Reset => { + plain_seen_stream.store(false, Ordering::Relaxed); + thinking_seen_stream.store(false, Ordering::Relaxed); + stream_tx.send(AnswerEvent::BotStreamReset.into()).ok(); + } ChatStreamUpdate::Plain { index, text } => { plain_seen_stream.store(true, Ordering::Relaxed); - (index, BotStreamKind::Plain, text) + stream_tx + .send( + AnswerEvent::BotStream { + index, + kind: BotStreamKind::Plain, + text, + } + .into(), + ) + .ok(); } ChatStreamUpdate::Thinking { index, text } => { thinking_seen_stream.store(true, Ordering::Relaxed); - (index, BotStreamKind::Thinking, text) + stream_tx + .send( + AnswerEvent::BotStream { + index, + kind: BotStreamKind::Thinking, + text, + } + .into(), + ) + .ok(); } }; - stream_tx - .send(AnswerEvent::BotStream { index, kind, text }.into()) - .ok(); }) .await; streamed_plain = plain_seen.load(Ordering::Relaxed); @@ -2349,19 +2374,39 @@ async fn task_chat_with_history(mut agent: Agent, cancel_token: CancellationToke let thinking_seen_stream = thinking_seen.clone(); let resp = agent .chat_stream_with_history(cancel_token.clone(), move |update| { - let (index, kind, text) = match update { + match update { + ChatStreamUpdate::Reset => { + plain_seen_stream.store(false, Ordering::Relaxed); + thinking_seen_stream.store(false, Ordering::Relaxed); + stream_tx.send(AnswerEvent::BotStreamReset.into()).ok(); + } ChatStreamUpdate::Plain { index, text } => { plain_seen_stream.store(true, Ordering::Relaxed); - (index, BotStreamKind::Plain, text) + stream_tx + .send( + AnswerEvent::BotStream { + index, + kind: BotStreamKind::Plain, + text, + } + .into(), + ) + .ok(); } ChatStreamUpdate::Thinking { index, text } => { thinking_seen_stream.store(true, Ordering::Relaxed); - (index, BotStreamKind::Thinking, text) + stream_tx + .send( + AnswerEvent::BotStream { + index, + kind: BotStreamKind::Thinking, + text, + } + .into(), + ) + .ok(); } }; - stream_tx - .send(AnswerEvent::BotStream { index, kind, text }.into()) - .ok(); }) .await; streamed_plain = plain_seen.load(Ordering::Relaxed); diff --git a/crates/coco-tui/src/components/messages/combo.rs b/crates/coco-tui/src/components/messages/combo.rs index 95a7184..b74e953 100644 --- a/crates/coco-tui/src/components/messages/combo.rs +++ b/crates/coco-tui/src/components/messages/combo.rs @@ -353,6 +353,12 @@ impl Combo { self.push_prompt(prompt); } } + ComboEvent::PromptStreamReset { id, .. } => { + if self.matches_id(id) { + self.messages.finalize_stream(); + self.messages.reset_stream(); + } + } ComboEvent::PromptStream { id, index, diff --git a/crates/coco-tui/src/events.rs b/crates/coco-tui/src/events.rs index df6c3bc..2cf47f8 100644 --- a/crates/coco-tui/src/events.rs +++ b/crates/coco-tui/src/events.rs @@ -132,6 +132,10 @@ pub enum ComboEvent { kind: BotStreamKind, text: String, }, + PromptStreamReset { + id: String, + name: String, + }, /// Reply tool use from prompt, with optional offload via bash ReplyToolUse { id: String, diff --git a/crates/openai/src/client.rs b/crates/openai/src/client.rs index 7f7cee2..7e05666 100644 --- a/crates/openai/src/client.rs +++ b/crates/openai/src/client.rs @@ -35,6 +35,42 @@ pub struct RetryAttempt { pub type RetryNotifier = Arc; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamErrorKind { + Transport, + Decode, +} + +#[derive(Debug, Clone)] +pub struct StreamError { + pub kind: StreamErrorKind, + pub message: String, +} + +impl StreamError { + fn transport(context: &'static str, err: impl std::fmt::Display) -> Self { + Self { + kind: StreamErrorKind::Transport, + message: format!("{context}: {err}"), + } + } + + fn decode(context: &'static str, err: impl std::fmt::Display) -> Self { + Self { + kind: StreamErrorKind::Decode, + message: format!("{context}: {err}"), + } + } +} + +impl std::fmt::Display for StreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for StreamError {} + #[derive(Clone)] pub struct RetryConfig { pub max_attempts: usize, @@ -280,7 +316,7 @@ impl OpenAIRequestError { match self { Self::Transport { source, .. } => source.is_timeout() || source.is_connect(), Self::HttpStatus { status, .. } => is_retryable_status(*status), - Self::Decode { .. } => false, + Self::Decode { .. } => true, } } @@ -366,7 +402,7 @@ impl SseEventStream { Some(line) } - fn process_line(&mut self, line: &[u8]) -> Result<()> { + fn process_line(&mut self, line: &[u8]) -> std::result::Result<(), StreamError> { if line.is_empty() { self.flush_event(); return Ok(()); @@ -391,7 +427,7 @@ impl SseEventStream { } impl Stream for SseEventStream { - type Item = Result; + type Item = std::result::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -405,10 +441,10 @@ impl Stream for SseEventStream { match Pin::new(&mut this.inner).poll_next(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Some(Err(err))) => { - let message = format!("read stream chunk error: {err}"); - return Poll::Ready(Some(Err( - ::without_source(message), - ))); + return Poll::Ready(Some(Err(StreamError::transport( + "read stream chunk error", + err, + )))); } Poll::Ready(Some(Ok(chunk))) => { this.buffer.extend_from_slice(&chunk); @@ -449,7 +485,7 @@ impl ChatCompletionStream { } impl Stream for ChatCompletionStream { - type Item = Result; + type Item = std::result::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -483,18 +519,20 @@ fn trim_leading_space(bytes: &[u8]) -> &[u8] { } } -fn parse_chat_completion_chunk(event: SseEvent) -> Result> { +fn parse_chat_completion_chunk( + event: SseEvent, +) -> std::result::Result, StreamError> { if event.data.is_empty() { return Ok(None); } - let text = - std::str::from_utf8(&event.data).whatever_context("stream event data is not utf-8")?; + let text = std::str::from_utf8(&event.data) + .map_err(|err| StreamError::decode("stream event data is not utf-8", err))?; trace!(%text, "openai stream event"); if text.trim() == "[DONE]" { return Ok(None); } - let chunk: ChatCompletionChunk = - serde_json::from_str(text).whatever_context("decode chat completion chunk")?; + let chunk: ChatCompletionChunk = serde_json::from_str(text) + .map_err(|err| StreamError::decode("decode chat completion chunk", err))?; Ok(Some(chunk)) } @@ -509,4 +547,13 @@ mod tests { assert!(is_retryable_status(StatusCode::SERVICE_UNAVAILABLE)); assert!(!is_retryable_status(StatusCode::BAD_REQUEST)); } + + #[test] + fn retryable_decode_errors() { + let err = OpenAIRequestError::Decode { + context: "chat completions", + message: "bad json".to_string(), + }; + assert!(err.is_retryable()); + } } diff --git a/src/agent.rs b/src/agent.rs index 069a65c..58ac1b7 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -4,7 +4,7 @@ //! - [`Agent`] - The main agent struct for chat interactions //! - [`AgentConfig`] - Configuration types for customizing agents -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use crate::provider::{ Block, Client, ContentBlockDelta, MessagesStreamEvent, Role, Thinking, ToolChoice, @@ -17,8 +17,8 @@ use tokio_util::sync::CancellationToken; use tracing::warn; use crate::{ - Config, PromptSchema, ProviderConfig, RequestOptions, Result, ResultDisplayExt, - ThinkingBlocksMode, ThinkingConfig, + Config, Error, PromptSchema, ProviderConfig, RequestOptions, Result, ResultDisplayExt, + RetryAttempt as CoreRetryAttempt, RetryUpdate, StreamError, ThinkingBlocksMode, ThinkingConfig, tools::{ComboInfo, RunComboContext, RunComboTool, RunTaskContext, RunTaskTool}, }; use executor::PermissionControl; @@ -78,6 +78,7 @@ pub struct ChatResponse { pub enum ChatStreamUpdate { Plain { index: usize, text: String }, Thinking { index: usize, text: String }, + Reset, } pub struct PromptReply { @@ -99,6 +100,8 @@ struct StreamAccumulator { usage: Option, } +const STREAM_RETRY_BASE_DELAY_MS: u64 = 200; + impl StreamAccumulator { fn new() -> Self { Self { @@ -294,6 +297,53 @@ impl StreamAccumulator { } } +fn should_retry_stream_error(err: &StreamError) -> bool { + err.is_retryable() +} + +fn stream_retry_delay(attempt: usize, max_delay: Duration) -> Duration { + let shift = attempt.saturating_sub(1).min(30) as u32; + let multiplier = 1u64 << shift; + let delay_ms = STREAM_RETRY_BASE_DELAY_MS.saturating_mul(multiplier); + let mut delay = Duration::from_millis(delay_ms); + if max_delay != Duration::from_millis(0) && delay > max_delay { + delay = max_delay; + } + delay +} + +fn notify_stream_retry_attempt( + request_options: &RequestOptions, + attempt: usize, + delay: Duration, + err: &StreamError, +) { + if let Some(notifier) = &request_options.retry_notifier { + notifier.notify(RetryUpdate::Attempt(CoreRetryAttempt { + attempt, + max_attempts: request_options.retry_max_attempts, + delay, + error: err.to_string(), + })); + } +} + +fn notify_stream_retry_finished(request_options: &RequestOptions, success: bool) { + if let Some(notifier) = &request_options.retry_notifier { + notifier.notify(RetryUpdate::Finished { success }); + } +} + +async fn wait_for_retry(delay: Duration, cancel_token: &CancellationToken) -> bool { + if delay == Duration::from_millis(0) { + return true; + } + tokio::select! { + _ = cancel_token.cancelled() => false, + _ = tokio::time::sleep(delay) => true, + } +} + impl Agent { pub fn new(config: Config) -> Self { let mut executor = Executor::default(); @@ -663,55 +713,115 @@ impl Agent { messages.clone() }; let messages = self.prepare_messages_for_request(messages, request_options); - let tools = self.provider_tools_for_request(request_options); + let max_attempts = request_options.retry_max_attempts; + let max_delay = Duration::from_millis(request_options.retry_max_delay_ms); + let mut attempt = 0usize; + let mut retried = false; - let mut stream = client - .messages_stream( - Some(&self.system_prompt), - messages, - tools, - thinking, - request_options, - ) - .await - .inspect_err(|err| { - warn!("send messsages stream error: {err:?}"); - }) - .whatever_context_display("failed to send messages stream")?; + loop { + let tools = self.provider_tools_for_request(request_options); + let stream_result = client + .messages_stream( + Some(&self.system_prompt), + messages.clone(), + tools, + thinking.clone(), + request_options, + ) + .await + .inspect_err(|err| { + warn!("send messsages stream error: {err:?}"); + }) + .whatever_context_display("failed to send messages stream"); + + let mut stream = match stream_result { + Ok(stream) => stream, + Err(err) => { + if retried { + notify_stream_retry_finished(request_options, false); + } + return Err(err); + } + }; - let mut accumulator = StreamAccumulator::new(); - while let Some(event) = tokio::select! { - _ = cancel_token.cancelled() => { - whatever!("chat stream cancelled"); + let mut accumulator = StreamAccumulator::new(); + let mut stream_error: Option = None; + loop { + let event = tokio::select! { + _ = cancel_token.cancelled() => { + if retried { + notify_stream_retry_finished(request_options, false); + } + whatever!("chat stream cancelled"); + } + event = stream.next() => event, + }; + let Some(event) = event else { + break; + }; + let event = match event { + Ok(event) => event, + Err(err) => { + stream_error = Some(err.with_context("read messages stream error")); + break; + } + }; + let action = match accumulator.handle_event(event, &mut on_update) { + Ok(action) => action, + Err(err) => { + stream_error = Some(StreamError::decode(format!( + "parse messages stream error: {err}" + ))); + break; + } + }; + if matches!(action, StreamAction::Stop) { + break; + } } - event = stream.next() => event, - } { - let event = event.whatever_context_display("read messages stream error")?; - let action = accumulator - .handle_event(event, &mut on_update) - .whatever_context("parse messages stream error")?; - if matches!(action, StreamAction::Stop) { - break; + + if let Some(err) = stream_error { + if should_retry_stream_error(&err) && attempt < max_attempts { + attempt += 1; + retried = true; + let delay = stream_retry_delay(attempt, max_delay); + notify_stream_retry_attempt(request_options, attempt, delay, &err); + on_update(ChatStreamUpdate::Reset); + if !wait_for_retry(delay, &cancel_token).await { + if retried { + notify_stream_retry_finished(request_options, false); + } + whatever!("chat stream cancelled"); + } + continue; + } + if retried { + notify_stream_retry_finished(request_options, false); + } + return Err(Error::stream(err.kind, err.message)); } - } - let (blocks, stop_reason, usage) = accumulator.finish(); - let message = if blocks.is_empty() { - Message::assistant(Content::Multiple(Vec::default())) - } else { - let mut msg = Message::assistant(Content::Multiple(blocks)); - if request_options.stringify_nested_tool_inputs { - parse_stringified_tool_inputs_in_message(&mut msg, &self.executor); + let (blocks, stop_reason, usage) = accumulator.finish(); + let message = if blocks.is_empty() { + Message::assistant(Content::Multiple(Vec::default())) + } else { + let mut msg = Message::assistant(Content::Multiple(blocks)); + if request_options.stringify_nested_tool_inputs { + parse_stringified_tool_inputs_in_message(&mut msg, &self.executor); + } + self.messages.lock().await.push(msg.clone()); + msg + }; + self.mark_thinking_cleanup_pending(stop_reason.as_ref()); + if retried { + notify_stream_retry_finished(request_options, true); } - self.messages.lock().await.push(msg.clone()); - msg - }; - self.mark_thinking_cleanup_pending(stop_reason.as_ref()); - Ok(ChatResponse { - message, - stop_reason, - usage, - }) + return Ok(ChatResponse { + message, + stop_reason, + usage, + }); + } } pub async fn reply_prompt( @@ -918,7 +1028,13 @@ impl Agent { } event = stream.next() => event, } { - let event = event.whatever_context_display("read prompt reply stream error")?; + let event = match event { + Ok(event) => event, + Err(err) => { + let err = err.with_context("read prompt reply stream error"); + return Err(Error::stream(err.kind, err.message)); + } + }; let action = accumulator .handle_event(event, &mut on_update) .whatever_context("parse prompt reply stream error")?; @@ -1692,6 +1808,18 @@ mod tests { } } + #[test] + fn retry_stream_error_heuristics() { + let err = StreamError::transport("read stream chunk error: broken pipe".to_string()); + assert!(should_retry_stream_error(&err)); + + let err = StreamError::decode("decode stream event data".to_string()); + assert!(!should_retry_stream_error(&err)); + + let err = StreamError::decode("chat stream cancelled".to_string()); + assert!(!should_retry_stream_error(&err)); + } + #[test] fn tool_choice_fallback_requires_disable() { let mut options = RequestOptions { diff --git a/src/error.rs b/src/error.rs index dc288af..868ac0c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,7 @@ use snafu::prelude::*; +use crate::StreamErrorKind; + #[derive(Debug, Snafu)] #[snafu(visibility(pub))] pub enum Error { @@ -10,6 +12,12 @@ pub enum Error { source: Option>, backtrace: snafu::Backtrace, }, + #[snafu(display("{message}"))] + Stream { + kind: StreamErrorKind, + message: String, + backtrace: snafu::Backtrace, + }, } pub type Result = std::result::Result; @@ -29,3 +37,20 @@ where }) } } + +impl Error { + pub fn stream(kind: StreamErrorKind, message: impl Into) -> Self { + StreamSnafu { + kind, + message: message.into(), + } + .build() + } + + pub fn stream_kind(&self) -> Option { + match self { + Error::Stream { kind, .. } => Some(*kind), + _ => None, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index ff8ebf5..2b7fdee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ mod mcp; mod provider; mod retry; mod runtime_overrides; +mod stream_error; mod text_edit; pub mod tools; pub mod version; @@ -26,6 +27,7 @@ pub use mcp::*; pub use provider::*; pub use retry::*; pub use runtime_overrides::*; +pub use stream_error::*; pub use text_edit::*; #[cfg(test)] diff --git a/src/provider.rs b/src/provider.rs index 77ad4bd..6585e17 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -4,9 +4,9 @@ mod types; use std::{pin::Pin, time::Duration}; +use crate::StreamError; use futures_core::Stream; use futures_util::StreamExt; -use snafu::Whatever; use ::anthropic as anthropic_api; use ::openai as openai_api; @@ -19,7 +19,7 @@ use crate::{ pub use types::*; pub type MessagesStream = - Pin> + Send>>; + Pin> + Send>>; pub enum Client { Anthropic(anthropic_api::Client), @@ -126,7 +126,8 @@ impl Client { .call() .await .whatever_context_display("failed to send messages stream")?; - let mapped = stream.map(|event| event.map(Into::into)); + let mapped = + stream.map(|event| event.map(Into::into).map_err(map_anthropic_stream_error)); Ok(Box::pin(mapped)) } Client::OpenAI(client) => { @@ -224,7 +225,8 @@ impl Client { ) .await .whatever_context_display("failed to request tool choice stream")?; - let mapped = stream.map(|event| event.map(Into::into)); + let mapped = + stream.map(|event| event.map(Into::into).map_err(map_anthropic_stream_error)); Ok(Box::pin(mapped)) } Client::OpenAI(client) => { @@ -272,3 +274,10 @@ fn anthropic_retry_config(request_options: &RequestOptions) -> anthropic_api::Re notifier, } } + +fn map_anthropic_stream_error(err: anthropic_api::StreamError) -> StreamError { + match err.kind { + anthropic_api::StreamErrorKind::Transport => StreamError::transport(err.message), + anthropic_api::StreamErrorKind::Decode => StreamError::decode(err.message), + } +} diff --git a/src/provider/openai.rs b/src/provider/openai.rs index fc47d59..21bf338 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -16,7 +16,7 @@ use crate::provider::types::{ MessagesStreamEvent, Role, StopReason, StreamErrorDetail, Thinking, Tool, ToolChoice, ToolUse, UsageStats, }; -use crate::{RequestOptions, RetryAttempt as CoreRetryAttempt, RetryUpdate}; +use crate::{RequestOptions, RetryAttempt as CoreRetryAttempt, RetryUpdate, StreamError}; struct ToolCallState { #[allow(dead_code)] @@ -376,6 +376,13 @@ fn parse_tool_arguments(arguments: &str) -> Value { serde_json::from_str(arguments).unwrap_or_else(|_| Value::String(arguments.to_string())) } +fn map_stream_error(err: openai_api::StreamError) -> StreamError { + match err.kind { + openai_api::StreamErrorKind::Transport => StreamError::transport(err.message), + openai_api::StreamErrorKind::Decode => StreamError::decode(err.message), + } +} + fn map_finish_reason(reason: String) -> Option { match reason.as_str() { "stop" => Some(StopReason::EndTurn), @@ -468,7 +475,7 @@ impl OpenAIStream { } impl Stream for OpenAIStream { - type Item = Result; + type Item = std::result::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -481,7 +488,9 @@ impl Stream for OpenAIStream { } match Pin::new(&mut this.inner).poll_next(cx) { Poll::Pending => return Poll::Pending, - Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(Some(Err(map_stream_error(err)))); + } Poll::Ready(Some(Ok(chunk))) => { let openai_api::ChatCompletionChunk { choices, usage } = chunk; if let Some(usage) = usage { diff --git a/src/stream_error.rs b/src/stream_error.rs new file mode 100644 index 0000000..7389f8a --- /dev/null +++ b/src/stream_error.rs @@ -0,0 +1,48 @@ +use std::fmt; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamErrorKind { + Transport, + Decode, +} + +#[derive(Debug, Clone)] +pub struct StreamError { + pub kind: StreamErrorKind, + pub message: String, +} + +impl StreamError { + pub fn transport(message: impl Into) -> Self { + Self { + kind: StreamErrorKind::Transport, + message: message.into(), + } + } + + pub fn decode(message: impl Into) -> Self { + Self { + kind: StreamErrorKind::Decode, + message: message.into(), + } + } + + pub fn is_retryable(&self) -> bool { + matches!(self.kind, StreamErrorKind::Transport) + } + + pub fn with_context(self, context: impl fmt::Display) -> Self { + Self { + kind: self.kind, + message: format!("{context}: {}", self.message), + } + } +} + +impl fmt::Display for StreamError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for StreamError {} diff --git a/src/tools/run_combo.rs b/src/tools/run_combo.rs index 9407d8e..a1a3a4a 100644 --- a/src/tools/run_combo.rs +++ b/src/tools/run_combo.rs @@ -127,6 +127,11 @@ pub enum ComboEvent { /// Streamed text. text: String, }, + /// Reset prompt stream for combo reply. + PromptStreamReset { + /// Combo name. + name: String, + }, /// Reply tool use from prompt. ReplyToolUse { /// Combo name. @@ -638,25 +643,46 @@ async fn execute_combo( thinking.clone(), cancel_token.clone(), move |update| { - let (index, kind, text) = match update { + match update { + ChatStreamUpdate::Reset => { + thinking_seen_stream.store( + false, + std::sync::atomic::Ordering::Relaxed, + ); + emit_combo_event( + &on_event_stream, + ComboEvent::PromptStreamReset { + name: stream_name.clone(), + }, + ); + } ChatStreamUpdate::Plain { index, text } => { - (index, ComboStreamKind::Plain, text) + emit_combo_event( + &on_event_stream, + ComboEvent::PromptStream { + name: stream_name.clone(), + index, + kind: ComboStreamKind::Plain, + text, + }, + ); } ChatStreamUpdate::Thinking { index, text } => { - thinking_seen_stream - .store(true, std::sync::atomic::Ordering::Relaxed); - (index, ComboStreamKind::Thinking, text) + thinking_seen_stream.store( + true, + std::sync::atomic::Ordering::Relaxed, + ); + emit_combo_event( + &on_event_stream, + ComboEvent::PromptStream { + name: stream_name.clone(), + index, + kind: ComboStreamKind::Thinking, + text, + }, + ); } - }; - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStream { - name: stream_name.clone(), - index, - kind, - text, - }, - ); + } }, ) .await @@ -1268,22 +1294,37 @@ async fn handle_offload_combo_reply( let on_event_stream = on_event.clone(); let stream_name = combo_name.to_string(); let chat_response = agent - .chat_stream_with_history(cancel_token.clone(), move |update| { - let (index, kind, text) = match update { - ChatStreamUpdate::Plain { index, text } => (index, ComboStreamKind::Plain, text), - ChatStreamUpdate::Thinking { index, text } => { - (index, ComboStreamKind::Thinking, text) - } - }; - emit_combo_event( - &on_event_stream, - ComboEvent::PromptStream { - name: stream_name.clone(), - index, - kind, - text, - }, - ); + .chat_stream_with_history(cancel_token.clone(), move |update| match update { + ChatStreamUpdate::Reset => { + emit_combo_event( + &on_event_stream, + ComboEvent::PromptStreamReset { + name: stream_name.clone(), + }, + ); + } + ChatStreamUpdate::Plain { index, text } => { + emit_combo_event( + &on_event_stream, + ComboEvent::PromptStream { + name: stream_name.clone(), + index, + kind: ComboStreamKind::Plain, + text, + }, + ); + } + ChatStreamUpdate::Thinking { index, text } => { + emit_combo_event( + &on_event_stream, + ComboEvent::PromptStream { + name: stream_name.clone(), + index, + kind: ComboStreamKind::Thinking, + text, + }, + ); + } }) .await .map_err(|e| ComboReplyError::ChatFailed { diff --git a/src/tools/run_task.rs b/src/tools/run_task.rs index 042d86d..55fc4b9 100644 --- a/src/tools/run_task.rs +++ b/src/tools/run_task.rs @@ -491,16 +491,25 @@ fn execute_subagent( // Helper to emit stream updates as Output events, aggregating into lines let emit_update = move |update: ChatStreamUpdate| { - let (buffer, stream, prefix) = match &update { - ChatStreamUpdate::Plain { .. } => (&plain_buf_for_emit, StreamKind::Stdout, ""), - ChatStreamUpdate::Thinking { .. } => { - (&thinking_buf_for_emit, StreamKind::Stderr, "[thinking] ") + let (buffer, stream, prefix, text) = match update { + ChatStreamUpdate::Reset => { + if let Ok(mut buf) = plain_buf_for_emit.lock() { + buf.clear(); + } + if let Ok(mut buf) = thinking_buf_for_emit.lock() { + buf.clear(); + } + return; } - }; - - let text = match update { - ChatStreamUpdate::Plain { text, .. } => text, - ChatStreamUpdate::Thinking { text, .. } => text, + ChatStreamUpdate::Plain { text, .. } => { + (&plain_buf_for_emit, StreamKind::Stdout, "", text) + } + ChatStreamUpdate::Thinking { text, .. } => ( + &thinking_buf_for_emit, + StreamKind::Stderr, + "[thinking] ", + text, + ), }; let mut buf = buffer.lock().unwrap(); From 372a1dc22bd60f5433607d7836f4f30bccb9256a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Fri, 23 Jan 2026 20:42:48 +0800 Subject: [PATCH 2/2] refactor: remove redundant should_retry_stream_error helper function --- src/agent.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/agent.rs b/src/agent.rs index 58ac1b7..cb8332a 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -297,10 +297,6 @@ impl StreamAccumulator { } } -fn should_retry_stream_error(err: &StreamError) -> bool { - err.is_retryable() -} - fn stream_retry_delay(attempt: usize, max_delay: Duration) -> Duration { let shift = attempt.saturating_sub(1).min(30) as u32; let multiplier = 1u64 << shift; @@ -781,7 +777,7 @@ impl Agent { } if let Some(err) = stream_error { - if should_retry_stream_error(&err) && attempt < max_attempts { + if err.is_retryable() && attempt < max_attempts { attempt += 1; retried = true; let delay = stream_retry_delay(attempt, max_delay); @@ -1811,13 +1807,13 @@ mod tests { #[test] fn retry_stream_error_heuristics() { let err = StreamError::transport("read stream chunk error: broken pipe".to_string()); - assert!(should_retry_stream_error(&err)); + assert!(err.is_retryable()); let err = StreamError::decode("decode stream event data".to_string()); - assert!(!should_retry_stream_error(&err)); + assert!(!err.is_retryable()); let err = StreamError::decode("chat stream cancelled".to_string()); - assert!(!should_retry_stream_error(&err)); + assert!(!err.is_retryable()); } #[test]