From 7fdaa788f6ac6e722b88f30120d9f6fd642168b7 Mon Sep 17 00:00:00 2001 From: Sergio Gil Date: Mon, 13 Apr 2026 18:58:13 +0200 Subject: [PATCH 1/2] feat(progress): add real-time denoising step progress during inference --- backend/handlers/video_generation_handler.py | 8 +- .../fast_video_pipeline.py | 2 + .../ltx_fast_video_pipeline.py | 81 ++++++++++++++++--- backend/tests/fakes/services.py | 1 + frontend/hooks/use-generation.ts | 44 +++++----- 5 files changed, 103 insertions(+), 33 deletions(-) diff --git a/backend/handlers/video_generation_handler.py b/backend/handlers/video_generation_handler.py index a0ce15b3a..788d91e4c 100644 --- a/backend/handlers/video_generation_handler.py +++ b/backend/handlers/video_generation_handler.py @@ -224,6 +224,10 @@ def generate_video( height = round(height / 64) * 64 width = round(width / 64) * 64 + def _on_denoising_step(current_step: int, denoising_total: int) -> None: + pct = 15 + int(75 * current_step / denoising_total) + self._generation.update_progress("inference", pct, current_step, denoising_total) + t_inference_start = time.perf_counter() pipeline_state.pipeline.generate( prompt=enhanced_prompt, @@ -234,6 +238,7 @@ def generate_video( frame_rate=fps, images=images, output_path=str(output_path), + progress_callback=_on_denoising_step, ) t_inference_end = time.perf_counter() logger.info("[%s] Inference: %.2fs", gen_mode, t_inference_end - t_inference_start) @@ -286,6 +291,7 @@ def _generate_a2v( try: a2v_state = self._pipelines.load_a2v_pipeline() self._generation.start_generation(generation_id) + self._generation.update_progress("loading_model", 5, 0, 11) enhanced_prompt = req.prompt + self.config.camera_motion_prompts.get(req.cameraMotion, "") neg = req.negativePrompt if req.negativePrompt else self.config.default_negative_prompt @@ -306,8 +312,6 @@ def _generate_a2v( a2v_enhance = a2v_use_api and a2v_settings.prompt_enhancer_enabled_i2v else: a2v_enhance = a2v_use_api and a2v_settings.prompt_enhancer_enabled_t2v - - self._generation.update_progress("loading_model", 5, 0, total_steps) self._generation.update_progress("encoding_text", 10, 0, total_steps) self._text.prepare_text_encoding(enhanced_prompt, enhance_prompt=a2v_enhance) self._generation.update_progress("inference", 15, 0, total_steps) diff --git a/backend/services/fast_video_pipeline/fast_video_pipeline.py b/backend/services/fast_video_pipeline/fast_video_pipeline.py index e415524a4..aab82185e 100644 --- a/backend/services/fast_video_pipeline/fast_video_pipeline.py +++ b/backend/services/fast_video_pipeline/fast_video_pipeline.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Callable from typing import TYPE_CHECKING, ClassVar, Literal, Protocol from api_types import ImageConditioningInput @@ -32,6 +33,7 @@ def generate( frame_rate: float, images: list[ImageConditioningInput], output_path: str, + progress_callback: Callable[[int, int], None] | None = None, ) -> None: ... diff --git a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py index ccdaa24ee..78e1d3f56 100644 --- a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py +++ b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py @@ -2,9 +2,10 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Callable, Iterator +from contextlib import contextmanager import os -from typing import Final, cast +from typing import Any, Final, cast import torch @@ -12,6 +13,48 @@ from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8 +# Stage 1: 8 denoising steps, Stage 2: 3 denoising steps. +_STAGE1_STEPS = 8 +_STAGE2_STEPS = 3 +_TOTAL_DENOISING_STEPS = _STAGE1_STEPS + _STAGE2_STEPS + +StepCallback = Callable[[int, int], None] # (current_step, total_steps) + + +@contextmanager +def _tqdm_progress_interceptor(callback: StepCallback) -> Iterator[None]: + """Patch tqdm in ltx_pipelines.utils.samplers to forward step updates to callback. + + The denoising loops in samplers.py use tqdm directly with no external + callback hook. We replace tqdm there with a thin wrapper that calls + callback(current_step, total_steps) on each update() call. + """ + import ltx_pipelines.utils.samplers as _samplers_module + + _step_counter: list[int] = [0] + + original_tqdm = _samplers_module.tqdm + + class _ProgressTqdm: + def __init__(self, iterable: Any = None, **kwargs: Any) -> None: + self._items = list(iterable) if iterable is not None else [] + self._tqdm = original_tqdm(self._items, **kwargs) + + def __iter__(self) -> Iterator[Any]: + for item in self._tqdm: + yield item + _step_counter[0] += 1 + callback(_step_counter[0], _TOTAL_DENOISING_STEPS) + + def __len__(self) -> int: + return len(self._items) + + try: + _samplers_module.tqdm = _ProgressTqdm # type: ignore[attr-defined] + yield + finally: + _samplers_module.tqdm = original_tqdm # type: ignore[attr-defined] + class LTXFastVideoPipeline: pipeline_kind: Final = "fast" @@ -85,18 +128,32 @@ def generate( frame_rate: float, images: list[ImageConditioningInput], output_path: str, + progress_callback: StepCallback | None = None, ) -> None: tiling_config = default_tiling_config() - video, audio = self._run_inference( - prompt=prompt, - seed=seed, - height=height, - width=width, - num_frames=num_frames, - frame_rate=frame_rate, - images=images, - tiling_config=tiling_config, - ) + if progress_callback is not None: + with _tqdm_progress_interceptor(progress_callback): + video, audio = self._run_inference( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=images, + tiling_config=tiling_config, + ) + else: + video, audio = self._run_inference( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=images, + tiling_config=tiling_config, + ) chunks = video_chunks_number(num_frames, tiling_config) encode_video_output(video=video, audio=audio, fps=int(frame_rate), output_path=output_path, video_chunks_number_value=chunks) diff --git a/backend/tests/fakes/services.py b/backend/tests/fakes/services.py index 12700875a..4c440cfd5 100644 --- a/backend/tests/fakes/services.py +++ b/backend/tests/fakes/services.py @@ -507,6 +507,7 @@ def generate( frame_rate: float, images: list[ImageConditioningInput], output_path: str, + progress_callback: Callable[[int, int], None] | None = None, ) -> None: self._record_generate( { diff --git a/frontend/hooks/use-generation.ts b/frontend/hooks/use-generation.ts index ea5889d88..dbac33fea 100644 --- a/frontend/hooks/use-generation.ts +++ b/frontend/hooks/use-generation.ts @@ -56,7 +56,10 @@ function getImageDimensions(settings: GenerationSettings): { width: number; heig } // Map phase to user-friendly message -function getPhaseMessage(phase: string): string { +function getPhaseMessage(phase: string, currentStep?: number | null, totalSteps?: number | null): string { + const steps = currentStep != null && totalSteps != null && totalSteps > 0 + ? ` (${currentStep}/${totalSteps})` + : '' switch (phase) { case 'validating_request': return 'Validating request...' @@ -69,7 +72,7 @@ function getPhaseMessage(phase: string): string { case 'encoding_text': return 'Encoding prompt...' case 'inference': - return 'Generating...' + return `Generating...${steps}` case 'downloading_output': return 'Downloading output...' case 'decoding': @@ -101,9 +104,7 @@ export function useGeneration(): UseGenerationReturn { settings: GenerationSettings, audioPath?: string | null, ) => { - const statusMsg = settings.model === 'pro' - ? 'Loading Pro model & generating...' - : 'Generating video...' + const statusMsg = 'Loading model...' setState({ isGenerating: true, @@ -141,30 +142,35 @@ export function useGeneration(): UseGenerationReturn { // Poll for real progress from backend with time-based interpolation let lastPhase = '' - let inferenceStartTime = 0 - // Estimated inference time in seconds based on model - const estimatedInferenceTime = settings.model === 'pro' ? 120 : 45 - + let loadingModelStartTime = 0 + // Estimated loading time in seconds based on model + const estimatedLoadingTime = settings.model === 'pro' ? 60 : 30 + const pollProgress = async () => { if (!shouldApplyPollingUpdates) return try { const data = await ApiClient.getGenerationProgress() if (!shouldApplyPollingUpdates) return + // Ignore idle responses (phase="" means backend hasn't started yet) + if (!data.phase || data.status === 'idle') return + let displayProgress = data.progress - let statusMessage = getPhaseMessage(data.phase) - - // Time-based interpolation during inference phase - if (data.phase === 'inference') { - if (lastPhase !== 'inference') { - inferenceStartTime = Date.now() + let statusMessage = getPhaseMessage(data.phase, data.currentStep, data.totalSteps) + + // Time-based interpolation during loading_model phase + if (data.phase === 'loading_model') { + if (lastPhase !== 'loading_model') { + loadingModelStartTime = Date.now() } - const elapsed = (Date.now() - inferenceStartTime) / 1000 - // Interpolate from 15% to 95% based on estimated time - const inferenceProgress = Math.min(elapsed / estimatedInferenceTime, 0.95) - displayProgress = 15 + Math.floor(inferenceProgress * 80) + const elapsed = (Date.now() - loadingModelStartTime) / 1000 + // Interpolate from 5% to 12% based on estimated loading time + const loadingProgress = Math.min(elapsed / estimatedLoadingTime, 0.95) + displayProgress = 5 + Math.floor(loadingProgress * 7) } + // inference: use real progress from backend (set by _on_denoising_step callback) + // Keep API/local completion as a terminal response state, not polling state. // Polling complete means backend state is finalized, but request can still be in-flight. if (data.phase === 'complete' || data.status === 'complete') { From c43f99cecd1d452be19f444f6569b814a57b11bb Mon Sep 17 00:00:00 2001 From: Sergio Gil Date: Mon, 13 Apr 2026 20:25:37 +0200 Subject: [PATCH 2/2] fix(progress): dynamic step count and decoding phase detection --- backend/handlers/video_generation_handler.py | 8 ++++++-- .../fast_video_pipeline/ltx_fast_video_pipeline.py | 11 ++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/backend/handlers/video_generation_handler.py b/backend/handlers/video_generation_handler.py index 788d91e4c..958fee08f 100644 --- a/backend/handlers/video_generation_handler.py +++ b/backend/handlers/video_generation_handler.py @@ -184,7 +184,8 @@ def generate_video( if not resolve_model_path(self.models_dir, self.config.model_download_specs,"checkpoint").exists(): raise RuntimeError("Models not downloaded. Please download the AI models first using the Model Status menu.") - total_steps = 8 + from services.fast_video_pipeline.ltx_fast_video_pipeline import total_denoising_steps + total_steps = total_denoising_steps() self._generation.update_progress("loading_model", 5, 0, total_steps) t_load_start = time.perf_counter() @@ -226,7 +227,10 @@ def generate_video( def _on_denoising_step(current_step: int, denoising_total: int) -> None: pct = 15 + int(75 * current_step / denoising_total) - self._generation.update_progress("inference", pct, current_step, denoising_total) + if current_step >= denoising_total: + self._generation.update_progress("decoding", pct, None, None) + else: + self._generation.update_progress("inference", pct, current_step, denoising_total) t_inference_start = time.perf_counter() pipeline_state.pipeline.generate( diff --git a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py index 78e1d3f56..5265ba228 100644 --- a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py +++ b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py @@ -13,10 +13,10 @@ from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8 -# Stage 1: 8 denoising steps, Stage 2: 3 denoising steps. -_STAGE1_STEPS = 8 -_STAGE2_STEPS = 3 -_TOTAL_DENOISING_STEPS = _STAGE1_STEPS + _STAGE2_STEPS +def total_denoising_steps() -> int: + from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES + + return (len(DISTILLED_SIGMA_VALUES) - 1) + (len(STAGE_2_DISTILLED_SIGMA_VALUES) - 1) StepCallback = Callable[[int, int], None] # (current_step, total_steps) @@ -32,6 +32,7 @@ def _tqdm_progress_interceptor(callback: StepCallback) -> Iterator[None]: import ltx_pipelines.utils.samplers as _samplers_module _step_counter: list[int] = [0] + total_steps = total_denoising_steps() original_tqdm = _samplers_module.tqdm @@ -44,7 +45,7 @@ def __iter__(self) -> Iterator[Any]: for item in self._tqdm: yield item _step_counter[0] += 1 - callback(_step_counter[0], _TOTAL_DENOISING_STEPS) + callback(_step_counter[0], total_steps) def __len__(self) -> int: return len(self._items)