Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 72 additions & 24 deletions crates/anthropic/src/client/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,42 @@ use tracing::{trace, warn};

use crate::{Block, Message, RetryAttempt, RetryConfig, Role, Tool};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamErrorKind {
Transport,
Decode,
}

#[derive(Debug, Clone)]
pub struct StreamError {
pub kind: StreamErrorKind,
pub message: String,
}

impl StreamError {
fn transport(context: &'static str, err: impl std::fmt::Display) -> Self {
Self {
kind: StreamErrorKind::Transport,
message: format!("{context}: {err}"),
}
}

fn decode(context: &'static str, err: impl std::fmt::Display) -> Self {
Self {
kind: StreamErrorKind::Decode,
message: format!("{context}: {err}"),
}
}
}

impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}

impl std::error::Error for StreamError {}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolChoice {
Expand Down Expand Up @@ -377,7 +413,7 @@ impl SseEventStream {
Some(line)
}

fn process_line(&mut self, line: &[u8]) -> Result<(), Whatever> {
fn process_line(&mut self, line: &[u8]) -> std::result::Result<(), StreamError> {
if line.is_empty() {
self.flush_event();
return Ok(());
Expand All @@ -387,8 +423,8 @@ impl SseEventStream {
}
if let Some(rest) = line.strip_prefix(b"event:") {
let rest = trim_leading_space(rest);
let name =
std::str::from_utf8(rest).whatever_context("stream event name is not utf-8")?;
let name = std::str::from_utf8(rest)
.map_err(|err| StreamError::decode("stream event name is not utf-8", err))?;
self.event = Some(name.to_string());
return Ok(());
}
Expand All @@ -413,7 +449,7 @@ impl SseEventStream {
}

impl Stream for SseEventStream {
type Item = Result<SseEvent, Whatever>;
type Item = std::result::Result<SseEvent, StreamError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
Expand All @@ -427,10 +463,10 @@ impl Stream for SseEventStream {
match this.inner.as_mut().poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Err(err))) => {
let message = format!("read stream chunk error: {err}");
return Poll::Ready(Some(Err(
<Whatever as snafu::FromString>::without_source(message),
)));
return Poll::Ready(Some(Err(StreamError::transport(
"read stream chunk error",
err,
))));
}
Poll::Ready(Some(Ok(chunk))) => {
this.buffer.extend_from_slice(&chunk);
Expand Down Expand Up @@ -479,7 +515,7 @@ impl MessagesStream {
}

impl Stream for MessagesStream {
type Item = Result<MessagesStreamEvent, Whatever>;
type Item = std::result::Result<MessagesStreamEvent, StreamError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
Expand All @@ -500,13 +536,16 @@ fn trim_leading_space(bytes: &[u8]) -> &[u8] {
}
}

fn parse_messages_stream_event(event: SseEvent) -> Result<MessagesStreamEvent, Whatever> {
fn parse_messages_stream_event(
event: SseEvent,
) -> std::result::Result<MessagesStreamEvent, StreamError> {
let data = if event.data.is_empty() {
Value::Object(Map::new())
} else {
let text =
std::str::from_utf8(&event.data).whatever_context("stream event data is not utf-8")?;
serde_json::from_str::<Value>(text).whatever_context("decode stream event data")?
let text = std::str::from_utf8(&event.data)
.map_err(|err| StreamError::decode("stream event data is not utf-8", err))?;
serde_json::from_str::<Value>(text)
.map_err(|err| StreamError::decode("decode stream event data", err))?
};
let event_type = data
.get("type")
Expand All @@ -516,36 +555,36 @@ fn parse_messages_stream_event(event: SseEvent) -> Result<MessagesStreamEvent, W
trace!(?event_type, ?data, "received sse event");
match event_type.as_str() {
"message_start" => {
let event: MessageStartEvent =
serde_json::from_value(data).whatever_context("decode message_start event")?;
let event: MessageStartEvent = serde_json::from_value(data)
.map_err(|err| StreamError::decode("decode message_start event", err))?;
Ok(MessagesStreamEvent::MessageStart {
message: event.message,
})
}
"content_block_start" => {
let event: ContentBlockStartEvent = serde_json::from_value(data)
.whatever_context("decode content_block_start event")?;
.map_err(|err| StreamError::decode("decode content_block_start event", err))?;
Ok(MessagesStreamEvent::ContentBlockStart {
index: event.index,
content_block: event.content_block,
})
}
"content_block_delta" => {
let event: ContentBlockDeltaEvent = serde_json::from_value(data)
.whatever_context("decode content_block_delta event")?;
.map_err(|err| StreamError::decode("decode content_block_delta event", err))?;
Ok(MessagesStreamEvent::ContentBlockDelta {
index: event.index,
delta: event.delta,
})
}
"content_block_stop" => {
let event: ContentBlockStopEvent =
serde_json::from_value(data).whatever_context("decode content_block_stop event")?;
let event: ContentBlockStopEvent = serde_json::from_value(data)
.map_err(|err| StreamError::decode("decode content_block_stop event", err))?;
Ok(MessagesStreamEvent::ContentBlockStop { index: event.index })
}
"message_delta" => {
let event: MessageDeltaEvent =
serde_json::from_value(data).whatever_context("decode message_delta event")?;
let event: MessageDeltaEvent = serde_json::from_value(data)
.map_err(|err| StreamError::decode("decode message_delta event", err))?;
Ok(MessagesStreamEvent::MessageDelta {
delta: event.delta.unwrap_or_default(),
usage: event.usage,
Expand All @@ -554,8 +593,8 @@ fn parse_messages_stream_event(event: SseEvent) -> Result<MessagesStreamEvent, W
"message_stop" => Ok(MessagesStreamEvent::MessageStop),
"ping" => Ok(MessagesStreamEvent::Ping),
"error" => {
let event: StreamErrorEvent =
serde_json::from_value(data).whatever_context("decode error event")?;
let event: StreamErrorEvent = serde_json::from_value(data)
.map_err(|err| StreamError::decode("decode error event", err))?;
Ok(MessagesStreamEvent::Error { error: event.error })
}
_ => Ok(MessagesStreamEvent::Unknown {
Expand Down Expand Up @@ -789,7 +828,7 @@ impl AnthropicRequestError {
match self {
Self::Transport { source, .. } => source.is_timeout() || source.is_connect(),
Self::HttpStatus { status, .. } => is_retryable_status(*status),
Self::Decode { .. } => false,
Self::Decode { .. } => true,
}
}

Expand Down Expand Up @@ -850,6 +889,15 @@ mod retry_tests {
assert!(is_retryable_status(StatusCode::SERVICE_UNAVAILABLE));
assert!(!is_retryable_status(StatusCode::BAD_REQUEST));
}

#[test]
fn retryable_decode_errors() {
let err = AnthropicRequestError::Decode {
context: "messages",
message: "bad json".to_string(),
};
assert!(err.is_retryable());
}
}

#[cfg(test)]
Expand Down
71 changes: 58 additions & 13 deletions crates/coco-tui/src/components/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ impl Chat<'static> {
| ComboEvent::RecordStart { .. }
| ComboEvent::RecordOutput { .. }
| ComboEvent::RecordEnd { .. }
| ComboEvent::PromptStream { .. } => {
| ComboEvent::PromptStream { .. }
| ComboEvent::PromptStreamReset { .. } => {
self.set_processing();
}
ComboEvent::Prompt { thinking, .. } => {
Expand Down Expand Up @@ -646,6 +647,10 @@ impl Chat<'static> {
},
text: text.clone(),
},
ComboToolEvent::PromptStreamReset { name } => ComboEvent::PromptStreamReset {
id: id.to_string(),
name: name.clone(),
},
ComboToolEvent::ReplyToolUse {
name,
tool_use,
Expand Down Expand Up @@ -2278,19 +2283,39 @@ async fn task_chat(mut agent: Agent, content: ChatContent, cancel_token: Cancell
let thinking_seen_stream = thinking_seen.clone();
let resp = agent
.chat_stream(msg, cancel_token.clone(), move |update| {
let (index, kind, text) = match update {
match update {
ChatStreamUpdate::Reset => {
plain_seen_stream.store(false, Ordering::Relaxed);
thinking_seen_stream.store(false, Ordering::Relaxed);
stream_tx.send(AnswerEvent::BotStreamReset.into()).ok();
}
ChatStreamUpdate::Plain { index, text } => {
plain_seen_stream.store(true, Ordering::Relaxed);
(index, BotStreamKind::Plain, text)
stream_tx
.send(
AnswerEvent::BotStream {
index,
kind: BotStreamKind::Plain,
text,
}
.into(),
)
.ok();
}
ChatStreamUpdate::Thinking { index, text } => {
thinking_seen_stream.store(true, Ordering::Relaxed);
(index, BotStreamKind::Thinking, text)
stream_tx
.send(
AnswerEvent::BotStream {
index,
kind: BotStreamKind::Thinking,
text,
}
.into(),
)
.ok();
}
};
stream_tx
.send(AnswerEvent::BotStream { index, kind, text }.into())
.ok();
})
.await;
streamed_plain = plain_seen.load(Ordering::Relaxed);
Expand Down Expand Up @@ -2349,19 +2374,39 @@ async fn task_chat_with_history(mut agent: Agent, cancel_token: CancellationToke
let thinking_seen_stream = thinking_seen.clone();
let resp = agent
.chat_stream_with_history(cancel_token.clone(), move |update| {
let (index, kind, text) = match update {
match update {
ChatStreamUpdate::Reset => {
plain_seen_stream.store(false, Ordering::Relaxed);
thinking_seen_stream.store(false, Ordering::Relaxed);
stream_tx.send(AnswerEvent::BotStreamReset.into()).ok();
}
ChatStreamUpdate::Plain { index, text } => {
plain_seen_stream.store(true, Ordering::Relaxed);
(index, BotStreamKind::Plain, text)
stream_tx
.send(
AnswerEvent::BotStream {
index,
kind: BotStreamKind::Plain,
text,
}
.into(),
)
.ok();
}
ChatStreamUpdate::Thinking { index, text } => {
thinking_seen_stream.store(true, Ordering::Relaxed);
(index, BotStreamKind::Thinking, text)
stream_tx
.send(
AnswerEvent::BotStream {
index,
kind: BotStreamKind::Thinking,
text,
}
.into(),
)
.ok();
}
};
stream_tx
.send(AnswerEvent::BotStream { index, kind, text }.into())
.ok();
})
.await;
streamed_plain = plain_seen.load(Ordering::Relaxed);
Expand Down
6 changes: 6 additions & 0 deletions crates/coco-tui/src/components/messages/combo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,12 @@ impl Combo {
self.push_prompt(prompt);
}
}
ComboEvent::PromptStreamReset { id, .. } => {
if self.matches_id(id) {
self.messages.finalize_stream();
self.messages.reset_stream();
}
}
ComboEvent::PromptStream {
id,
index,
Expand Down
4 changes: 4 additions & 0 deletions crates/coco-tui/src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ pub enum ComboEvent {
kind: BotStreamKind,
text: String,
},
PromptStreamReset {
id: String,
name: String,
},
/// Reply tool use from prompt, with optional offload via bash
ReplyToolUse {
id: String,
Expand Down
Loading