Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions backend/handlers/video_generation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def generate_video(
if not resolve_model_path(self.models_dir, self.config.model_download_specs,"checkpoint").exists():
raise RuntimeError("Models not downloaded. Please download the AI models first using the Model Status menu.")

total_steps = 8
from services.fast_video_pipeline.ltx_fast_video_pipeline import total_denoising_steps
total_steps = total_denoising_steps()

self._generation.update_progress("loading_model", 5, 0, total_steps)
t_load_start = time.perf_counter()
Expand Down Expand Up @@ -224,6 +225,13 @@ def generate_video(
height = round(height / 64) * 64
width = round(width / 64) * 64

def _on_denoising_step(current_step: int, denoising_total: int) -> None:
pct = 15 + int(75 * current_step / denoising_total)
if current_step >= denoising_total:
self._generation.update_progress("decoding", pct, None, None)
else:
self._generation.update_progress("inference", pct, current_step, denoising_total)

t_inference_start = time.perf_counter()
pipeline_state.pipeline.generate(
prompt=enhanced_prompt,
Expand All @@ -234,6 +242,7 @@ def generate_video(
frame_rate=fps,
images=images,
output_path=str(output_path),
progress_callback=_on_denoising_step,
)
t_inference_end = time.perf_counter()
logger.info("[%s] Inference: %.2fs", gen_mode, t_inference_end - t_inference_start)
Expand Down Expand Up @@ -286,6 +295,7 @@ def _generate_a2v(
try:
a2v_state = self._pipelines.load_a2v_pipeline()
self._generation.start_generation(generation_id)
self._generation.update_progress("loading_model", 5, 0, 11)

enhanced_prompt = req.prompt + self.config.camera_motion_prompts.get(req.cameraMotion, "")
neg = req.negativePrompt if req.negativePrompt else self.config.default_negative_prompt
Expand All @@ -306,8 +316,6 @@ def _generate_a2v(
a2v_enhance = a2v_use_api and a2v_settings.prompt_enhancer_enabled_i2v
else:
a2v_enhance = a2v_use_api and a2v_settings.prompt_enhancer_enabled_t2v

self._generation.update_progress("loading_model", 5, 0, total_steps)
self._generation.update_progress("encoding_text", 10, 0, total_steps)
self._text.prepare_text_encoding(enhanced_prompt, enhance_prompt=a2v_enhance)
self._generation.update_progress("inference", 15, 0, total_steps)
Expand Down
2 changes: 2 additions & 0 deletions backend/services/fast_video_pipeline/fast_video_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, ClassVar, Literal, Protocol

from api_types import ImageConditioningInput
Expand Down Expand Up @@ -32,6 +33,7 @@ def generate(
frame_rate: float,
images: list[ImageConditioningInput],
output_path: str,
progress_callback: Callable[[int, int], None] | None = None,
) -> None:
...

Expand Down
82 changes: 70 additions & 12 deletions backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,60 @@

from __future__ import annotations

from collections.abc import Iterator
from collections.abc import Callable, Iterator
from contextlib import contextmanager
import os
from typing import Final, cast
from typing import Any, Final, cast

import torch

from api_types import ImageConditioningInput
from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number
from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8

def total_denoising_steps() -> int:
from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES

return (len(DISTILLED_SIGMA_VALUES) - 1) + (len(STAGE_2_DISTILLED_SIGMA_VALUES) - 1)

StepCallback = Callable[[int, int], None] # (current_step, total_steps)


@contextmanager
def _tqdm_progress_interceptor(callback: StepCallback) -> Iterator[None]:
"""Patch tqdm in ltx_pipelines.utils.samplers to forward step updates to callback.

The denoising loops in samplers.py use tqdm directly with no external
callback hook. We replace tqdm there with a thin wrapper that calls
callback(current_step, total_steps) on each update() call.
"""
import ltx_pipelines.utils.samplers as _samplers_module

_step_counter: list[int] = [0]
total_steps = total_denoising_steps()

original_tqdm = _samplers_module.tqdm

class _ProgressTqdm:
def __init__(self, iterable: Any = None, **kwargs: Any) -> None:
self._items = list(iterable) if iterable is not None else []
self._tqdm = original_tqdm(self._items, **kwargs)

def __iter__(self) -> Iterator[Any]:
for item in self._tqdm:
yield item
_step_counter[0] += 1
callback(_step_counter[0], total_steps)

def __len__(self) -> int:
return len(self._items)

try:
_samplers_module.tqdm = _ProgressTqdm # type: ignore[attr-defined]
yield
finally:
_samplers_module.tqdm = original_tqdm # type: ignore[attr-defined]


class LTXFastVideoPipeline:
pipeline_kind: Final = "fast"
Expand Down Expand Up @@ -85,18 +129,32 @@ def generate(
frame_rate: float,
images: list[ImageConditioningInput],
output_path: str,
progress_callback: StepCallback | None = None,
) -> None:
tiling_config = default_tiling_config()
video, audio = self._run_inference(
prompt=prompt,
seed=seed,
height=height,
width=width,
num_frames=num_frames,
frame_rate=frame_rate,
images=images,
tiling_config=tiling_config,
)
if progress_callback is not None:
with _tqdm_progress_interceptor(progress_callback):
video, audio = self._run_inference(
prompt=prompt,
seed=seed,
height=height,
width=width,
num_frames=num_frames,
frame_rate=frame_rate,
images=images,
tiling_config=tiling_config,
)
else:
video, audio = self._run_inference(
prompt=prompt,
seed=seed,
height=height,
width=width,
num_frames=num_frames,
frame_rate=frame_rate,
images=images,
tiling_config=tiling_config,
)
chunks = video_chunks_number(num_frames, tiling_config)
encode_video_output(video=video, audio=audio, fps=int(frame_rate), output_path=output_path, video_chunks_number_value=chunks)

Expand Down
1 change: 1 addition & 0 deletions backend/tests/fakes/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def generate(
frame_rate: float,
images: list[ImageConditioningInput],
output_path: str,
progress_callback: Callable[[int, int], None] | None = None,
) -> None:
self._record_generate(
{
Expand Down
44 changes: 25 additions & 19 deletions frontend/hooks/use-generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ function getImageDimensions(settings: GenerationSettings): { width: number; heig
}

// Map phase to user-friendly message
function getPhaseMessage(phase: string): string {
function getPhaseMessage(phase: string, currentStep?: number | null, totalSteps?: number | null): string {
const steps = currentStep != null && totalSteps != null && totalSteps > 0
? ` (${currentStep}/${totalSteps})`
: ''
switch (phase) {
case 'validating_request':
return 'Validating request...'
Expand All @@ -69,7 +72,7 @@ function getPhaseMessage(phase: string): string {
case 'encoding_text':
return 'Encoding prompt...'
case 'inference':
return 'Generating...'
return `Generating...${steps}`
case 'downloading_output':
return 'Downloading output...'
case 'decoding':
Expand Down Expand Up @@ -101,9 +104,7 @@ export function useGeneration(): UseGenerationReturn {
settings: GenerationSettings,
audioPath?: string | null,
) => {
const statusMsg = settings.model === 'pro'
? 'Loading Pro model & generating...'
: 'Generating video...'
const statusMsg = 'Loading model...'

setState({
isGenerating: true,
Expand Down Expand Up @@ -141,30 +142,35 @@ export function useGeneration(): UseGenerationReturn {

// Poll for real progress from backend with time-based interpolation
let lastPhase = ''
let inferenceStartTime = 0
// Estimated inference time in seconds based on model
const estimatedInferenceTime = settings.model === 'pro' ? 120 : 45
let loadingModelStartTime = 0
// Estimated loading time in seconds based on model
const estimatedLoadingTime = settings.model === 'pro' ? 60 : 30

const pollProgress = async () => {
if (!shouldApplyPollingUpdates) return
try {
const data = await ApiClient.getGenerationProgress()
if (!shouldApplyPollingUpdates) return

// Ignore idle responses (phase="" means backend hasn't started yet)
if (!data.phase || data.status === 'idle') return

let displayProgress = data.progress
let statusMessage = getPhaseMessage(data.phase)
// Time-based interpolation during inference phase
if (data.phase === 'inference') {
if (lastPhase !== 'inference') {
inferenceStartTime = Date.now()
let statusMessage = getPhaseMessage(data.phase, data.currentStep, data.totalSteps)

// Time-based interpolation during loading_model phase
if (data.phase === 'loading_model') {
if (lastPhase !== 'loading_model') {
loadingModelStartTime = Date.now()
}
const elapsed = (Date.now() - inferenceStartTime) / 1000
// Interpolate from 15% to 95% based on estimated time
const inferenceProgress = Math.min(elapsed / estimatedInferenceTime, 0.95)
displayProgress = 15 + Math.floor(inferenceProgress * 80)
const elapsed = (Date.now() - loadingModelStartTime) / 1000
// Interpolate from 5% to 12% based on estimated loading time
const loadingProgress = Math.min(elapsed / estimatedLoadingTime, 0.95)
displayProgress = 5 + Math.floor(loadingProgress * 7)
}

// inference: use real progress from backend (set by _on_denoising_step callback)

// Keep API/local completion as a terminal response state, not polling state.
// Polling complete means backend state is finalized, but request can still be in-flight.
if (data.phase === 'complete' || data.status === 'complete') {
Expand Down