diff --git a/backend/handlers/video_generation_handler.py b/backend/handlers/video_generation_handler.py index a0ce15b3a..958fee08f 100644 --- a/backend/handlers/video_generation_handler.py +++ b/backend/handlers/video_generation_handler.py @@ -184,7 +184,8 @@ def generate_video( if not resolve_model_path(self.models_dir, self.config.model_download_specs,"checkpoint").exists(): raise RuntimeError("Models not downloaded. Please download the AI models first using the Model Status menu.") - total_steps = 8 + from services.fast_video_pipeline.ltx_fast_video_pipeline import total_denoising_steps + total_steps = total_denoising_steps() self._generation.update_progress("loading_model", 5, 0, total_steps) t_load_start = time.perf_counter() @@ -224,6 +225,13 @@ def generate_video( height = round(height / 64) * 64 width = round(width / 64) * 64 + def _on_denoising_step(current_step: int, denoising_total: int) -> None: + pct = 15 + int(75 * current_step / denoising_total) + if current_step >= denoising_total: + self._generation.update_progress("decoding", pct, None, None) + else: + self._generation.update_progress("inference", pct, current_step, denoising_total) + t_inference_start = time.perf_counter() pipeline_state.pipeline.generate( prompt=enhanced_prompt, @@ -234,6 +242,7 @@ def generate_video( frame_rate=fps, images=images, output_path=str(output_path), + progress_callback=_on_denoising_step, ) t_inference_end = time.perf_counter() logger.info("[%s] Inference: %.2fs", gen_mode, t_inference_end - t_inference_start) @@ -286,6 +295,7 @@ def _generate_a2v( try: a2v_state = self._pipelines.load_a2v_pipeline() self._generation.start_generation(generation_id) + self._generation.update_progress("loading_model", 5, 0, 11) enhanced_prompt = req.prompt + self.config.camera_motion_prompts.get(req.cameraMotion, "") neg = req.negativePrompt if req.negativePrompt else self.config.default_negative_prompt @@ -306,8 +316,6 @@ def _generate_a2v( a2v_enhance = a2v_use_api and a2v_settings.prompt_enhancer_enabled_i2v else: a2v_enhance = a2v_use_api and a2v_settings.prompt_enhancer_enabled_t2v - - self._generation.update_progress("loading_model", 5, 0, total_steps) self._generation.update_progress("encoding_text", 10, 0, total_steps) self._text.prepare_text_encoding(enhanced_prompt, enhance_prompt=a2v_enhance) self._generation.update_progress("inference", 15, 0, total_steps) diff --git a/backend/services/fast_video_pipeline/fast_video_pipeline.py b/backend/services/fast_video_pipeline/fast_video_pipeline.py index e415524a4..aab82185e 100644 --- a/backend/services/fast_video_pipeline/fast_video_pipeline.py +++ b/backend/services/fast_video_pipeline/fast_video_pipeline.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Callable from typing import TYPE_CHECKING, ClassVar, Literal, Protocol from api_types import ImageConditioningInput @@ -32,6 +33,7 @@ def generate( frame_rate: float, images: list[ImageConditioningInput], output_path: str, + progress_callback: Callable[[int, int], None] | None = None, ) -> None: ... diff --git a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py index ccdaa24ee..5265ba228 100644 --- a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py +++ b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py @@ -2,9 +2,10 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Callable, Iterator +from contextlib import contextmanager import os -from typing import Final, cast +from typing import Any, Final, cast import torch @@ -12,6 +13,49 @@ from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8 +def total_denoising_steps() -> int: + from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES + + return (len(DISTILLED_SIGMA_VALUES) - 1) + (len(STAGE_2_DISTILLED_SIGMA_VALUES) - 1) + +StepCallback = Callable[[int, int], None] # (current_step, total_steps) + + +@contextmanager +def _tqdm_progress_interceptor(callback: StepCallback) -> Iterator[None]: + """Patch tqdm in ltx_pipelines.utils.samplers to forward step updates to callback. + + The denoising loops in samplers.py use tqdm directly with no external + callback hook. We replace tqdm there with a thin wrapper that calls + callback(current_step, total_steps) on each update() call. + """ + import ltx_pipelines.utils.samplers as _samplers_module + + _step_counter: list[int] = [0] + total_steps = total_denoising_steps() + + original_tqdm = _samplers_module.tqdm + + class _ProgressTqdm: + def __init__(self, iterable: Any = None, **kwargs: Any) -> None: + self._items = list(iterable) if iterable is not None else [] + self._tqdm = original_tqdm(self._items, **kwargs) + + def __iter__(self) -> Iterator[Any]: + for item in self._tqdm: + yield item + _step_counter[0] += 1 + callback(_step_counter[0], total_steps) + + def __len__(self) -> int: + return len(self._items) + + try: + _samplers_module.tqdm = _ProgressTqdm # type: ignore[attr-defined] + yield + finally: + _samplers_module.tqdm = original_tqdm # type: ignore[attr-defined] + class LTXFastVideoPipeline: pipeline_kind: Final = "fast" @@ -85,18 +129,32 @@ def generate( frame_rate: float, images: list[ImageConditioningInput], output_path: str, + progress_callback: StepCallback | None = None, ) -> None: tiling_config = default_tiling_config() - video, audio = self._run_inference( - prompt=prompt, - seed=seed, - height=height, - width=width, - num_frames=num_frames, - frame_rate=frame_rate, - images=images, - tiling_config=tiling_config, - ) + if progress_callback is not None: + with _tqdm_progress_interceptor(progress_callback): + video, audio = self._run_inference( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=images, + tiling_config=tiling_config, + ) + else: + video, audio = self._run_inference( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=images, + tiling_config=tiling_config, + ) chunks = video_chunks_number(num_frames, tiling_config) encode_video_output(video=video, audio=audio, fps=int(frame_rate), output_path=output_path, video_chunks_number_value=chunks) diff --git a/backend/tests/fakes/services.py b/backend/tests/fakes/services.py index 12700875a..4c440cfd5 100644 --- a/backend/tests/fakes/services.py +++ b/backend/tests/fakes/services.py @@ -507,6 +507,7 @@ def generate( frame_rate: float, images: list[ImageConditioningInput], output_path: str, + progress_callback: Callable[[int, int], None] | None = None, ) -> None: self._record_generate( { diff --git a/frontend/hooks/use-generation.ts b/frontend/hooks/use-generation.ts index ea5889d88..dbac33fea 100644 --- a/frontend/hooks/use-generation.ts +++ b/frontend/hooks/use-generation.ts @@ -56,7 +56,10 @@ function getImageDimensions(settings: GenerationSettings): { width: number; heig } // Map phase to user-friendly message -function getPhaseMessage(phase: string): string { +function getPhaseMessage(phase: string, currentStep?: number | null, totalSteps?: number | null): string { + const steps = currentStep != null && totalSteps != null && totalSteps > 0 + ? ` (${currentStep}/${totalSteps})` + : '' switch (phase) { case 'validating_request': return 'Validating request...' @@ -69,7 +72,7 @@ function getPhaseMessage(phase: string): string { case 'encoding_text': return 'Encoding prompt...' case 'inference': - return 'Generating...' + return `Generating...${steps}` case 'downloading_output': return 'Downloading output...' case 'decoding': @@ -101,9 +104,7 @@ export function useGeneration(): UseGenerationReturn { settings: GenerationSettings, audioPath?: string | null, ) => { - const statusMsg = settings.model === 'pro' - ? 'Loading Pro model & generating...' - : 'Generating video...' + const statusMsg = 'Loading model...' setState({ isGenerating: true, @@ -141,30 +142,35 @@ export function useGeneration(): UseGenerationReturn { // Poll for real progress from backend with time-based interpolation let lastPhase = '' - let inferenceStartTime = 0 - // Estimated inference time in seconds based on model - const estimatedInferenceTime = settings.model === 'pro' ? 120 : 45 - + let loadingModelStartTime = 0 + // Estimated loading time in seconds based on model + const estimatedLoadingTime = settings.model === 'pro' ? 60 : 30 + const pollProgress = async () => { if (!shouldApplyPollingUpdates) return try { const data = await ApiClient.getGenerationProgress() if (!shouldApplyPollingUpdates) return + // Ignore idle responses (phase="" means backend hasn't started yet) + if (!data.phase || data.status === 'idle') return + let displayProgress = data.progress - let statusMessage = getPhaseMessage(data.phase) - - // Time-based interpolation during inference phase - if (data.phase === 'inference') { - if (lastPhase !== 'inference') { - inferenceStartTime = Date.now() + let statusMessage = getPhaseMessage(data.phase, data.currentStep, data.totalSteps) + + // Time-based interpolation during loading_model phase + if (data.phase === 'loading_model') { + if (lastPhase !== 'loading_model') { + loadingModelStartTime = Date.now() } - const elapsed = (Date.now() - inferenceStartTime) / 1000 - // Interpolate from 15% to 95% based on estimated time - const inferenceProgress = Math.min(elapsed / estimatedInferenceTime, 0.95) - displayProgress = 15 + Math.floor(inferenceProgress * 80) + const elapsed = (Date.now() - loadingModelStartTime) / 1000 + // Interpolate from 5% to 12% based on estimated loading time + const loadingProgress = Math.min(elapsed / estimatedLoadingTime, 0.95) + displayProgress = 5 + Math.floor(loadingProgress * 7) } + // inference: use real progress from backend (set by _on_denoising_step callback) + // Keep API/local completion as a terminal response state, not polling state. // Polling complete means backend state is finalized, but request can still be in-flight. if (data.phase === 'complete' || data.status === 'complete') {