Skip to content

Commit 2fe7a8e

Browse files
committed
feat(novita): add Sora 2 (T2V/I2V) + Pro request shapes
Wires Novita's `/v3/async/sora-2-text2video` and `/v3/async/sora-2-img2video` endpoints. Sora 2's `professional` body flag toggles the Pro tier (1024×1792 / 1792×1024 sizes, 1080p resolution); rather than expose it to callers we route the Pro/non-Pro distinction through separate enum variants and force the body field server-side. This keeps the user-facing model identity stable: the operator picks Pro by routing to the matching variant, not by passing an extra param. Allowed body keys per shape (per Novita docs): - Sora2TextToVideo / Sora2ProTextToVideo: `size`, `duration` - Sora2ImageToVideo / Sora2ProImageToVideo: `image`, `resolution`, `duration`. `image` is a single field accepting either a URL or a Base64 string — no URL/Base64 split like Veo's I2V shape. `prompt` is required for all four (validated upstream). All four hit `/v3/async/...` endpoints, so they ride the existing `async_submission = true` polling path.
1 parent 8684610 commit 2fe7a8e

1 file changed

Lines changed: 54 additions & 10 deletions

File tree

tensorzero-core/src/providers/novita.rs

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use schemars::JsonSchema;
33
use secrecy::{ExposeSecret, SecretString};
44
use serde::{Deserialize, Serialize};
55
use serde_json::{Value, json};
6-
use std::time::Duration;
76
use std::sync::Arc;
7+
use std::time::Duration;
88
use tokio::time::Instant;
99
use url::Url;
1010
use uuid::Uuid;
@@ -46,6 +46,17 @@ pub enum NovitaRequestShape {
4646
VeoTextToVideo,
4747
VeoImageToVideo,
4848
Veo31ImageToVideo,
49+
/// OpenAI Sora 2, text-to-video (basic). `professional` is forced
50+
/// to `false` server-side; the operator picks the Pro tier by
51+
/// routing to the `Sora2ProTextToVideo` variant instead.
52+
Sora2TextToVideo,
53+
/// OpenAI Sora 2, text-to-video (Pro). Same Novita endpoint as
54+
/// `Sora2TextToVideo` but `professional` is forced to `true`.
55+
Sora2ProTextToVideo,
56+
/// OpenAI Sora 2, image-to-video (basic).
57+
Sora2ImageToVideo,
58+
/// OpenAI Sora 2, image-to-video (Pro).
59+
Sora2ProImageToVideo,
4960
}
5061

5162
impl NovitaProvider {
@@ -240,6 +251,20 @@ fn build_body(shape: &NovitaRequestShape, input: &Value) -> Result<Value, Error>
240251
"sample_count",
241252
"seed",
242253
],
254+
// Sora 2 text-to-video. Per Novita's
255+
// `/v3/async/sora-2-text2video` doc: prompt (auto), size,
256+
// duration. `professional` is set explicitly below based on
257+
// the shape (basic vs Pro).
258+
NovitaRequestShape::Sora2TextToVideo | NovitaRequestShape::Sora2ProTextToVideo => {
259+
&["size", "duration"]
260+
}
261+
// Sora 2 image-to-video. Per `/v3/async/sora-2-img2video`:
262+
// prompt (auto), image (URL or Base64 string — passed
263+
// through as a single `image` field, no URL/Base64 split),
264+
// resolution, duration.
265+
NovitaRequestShape::Sora2ImageToVideo | NovitaRequestShape::Sora2ProImageToVideo => {
266+
&["image", "resolution", "duration"]
267+
}
243268
};
244269

245270
if let Some(input_obj) = input.as_object() {
@@ -287,6 +312,24 @@ fn build_body(shape: &NovitaRequestShape, input: &Value) -> Result<Value, Error>
287312
}
288313
}
289314

315+
// Sora 2 routes the Pro/non-Pro distinction through a single Novita
316+
// endpoint with a `professional` body field. Force the value
317+
// server-side so the user can't accidentally upgrade to Pro by
318+
// passing `professional: true` to the basic variant.
319+
if matches!(
320+
shape,
321+
NovitaRequestShape::Sora2TextToVideo
322+
| NovitaRequestShape::Sora2ProTextToVideo
323+
| NovitaRequestShape::Sora2ImageToVideo
324+
| NovitaRequestShape::Sora2ProImageToVideo
325+
) {
326+
let pro = matches!(
327+
shape,
328+
NovitaRequestShape::Sora2ProTextToVideo | NovitaRequestShape::Sora2ProImageToVideo
329+
);
330+
body.insert("professional".into(), Value::Bool(pro));
331+
}
332+
290333
if matches!(shape, NovitaRequestShape::GptImageEdit) && !body.contains_key("image") {
291334
if let Some(first) = input
292335
.get("image_urls")
@@ -371,14 +414,12 @@ fn parse_urls(body: &Value) -> Vec<String> {
371414
let urls: Vec<String> = arr
372415
.iter()
373416
.filter_map(|item| {
374-
item.as_str()
375-
.map(ToString::to_string)
376-
.or_else(|| {
377-
item.get("video_url")
378-
.or_else(|| item.get("url"))
379-
.and_then(Value::as_str)
380-
.map(ToString::to_string)
381-
})
417+
item.as_str().map(ToString::to_string).or_else(|| {
418+
item.get("video_url")
419+
.or_else(|| item.get("url"))
420+
.and_then(Value::as_str)
421+
.map(ToString::to_string)
422+
})
382423
})
383424
.collect();
384425
if !urls.is_empty() {
@@ -396,7 +437,10 @@ async fn poll_async_result(
396437
api_key: &str,
397438
task_id: &str,
398439
) -> Result<Value, Error> {
399-
let url = format!("{}/v3/async/task-result?task_id={task_id}", *NOVITA_API_BASE);
440+
let url = format!(
441+
"{}/v3/async/task-result?task_id={task_id}",
442+
*NOVITA_API_BASE
443+
);
400444
let deadline = Instant::now() + REQUEST_TIMEOUT;
401445
let poll_interval = Duration::from_secs(4);
402446

0 commit comments

Comments
 (0)