From 50f4381745b98e4ccffaa10b7b56e8ad99327bf2 Mon Sep 17 00:00:00 2001 From: Stephen Belanger Date: Tue, 28 Apr 2026 22:04:56 -0700 Subject: [PATCH 1/2] Fix streaming aggregation issues --- src/stream.rs | 429 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 401 insertions(+), 28 deletions(-) diff --git a/src/stream.rs b/src/stream.rs index 06e759d..cfed8db 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -13,6 +13,53 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; +/// Accumulates partial tool call information from streaming chunks. +/// +/// In OpenAI streaming, each chunk may contain partial tool call data. +/// This struct accumulates the id, type, function name, and raw arguments +/// across multiple chunks, then parses the complete arguments string +/// as JSON when aggregation is complete. +#[derive(Clone, Debug, Default)] +struct StreamToolCallAccumulator { + id: String, + call_type: String, + function_name: String, + raw_arguments: String, +} + +impl StreamToolCallAccumulator { + /// Creates a `ToolCall` from the accumulated data. + /// + /// The raw arguments string is parsed as JSON to extract the function + /// name and arguments, falling back to the raw string if parsing fails. + pub fn build(self) -> crate::stream::ToolCall { + let args: serde_json::Value = serde_json::from_str(&self.raw_arguments).unwrap_or_default(); + + // Extract function name from JSON object, falling back to stored name + let function_name = if let Some(obj) = args.as_object() { + obj.get("name") + .and_then(|v| v.as_str()) + .unwrap_or(&self.function_name) + .to_string() + } else { + self.function_name.clone() + }; + + // Extract arguments string from JSON object, falling back to raw string + let arguments = if let Some(obj) = args.as_object() { + obj.get("arguments") + .and_then(|v| v.as_str()) + .unwrap_or(&self.raw_arguments) + .to_string() + } else { + self.raw_arguments.clone() + }; + + let function = crate::stream::FunctionCall::new(function_name, arguments); + crate::stream::ToolCall::new(self.id, self.call_type, function) + } +} + use anyhow::Result; use futures::Stream; use serde::{Deserialize, Serialize}; @@ -646,13 +693,38 @@ struct StreamChunk { usage: Option, } -/// Delta from a streaming chunk (typed for role/content). +/// Delta from a streaming chunk (typed for role/content/tool_calls). #[derive(Debug, Clone, Default, Deserialize, Serialize)] struct StreamDelta { #[serde(default)] role: Option, #[serde(default)] content: Option, + #[serde(default)] + tool_calls: Option>, +} + +/// A partial tool call from a streaming delta. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +struct StreamToolCall { + #[serde(default)] + index: Option, + #[serde(default)] + id: Option, + #[serde(default)] + #[serde(rename = "type")] + call_type: Option, + #[serde(default)] + function: Option, +} + +/// A partial function call from a streaming delta. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +struct StreamFunctionCall { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -663,16 +735,45 @@ struct StreamChoice { finish_reason: Option, } -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Default, Deserialize, Serialize)] struct StreamUsage { - #[serde(default)] - prompt_tokens: Option, #[serde(default, alias = "input_tokens")] + prompt_tokens: Option, + #[serde(default, alias = "output_tokens")] completion_tokens: Option, - #[serde(default, alias = "cache_read_input_tokens")] - prompt_cached_tokens: Option, - #[serde(default, alias = "cache_creation_input_tokens")] - prompt_cache_creation_tokens: Option, + #[serde(default, alias = "completion_reasoning_tokens")] + completion_reasoning_tokens: Option, + #[serde(default, alias = "accepted_prediction_tokens")] + accepted_prediction_tokens: Option, + #[serde(default, alias = "rejected_prediction_tokens")] + rejected_prediction_tokens: Option, + #[serde(default, alias = "prompt_tokens_details", alias = "input_tokens_details")] + prompt_tokens_details: Option, + #[serde(default, alias = "completion_tokens_details", alias = "output_tokens_details")] + completion_tokens_details: Option, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +struct PromptTokensDetails { + #[serde(default, alias = "audio_tokens", alias = "prompt_audio_tokens")] + pub(crate) audio_tokens: Option, + #[serde(default, alias = "cached_tokens", alias = "prompt_cached_tokens")] + pub(crate) cached_tokens: Option, + #[serde( + default, + alias = "cache_creation_tokens", + alias = "cache_creation_input_tokens", + alias = "prompt_cache_creation_tokens" + )] + pub(crate) cache_creation_tokens: Option, +} + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +struct CompletionTokensDetails { + #[serde(default, alias = "audio_tokens", alias = "completion_audio_tokens")] + pub(crate) audio_tokens: Option, + #[serde(default, alias = "reasoning_tokens", alias = "completion_reasoning_tokens")] + pub(crate) reasoning_tokens: Option, } impl BraintrustStream { @@ -733,6 +834,9 @@ impl BraintrustStream { let mut aggregated_content = String::new(); let mut role: Option = None; + // Accumulate tool calls by index across chunks + let mut tool_calls_by_index: HashMap = HashMap::new(); + for raw in &self.raw_chunks { // Try to parse as OpenAI-style streaming chunk let chunk: StreamChunk = match serde_json::from_value(raw.clone()) { @@ -745,25 +849,64 @@ impl BraintrustStream { model = chunk.model; } - // Extract usage (take last non-None) + // Extract detailed usage (take last non-None, preserving all details) if let Some(ref u) = chunk.usage { + let prompt_tokens = u.prompt_tokens.and_then(|v| u32::try_from(v).ok()); + let completion_tokens = u.completion_tokens.and_then(|v| u32::try_from(v).ok()); + + let prompt_details = u.prompt_tokens_details.clone().map(|p| { + crate::types::PromptTokensDetails::new( + p.audio_tokens.and_then(|v| u32::try_from(v).ok()), + p.cached_tokens.and_then(|v| u32::try_from(v).ok()), + p.cache_creation_tokens.and_then(|v| u32::try_from(v).ok()), + ) + }); + + let completion_details = u.completion_tokens_details.clone().map(|c| { + crate::types::CompletionTokensDetails::new( + c.audio_tokens.and_then(|v| u32::try_from(v).ok()), + c.reasoning_tokens.and_then(|v| u32::try_from(v).ok()), + u.accepted_prediction_tokens.and_then(|v| u32::try_from(v).ok()), + u.rejected_prediction_tokens.and_then(|v| u32::try_from(v).ok()), + ) + }); + usage = Some(UsageMetrics { - prompt_tokens: u.prompt_tokens.and_then(|v| u32::try_from(v).ok()), - completion_tokens: u.completion_tokens.and_then(|v| u32::try_from(v).ok()), - total_tokens: match (u.prompt_tokens, u.completion_tokens) { - (Some(p), Some(c)) => u32::try_from(p + c).ok(), + prompt_tokens, + completion_tokens, + total_tokens: match (prompt_tokens, completion_tokens) { + (Some(p), Some(c)) => Some(p + c), _ => None, }, reasoning_tokens: None, - prompt_cached_tokens: u - .prompt_cached_tokens - .and_then(|v| u32::try_from(v).ok()), - prompt_cache_creation_tokens: u - .prompt_cache_creation_tokens - .and_then(|v| u32::try_from(v).ok()), - completion_reasoning_tokens: None, - prompt_tokens_details: None, - completion_tokens_details: None, + prompt_cached_tokens: prompt_details + .as_ref() + .and_then(|d| d.cached_tokens()) + .or_else(|| { + u.prompt_tokens_details + .as_ref() + .and_then(|p| p.cached_tokens) + .and_then(|v| u32::try_from(v).ok()) + }), + prompt_cache_creation_tokens: prompt_details + .as_ref() + .and_then(|d| d.cache_creation_tokens()) + .or_else(|| { + u.prompt_tokens_details + .as_ref() + .and_then(|p| p.cache_creation_tokens) + .and_then(|v| u32::try_from(v).ok()) + }), + completion_reasoning_tokens: u + .completion_reasoning_tokens + .and_then(|v| u32::try_from(v).ok()) + .or_else(|| { + completion_details + .as_ref() + .and_then(|d| d.reasoning_tokens()) + }), + prompt_tokens_details: prompt_details, + completion_tokens_details: completion_details, }); } @@ -776,7 +919,7 @@ impl BraintrustStream { } } - // Extract content from delta + // Extract content and tool calls from delta if let Some(ref delta) = choice.delta { // Extract role (take first) if role.is_none() { @@ -787,6 +930,37 @@ impl BraintrustStream { if let Some(ref content) = delta.content { aggregated_content.push_str(content); } + + // Accumulate tool calls by index + if let Some(ref tool_calls) = delta.tool_calls { + for tc in tool_calls { + let idx = tc.index.unwrap_or(0); + let entry = tool_calls_by_index.entry(idx).or_default(); + + // Store id if present and not yet stored + if entry.id.is_empty() { + entry.id = tc.id.clone().unwrap_or_default(); + } + + // Store call_type if present and not yet stored + if entry.call_type.is_empty() { + entry.call_type = tc.call_type.clone().unwrap_or_default(); + } + + // Store function name if present + if let Some(ref func) = tc.function { + if let Some(ref name) = func.name { + entry.function_name = name.clone(); + } + // Accumulate raw argument fragments + if let Some(ref args) = func.arguments { + if !args.is_empty() { + entry.raw_arguments.push_str(args); + } + } + } + } + } } } } @@ -795,11 +969,25 @@ impl BraintrustStream { let metadata = StreamMetadata::new(model, HashMap::new()); // Build typed output (matches OpenAI response format) - let message = ChatMessage::new( - Some(role.unwrap_or_else(|| "assistant".to_string())), - Some(aggregated_content), - None, // TODO: implement tool call aggregation - ); + let message = if tool_calls_by_index.is_empty() { + ChatMessage::new( + Some(role.unwrap_or_else(|| "assistant".to_string())), + Some(aggregated_content), + None, + ) + } else { + let mut tool_calls = Vec::new(); + for idx in 0..tool_calls_by_index.len() { + if let Some(accumulator) = tool_calls_by_index.remove(&idx) { + tool_calls.push(accumulator.build()); + } + } + ChatMessage::new( + Some(role.unwrap_or_else(|| "assistant".to_string())), + Some(aggregated_content), + Some(tool_calls), + ) + }; let choice = OutputChoice::new(0, message, finish_reason); @@ -1155,6 +1343,191 @@ mod tests { assert_eq!(usage.total_tokens(), Some(15)); } + #[test] + fn aggregates_usage_with_input_output_token_aliases() { + let chunks = vec![ + json!({ + "id": "chunk1", + "model": "gpt-4.1", + "choices": [{ + "index": 0, + "delta": { "role": "assistant", "content": "Hi" }, + "finish_reason": "stop" + }], + "created": 1 + }), + json!({ + "id": "chunk2", + "model": "gpt-4.1", + "choices": [], + "created": 1, + "usage": { + "input_tokens": 12, + "output_tokens": 7 + } + }), + ]; + + let mut stream = BraintrustStream::new(); + for chunk in chunks { + stream.push(chunk); + } + + let finalized = stream.final_value().expect("should finalize"); + let usage = finalized.usage().expect("should have usage"); + + assert_eq!(usage.prompt_tokens(), Some(12)); + assert_eq!(usage.completion_tokens(), Some(7)); + assert_eq!(usage.total_tokens(), Some(19)); + } + + #[test] + fn aggregates_tool_calls_from_streaming_values() { + let chunks = vec![ + json!({ + "id": "chunk1", + "model": "gpt-4o", + "choices": [{ + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": [{ + "index": 0, + "id": "call_abc", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\":\"" + } + }] + } + }], + "created": 1 + }), + json!({ + "id": "chunk2", + "model": "gpt-4o", + "choices": [{ + "index": 0, + "delta": { + "tool_calls": [{ + "index": 0, + "function": { + "arguments": "NYC\",\"unit\":\"" + } + }] + } + }], + "created": 1 + }), + json!({ + "id": "chunk3", + "model": "gpt-4o", + "choices": [{ + "index": 0, + "delta": { + "tool_calls": [{ + "index": 0, + "function": { + "arguments": "celsius\"}" + } + }] + }, + "finish_reason": "tool_calls" + }], + "created": 1 + }), + ]; + + let mut stream = BraintrustStream::new(); + for chunk in chunks { + stream.push(chunk); + } + + let finalized = stream.final_value().expect("should finalize"); + let choice = &finalized.output()[0]; + let tool_calls = choice + .message() + .tool_calls() + .expect("should aggregate tool calls"); + + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id(), "call_abc"); + assert_eq!(tool_calls[0].call_type(), "function"); + assert_eq!(tool_calls[0].function().name(), "get_weather"); + assert_eq!( + tool_calls[0].function().arguments(), + "{\"location\":\"NYC\",\"unit\":\"celsius\"}" + ); + assert_eq!(choice.finish_reason(), Some("tool_calls")); + } + + #[test] + fn aggregates_detailed_usage_from_final_chunk() { + let chunks = vec![ + json!({ + "id": "chunk1", + "model": "o4-mini", + "choices": [{ + "index": 0, + "delta": { "role": "assistant", "content": "Reasoned answer" }, + "finish_reason": "stop" + }], + "created": 1 + }), + json!({ + "id": "chunk2", + "model": "o4-mini", + "choices": [], + "created": 1, + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "input_tokens_details": { + "audio_tokens": 3, + "cached_tokens": 20, + "cache_creation_tokens": 5 + }, + "output_tokens_details": { + "audio_tokens": 2, + "reasoning_tokens": 17 + }, + "accepted_prediction_tokens": 11, + "rejected_prediction_tokens": 4 + } + }), + ]; + + let mut stream = BraintrustStream::new(); + for chunk in chunks { + stream.push(chunk); + } + + let finalized = stream.final_value().expect("should finalize"); + let usage = finalized.usage().expect("should have usage"); + + assert_eq!(usage.prompt_tokens(), Some(100)); + assert_eq!(usage.completion_tokens(), Some(50)); + assert_eq!(usage.total_tokens(), Some(150)); + assert_eq!(usage.completion_reasoning_tokens(), Some(17)); + assert_eq!(usage.reasoning_tokens(), None); + + let prompt_details = usage + .prompt_tokens_details() + .expect("should have prompt token details"); + assert_eq!(prompt_details.audio_tokens(), Some(3)); + assert_eq!(prompt_details.cached_tokens(), Some(20)); + assert_eq!(prompt_details.cache_creation_tokens(), Some(5)); + + let completion_details = usage + .completion_tokens_details() + .expect("should have completion token details"); + assert_eq!(completion_details.audio_tokens(), Some(2)); + assert_eq!(completion_details.reasoning_tokens(), Some(17)); + assert_eq!(completion_details.accepted_prediction_tokens(), Some(11)); + assert_eq!(completion_details.rejected_prediction_tokens(), Some(4)); + } + #[test] fn skips_keep_alive_markers() { let mut stream = BraintrustStream::new(); From ae0bfc79896b74809f2de3c04c62614d856be963 Mon Sep 17 00:00:00 2001 From: Stephen Belanger Date: Wed, 29 Apr 2026 12:15:35 -0700 Subject: [PATCH 2/2] Format streaming aggregation changes --- src/stream.rs | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/stream.rs b/src/stream.rs index cfed8db..615de93 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -747,9 +747,17 @@ struct StreamUsage { accepted_prediction_tokens: Option, #[serde(default, alias = "rejected_prediction_tokens")] rejected_prediction_tokens: Option, - #[serde(default, alias = "prompt_tokens_details", alias = "input_tokens_details")] + #[serde( + default, + alias = "prompt_tokens_details", + alias = "input_tokens_details" + )] prompt_tokens_details: Option, - #[serde(default, alias = "completion_tokens_details", alias = "output_tokens_details")] + #[serde( + default, + alias = "completion_tokens_details", + alias = "output_tokens_details" + )] completion_tokens_details: Option, } @@ -772,7 +780,11 @@ struct PromptTokensDetails { struct CompletionTokensDetails { #[serde(default, alias = "audio_tokens", alias = "completion_audio_tokens")] pub(crate) audio_tokens: Option, - #[serde(default, alias = "reasoning_tokens", alias = "completion_reasoning_tokens")] + #[serde( + default, + alias = "reasoning_tokens", + alias = "completion_reasoning_tokens" + )] pub(crate) reasoning_tokens: Option, } @@ -866,8 +878,10 @@ impl BraintrustStream { crate::types::CompletionTokensDetails::new( c.audio_tokens.and_then(|v| u32::try_from(v).ok()), c.reasoning_tokens.and_then(|v| u32::try_from(v).ok()), - u.accepted_prediction_tokens.and_then(|v| u32::try_from(v).ok()), - u.rejected_prediction_tokens.and_then(|v| u32::try_from(v).ok()), + u.accepted_prediction_tokens + .and_then(|v| u32::try_from(v).ok()), + u.rejected_prediction_tokens + .and_then(|v| u32::try_from(v).ok()), ) });