diff --git a/tensorzero-core/src/providers/amux.rs b/tensorzero-core/src/providers/amux.rs index 4a1cbe094c..ed439b316c 100644 --- a/tensorzero-core/src/providers/amux.rs +++ b/tensorzero-core/src/providers/amux.rs @@ -268,6 +268,37 @@ fn parse_urls(body: &Value) -> Vec { Vec::new() } +/// Terminal classification of a single poll response. +enum PollOutcome { + Done, + Failed(String), + Pending, +} + +/// Classify an Amux poll response into a terminal/pending state. +/// +/// The universal poll endpoint wraps the task under `data`: +/// `{ code, message, data: { status, url, error } }`. We read the status from +/// there, falling back to the root so a flat shape still works. Amux reports +/// terminal success as `"succeeded"` (not `"completed"`) and failure as +/// `"failed"`; both spellings of success/failure are accepted defensively. +fn classify_poll(body: &Value) -> PollOutcome { + let task = body.get("data").unwrap_or(body); + match task.get("status").and_then(Value::as_str).unwrap_or("") { + "succeeded" | "completed" => PollOutcome::Done, + "failed" | "error" => { + let reason = task + .get("error") + .and_then(|err| err.get("message")) + .and_then(Value::as_str) + .unwrap_or("(no reason given)") + .to_string(); + PollOutcome::Failed(reason) + } + _ => PollOutcome::Pending, + } +} + async fn poll_async_result( http_client: &TensorzeroHttpClient, api_key: &str, @@ -325,17 +356,9 @@ async fn poll_async_result( })); } - // Universal status values: queued | in_progress | completed | failed | - // unknown. - let status_str = body.get("status").and_then(Value::as_str).unwrap_or(""); - match status_str { - "completed" => return Ok(body), - "failed" => { - let reason = body - .get("error") - .and_then(|err| err.get("message")) - .and_then(Value::as_str) - .unwrap_or("(no reason given)"); + match classify_poll(&body) { + PollOutcome::Done => return Ok(body), + PollOutcome::Failed(reason) => { return Err(Error::new(ErrorDetails::InferenceServer { message: format!("Amux generation failed: {reason}"), provider_type: PROVIDER_TYPE.to_string(), @@ -343,7 +366,7 @@ async fn poll_async_result( raw_response: Some(body.to_string()), })); } - _ => {} + PollOutcome::Pending => {} } tokio::time::sleep(poll_interval).await; @@ -402,3 +425,75 @@ async fn post_media_callback( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn classify_poll_detects_nested_succeeded() { + // The real Amux universal poll shape nests status under `data` and + // reports success as "succeeded" — the bug this fixes: the old code + // read root `status` and only matched "completed", so it never + // terminated and timed out after the full async budget. + let body = json!({ + "code": "success", + "data": { "status": "succeeded", "url": "https://cdn.amux.ai/x.mp4" } + }); + assert!( + matches!(classify_poll(&body), PollOutcome::Done), + "nested data.status=succeeded must classify as Done" + ); + } + + #[test] + fn classify_poll_detects_nested_failed_with_reason() { + let body = json!({ + "code": "success", + "data": { "status": "failed", "error": { "message": "render blew up" } } + }); + match classify_poll(&body) { + PollOutcome::Failed(reason) => assert_eq!( + reason, "render blew up", + "failure reason must come from data.error.message" + ), + _ => panic!("nested data.status=failed must classify as Failed"), + } + } + + #[test] + fn classify_poll_treats_in_progress_as_pending() { + for status in ["queued", "in_progress", "unknown", ""] { + let body = json!({ "data": { "status": status } }); + assert!( + matches!(classify_poll(&body), PollOutcome::Pending), + "status={status:?} must keep polling (Pending)" + ); + } + } + + #[test] + fn classify_poll_tolerates_flat_shape() { + // Defensive fallback: if Amux ever returns a flat (un-nested) body, + // the root-level status is still honored. + let body = json!({ "status": "completed", "url": "https://cdn.amux.ai/x.mp4" }); + assert!( + matches!(classify_poll(&body), PollOutcome::Done), + "flat root status=completed must classify as Done" + ); + } + + #[test] + fn parse_urls_reads_nested_data_url() { + let body = json!({ + "code": "success", + "data": { "status": "succeeded", "url": "https://cdn.amux.ai/x.mp4" } + }); + assert_eq!( + parse_urls(&body), + vec!["https://cdn.amux.ai/x.mp4".to_string()], + "the completed video URL must be extracted from data.url" + ); + } +}