Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 107 additions & 12 deletions tensorzero-core/src/providers/amux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,37 @@ fn parse_urls(body: &Value) -> Vec<String> {
Vec::new()
}

/// Terminal classification of a single poll response.
enum PollOutcome {
Done,
Failed(String),
Pending,
}

/// Classify an Amux poll response into a terminal/pending state.
///
/// The universal poll endpoint wraps the task under `data`:
/// `{ code, message, data: { status, url, error } }`. We read the status from
/// there, falling back to the root so a flat shape still works. Amux reports
/// terminal success as `"succeeded"` (not `"completed"`) and failure as
/// `"failed"`; both spellings of success/failure are accepted defensively.
fn classify_poll(body: &Value) -> PollOutcome {
let task = body.get("data").unwrap_or(body);
match task.get("status").and_then(Value::as_str).unwrap_or("") {
"succeeded" | "completed" => PollOutcome::Done,
"failed" | "error" => {
let reason = task
.get("error")
.and_then(|err| err.get("message"))
.and_then(Value::as_str)
.unwrap_or("(no reason given)")
.to_string();
PollOutcome::Failed(reason)
}
_ => PollOutcome::Pending,
}
}

async fn poll_async_result(
http_client: &TensorzeroHttpClient,
api_key: &str,
Expand Down Expand Up @@ -325,25 +356,17 @@ async fn poll_async_result(
}));
}

// Universal status values: queued | in_progress | completed | failed |
// unknown.
let status_str = body.get("status").and_then(Value::as_str).unwrap_or("");
match status_str {
"completed" => return Ok(body),
"failed" => {
let reason = body
.get("error")
.and_then(|err| err.get("message"))
.and_then(Value::as_str)
.unwrap_or("(no reason given)");
match classify_poll(&body) {
PollOutcome::Done => return Ok(body),
PollOutcome::Failed(reason) => {
return Err(Error::new(ErrorDetails::InferenceServer {
message: format!("Amux generation failed: {reason}"),
provider_type: PROVIDER_TYPE.to_string(),
raw_request: None,
raw_response: Some(body.to_string()),
}));
}
_ => {}
PollOutcome::Pending => {}
}

tokio::time::sleep(poll_interval).await;
Expand Down Expand Up @@ -402,3 +425,75 @@ async fn post_media_callback(

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;

#[test]
fn classify_poll_detects_nested_succeeded() {
// The real Amux universal poll shape nests status under `data` and
// reports success as "succeeded" — the bug this fixes: the old code
// read root `status` and only matched "completed", so it never
// terminated and timed out after the full async budget.
let body = json!({
"code": "success",
"data": { "status": "succeeded", "url": "https://cdn.amux.ai/x.mp4" }
});
assert!(
matches!(classify_poll(&body), PollOutcome::Done),
"nested data.status=succeeded must classify as Done"
);
}

#[test]
fn classify_poll_detects_nested_failed_with_reason() {
let body = json!({
"code": "success",
"data": { "status": "failed", "error": { "message": "render blew up" } }
});
match classify_poll(&body) {
PollOutcome::Failed(reason) => assert_eq!(
reason, "render blew up",
"failure reason must come from data.error.message"
),
_ => panic!("nested data.status=failed must classify as Failed"),
}
}

#[test]
fn classify_poll_treats_in_progress_as_pending() {
for status in ["queued", "in_progress", "unknown", ""] {
let body = json!({ "data": { "status": status } });
assert!(
matches!(classify_poll(&body), PollOutcome::Pending),
"status={status:?} must keep polling (Pending)"
);
}
}

#[test]
fn classify_poll_tolerates_flat_shape() {
// Defensive fallback: if Amux ever returns a flat (un-nested) body,
// the root-level status is still honored.
let body = json!({ "status": "completed", "url": "https://cdn.amux.ai/x.mp4" });
assert!(
matches!(classify_poll(&body), PollOutcome::Done),
"flat root status=completed must classify as Done"
);
}

#[test]
fn parse_urls_reads_nested_data_url() {
let body = json!({
"code": "success",
"data": { "status": "succeeded", "url": "https://cdn.amux.ai/x.mp4" }
});
assert_eq!(
parse_urls(&body),
vec!["https://cdn.amux.ai/x.mp4".to_string()],
"the completed video URL must be extracted from data.url"
);
}
}
Loading