From 9008933c53ccc73cd5a526facceabc3d54d8cc9a Mon Sep 17 00:00:00 2001 From: Sathvik-1007 Date: Mon, 18 May 2026 01:52:59 +0530 Subject: [PATCH] fix: add provider/model failover to streaming LLM calls stream_chat_with_system only tried the first streaming-capable provider with the first model. Transient errors propagated immediately while non-streaming methods had full retry + failover. Now iterates all provider+model candidates with exponential backoff between transient failures, matching non-streaming reliability behavior. Closes #1931 --- src/openhuman/inference/provider/reliable.rs | 193 ++++++++++++++----- 1 file changed, 141 insertions(+), 52 deletions(-) diff --git a/src/openhuman/inference/provider/reliable.rs b/src/openhuman/inference/provider/reliable.rs index b2d8b4f587..e311b0e47e 100644 --- a/src/openhuman/inference/provider/reliable.rs +++ b/src/openhuman/inference/provider/reliable.rs @@ -1,5 +1,5 @@ use super::traits::{ - ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamOptions, StreamResult, + ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamError, StreamOptions, StreamResult, }; use super::Provider; use async_trait::async_trait; @@ -59,6 +59,32 @@ fn is_non_retryable(err: &anyhow::Error) -> bool { || msg_lower.contains("invalid")) } +/// Classify a StreamError without losing type information. +/// Inspects the inner reqwest::Error status directly for Http variants. +fn is_stream_error_non_retryable(err: &StreamError) -> bool { + match err { + StreamError::Http(reqwest_err) => { + if let Some(status) = reqwest_err.status() { + let code = status.as_u16(); + // Client errors except 429 (rate limit) and 408 (timeout) are non-retryable + return status.is_client_error() && code != 429 && code != 408; + } + false + } + StreamError::Provider(msg) => { + let lower = msg.to_lowercase(); + lower.contains("invalid api key") + || lower.contains("unauthorized") + || lower.contains("forbidden") + || lower.contains("model") + && (lower.contains("not found") || lower.contains("unsupported")) + } + // JSON/SSE parse errors and IO errors are generally non-retryable + StreamError::Json(_) | StreamError::InvalidSse(_) => true, + StreamError::Io(_) => false, + } +} + fn is_context_window_exceeded(err: &anyhow::Error) -> bool { let lower = err.to_string().to_lowercase(); let hints = [ @@ -924,63 +950,126 @@ impl Provider for ReliableProvider { temperature: f64, options: StreamOptions, ) -> stream::BoxStream<'static, StreamResult> { - // Try each provider/model combination for streaming - // For streaming, we use the first provider that supports it and has streaming enabled - for (provider_name, provider) in &self.providers { - if !provider.supports_streaming() || !options.enabled { - continue; + if !options.enabled { + return stream::once(async move { + Err(super::traits::StreamError::Provider( + "Streaming disabled".to_string(), + )) + }) + .boxed(); + } + + // Collect streaming-capable providers + let streaming_providers: Vec<_> = self + .providers + .iter() + .filter(|(_, p)| p.supports_streaming()) + .collect(); + + if streaming_providers.is_empty() { + return stream::once(async move { + Err(super::traits::StreamError::Provider( + "No provider supports streaming".to_string(), + )) + }) + .boxed(); + } + + // Build model chain and provider info for the spawned task + let models = self.model_chain(model); + let model_chain: Vec = models.into_iter().map(|m| m.to_string()).collect(); + let base_backoff_ms = self.base_backoff_ms; + + // Collect provider streams lazily inside the task — we need owned data + // Provider trait is object-safe, so we call stream_chat_with_system per attempt + // We need to pre-create all possible streams since Provider is behind &self + // Instead, collect the streams for each provider+model combo upfront + let mut candidate_streams: Vec<( + String, + String, + stream::BoxStream<'static, StreamResult>, + )> = Vec::new(); + for current_model in &model_chain { + for (provider_name, provider) in &streaming_providers { + let s = provider.stream_chat_with_system( + system_prompt, + message, + current_model, + temperature, + options, + ); + candidate_streams.push(((*provider_name).clone(), current_model.clone(), s)); } + } - // Clone provider data for the stream - let provider_clone = provider_name.clone(); - - // Try the first model in the chain for streaming - let current_model = match self.model_chain(model).first() { - Some(m) => m.to_string(), - None => model.to_string(), - }; - - // For streaming, we attempt once and propagate errors - // The caller can retry the entire request if needed - let stream = provider.stream_chat_with_system( - system_prompt, - message, - ¤t_model, - temperature, - options, - ); - - // Use a channel to bridge the stream with logging - let (tx, rx) = tokio::sync::mpsc::channel::>(100); - - tokio::spawn(async move { - let mut stream = stream; - while let Some(chunk) = stream.next().await { - if let Err(ref e) = chunk { - tracing::warn!( - provider = provider_clone, - model = current_model, - "Streaming error: {e}" - ); - } - if tx.send(chunk).await.is_err() { - break; // Receiver dropped + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + let max_retries = self.max_retries; + + tokio::spawn(async move { + for (provider_name, current_model, mut candidate_stream) in candidate_streams { + let mut backoff_ms = base_backoff_ms; + let mut attempts = 0u32; + + loop { + match candidate_stream.next().await { + Some(Ok(chunk)) => { + // First chunk succeeded — commit to this stream + if tx.send(Ok(chunk)).await.is_err() { + return; + } + // Forward remaining chunks + while let Some(chunk) = candidate_stream.next().await { + if tx.send(chunk).await.is_err() { + return; + } + } + return; // Done successfully + } + Some(Err(ref e)) => { + let non_retryable = is_stream_error_non_retryable(e); + + tracing::warn!( + provider = provider_name, + model = current_model, + attempt = attempts + 1, + error = %e, + "Streaming failed{}", if non_retryable { " (non-retryable)" } else { "" } + ); + + if non_retryable || attempts >= max_retries { + break; // Move to next candidate + } + + attempts += 1; + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + // Continue inner loop — stream may yield more items + } + None => { + // Stream exhausted without success + if attempts == 0 { + tracing::warn!( + provider = provider_name, + model = current_model, + "Stream returned empty" + ); + } + break; // Move to next candidate + } } } - }); + } - // Convert channel receiver to stream - return stream::unfold(rx, |mut rx| async move { - rx.recv().await.map(|chunk| (chunk, rx)) - }) - .boxed(); - } + // All providers/models exhausted + let _ = tx + .send(Err(super::traits::StreamError::Provider( + "All streaming providers/models failed".to_string(), + ))) + .await; + }); - // No streaming support available - stream::once(async move { - Err(super::traits::StreamError::Provider( - "No provider supports streaming".to_string(), - )) + stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|chunk| (chunk, rx)) }) .boxed() }