diff --git a/README.md b/README.md index c0469cb8a..f382f5f51 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # LTX Desktop -LTX Desktop is an open-source desktop app for generating videos with LTX models — locally on supported Windows/Linux NVIDIA GPUs, with an API mode for unsupported hardware and macOS. +LTX Desktop is an open-source desktop app for generating videos with LTX models — locally on supported Windows/Linux NVIDIA GPUs and Apple Silicon Macs, with an API mode for unsupported hardware. > **Status: Beta.** Expect breaking changes. > Frontend architecture is under active refactor; large UI PRs may be declined for now (see [`CONTRIBUTING.md`](docs/CONTRIBUTING.md)). @@ -34,7 +34,8 @@ LTX Desktop is an open-source desktop app for generating videos with LTX models | Windows (no CUDA, <16GB VRAM, or unknown VRAM) | API-only | **LTX API key required** | | Linux + CUDA GPU with **≥16GB VRAM** | Local generation | Downloads model weights locally | | Linux (no CUDA, <16GB VRAM, or unknown VRAM) | API-only | **LTX API key required** | -| macOS (Apple Silicon builds) | API-only | **LTX API key required** | +| macOS + Apple Silicon with **≥15GB unified memory** | Local generation | Downloads model weights locally | +| macOS + Apple Silicon with <15GB unified memory | API-only | **LTX API key required** | In API-only mode, available resolutions/durations may be limited to what the API supports. @@ -55,9 +56,15 @@ In API-only mode, available resolutions/durations may be limited to what the API - 16GB+ RAM (32GB recommended) - Plenty of free disk space for model weights and outputs +### macOS (local generation) + +- Apple Silicon (arm64) with **≥15GB unified memory** +- macOS 13+ (Ventura) +- **160GB+ free disk space** (for model weights, Python environment, and outputs) + ### macOS (API-only) -- Apple Silicon (arm64) +- Apple Silicon (arm64) with <15GB unified memory - macOS 13+ (Ventura) - Stable internet connection @@ -91,10 +98,10 @@ Text encoding: to generate videos you must configure text encoding: The LTX API is used for: - **Cloud text encoding and prompt enhancement** — **FREE**; text encoding is highly recommended to speed up inference and save memory -- API-based video generations (required on macOS and on unsupported Windows hardware) — paid +- API-based video generations (required on unsupported hardware and low-memory Apple Silicon Macs) — paid - Retake — paid -An LTX API key is required in API-only mode, but optional on Windows/Linux local mode if you enable the Local Text Encoder. +An LTX API key is required in API-only mode, but optional on Windows/Linux/macOS local mode if you enable the Local Text Encoder. Generate a FREE API key at the [LTX Console](https://console.ltx.video/). Text encoding is free; video generation API usage is paid. [Read more](https://ltx.io/model/model-blog/ltx-2-better-control-for-real-workflows). diff --git a/backend/handlers/health_handler.py b/backend/handlers/health_handler.py index 327380cfd..426b91ce5 100644 --- a/backend/handlers/health_handler.py +++ b/backend/handlers/health_handler.py @@ -12,6 +12,7 @@ from handlers.pipelines_handler import PipelinesHandler from logging_policy import log_background_exception from services.interfaces import GpuInfo +from services.services_utils import get_device_type from state.app_state_types import AppState, GpuSlot, StartupError, StartupLoading, StartupPending, StartupReady, VideoPipelineState, VideoPipelineWarmth if TYPE_CHECKING: @@ -108,14 +109,30 @@ def default_warmup(self) -> None: self.set_startup_loading("Loading Fast pipeline", 30) self._pipelines.load_gpu_pipeline("fast", should_warm=False) - self.set_startup_loading("Warming Fast pipeline", 60) - self._pipelines.warmup_pipeline("fast") - with self._lock: - match self.state.gpu_slot: - case GpuSlot(active_pipeline=VideoPipelineState() as state): - state.warmth = VideoPipelineWarmth.WARM - case _: - pass + if get_device_type(self.config.device) != "mps": + self.set_startup_loading("Warming Fast pipeline", 60) + with self._lock: + match self.state.gpu_slot: + case GpuSlot(active_pipeline=VideoPipelineState() as state): + state.warmth = VideoPipelineWarmth.WARMING + case _: + pass + try: + self._pipelines.warmup_pipeline("fast") + except Exception: + with self._lock: + match self.state.gpu_slot: + case GpuSlot(active_pipeline=VideoPipelineState() as state) if state.warmth == VideoPipelineWarmth.WARMING: + state.warmth = VideoPipelineWarmth.COLD + case _: + pass + raise + with self._lock: + match self.state.gpu_slot: + case GpuSlot(active_pipeline=VideoPipelineState() as state): + state.warmth = VideoPipelineWarmth.WARM + case _: + pass zit_models_path = resolve_model_path(self.models_dir, self.config.model_download_specs,"zit") zit_exists = zit_models_path.exists() and any(zit_models_path.iterdir()) diff --git a/backend/handlers/pipelines_handler.py b/backend/handlers/pipelines_handler.py index b8787e181..343dac6f8 100644 --- a/backend/handlers/pipelines_handler.py +++ b/backend/handlers/pipelines_handler.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import time from threading import RLock from typing import TYPE_CHECKING @@ -256,6 +257,16 @@ def load_gpu_pipeline(self, model_type: VideoPipelineModelType, should_warm: boo case _: pass + if state is not None and state.warmth == VideoPipelineWarmth.WARMING: + while state.warmth == VideoPipelineWarmth.WARMING: + time.sleep(1.0) + with self._lock: + match self.state.gpu_slot: + case GpuSlot(active_pipeline=VideoPipelineState() as refreshed): + state = refreshed + case _: + state = None + if state is None: self._evict_gpu_pipeline_for_swap() state = self._create_video_pipeline(model_type) @@ -377,6 +388,15 @@ def load_retake_pipeline(self, *, distilled: bool = True) -> RetakePipelineState return state def warmup_pipeline(self, model_type: VideoPipelineModelType) -> None: - state = self.load_gpu_pipeline(model_type, should_warm=False) + with self._lock: + match self.state.gpu_slot: + case GpuSlot(active_pipeline=VideoPipelineState() as existing_state): + state: VideoPipelineState | None = existing_state + case _: + state = None + + if state is None or not self._pipeline_matches_model_type(model_type): + state = self.load_gpu_pipeline(model_type, should_warm=False) + warmup_path = self.config.outputs_dir / f"_warmup_{model_type}.mp4" state.pipeline.warmup(output_path=str(warmup_path)) diff --git a/backend/handlers/video_generation_handler.py b/backend/handlers/video_generation_handler.py index a0ce15b3a..66aff0e6e 100644 --- a/backend/handlers/video_generation_handler.py +++ b/backend/handlers/video_generation_handler.py @@ -136,10 +136,8 @@ def get_9_16_size(res: str) -> tuple[int, int]: seed = self._resolve_seed() try: - self._pipelines.load_gpu_pipeline("fast", should_warm=False) - self._generation.start_generation(generation_id) - output_path = self.generate_video( + generation_id=generation_id, prompt=req.prompt, image=image, height=height, @@ -164,6 +162,7 @@ def get_9_16_size(res: str) -> tuple[int, int]: def generate_video( self, + generation_id: str, prompt: str, image: Image.Image | None, height: int, @@ -186,12 +185,13 @@ def generate_video( total_steps = 8 - self._generation.update_progress("loading_model", 5, 0, total_steps) t_load_start = time.perf_counter() pipeline_state = self._pipelines.load_gpu_pipeline("fast", should_warm=False) t_load_end = time.perf_counter() logger.info("[%s] Pipeline load: %.2fs", gen_mode, t_load_end - t_load_start) + self._generation.start_generation(generation_id) + self._generation.update_progress("loading_model", 5, 0, total_steps) self._generation.update_progress("encoding_text", 10, 0, total_steps) enhanced_prompt = prompt + self.config.camera_motion_prompts.get(camera_motion, "") @@ -224,6 +224,10 @@ def generate_video( height = round(height / 64) * 64 width = round(width / 64) * 64 + def _on_denoising_step(current_step: int, denoising_total: int) -> None: + pct = 15 + int(75 * current_step / denoising_total) + self._generation.update_progress("inference", pct, current_step, denoising_total) + t_inference_start = time.perf_counter() pipeline_state.pipeline.generate( prompt=enhanced_prompt, @@ -234,6 +238,7 @@ def generate_video( frame_rate=fps, images=images, output_path=str(output_path), + progress_callback=_on_denoising_step, ) t_inference_end = time.perf_counter() logger.info("[%s] Inference: %.2fs", gen_mode, t_inference_end - t_inference_start) @@ -286,6 +291,7 @@ def _generate_a2v( try: a2v_state = self._pipelines.load_a2v_pipeline() self._generation.start_generation(generation_id) + self._generation.update_progress("loading_model", 5, 0, 11) enhanced_prompt = req.prompt + self.config.camera_motion_prompts.get(req.cameraMotion, "") neg = req.negativePrompt if req.negativePrompt else self.config.default_negative_prompt @@ -306,8 +312,6 @@ def _generate_a2v( a2v_enhance = a2v_use_api and a2v_settings.prompt_enhancer_enabled_i2v else: a2v_enhance = a2v_use_api and a2v_settings.prompt_enhancer_enabled_t2v - - self._generation.update_progress("loading_model", 5, 0, total_steps) self._generation.update_progress("encoding_text", 10, 0, total_steps) self._text.prepare_text_encoding(enhanced_prompt, enhance_prompt=a2v_enhance) self._generation.update_progress("inference", 15, 0, total_steps) diff --git a/backend/ltx2_server.py b/backend/ltx2_server.py index 359272154..b3878b38d 100644 --- a/backend/ltx2_server.py +++ b/backend/ltx2_server.py @@ -34,6 +34,14 @@ del _safetensors_loader_fix import services.patches.safetensors_metadata_fix as _safetensors_metadata_fix # pyright: ignore[reportUnusedImport] # Remove once safetensors supports read-only mmap del _safetensors_metadata_fix +import services.patches.mps_layer_streaming_fix as _mps_layer_streaming_fix # pyright: ignore[reportUnusedImport] # Remove once ltx-core adds MPS awareness to _LayerStore +del _mps_layer_streaming_fix +import services.patches.mps_gpu_model_fix as _mps_gpu_model_fix # pyright: ignore[reportUnusedImport] # Remove once ltx-pipelines adds MPS awareness to gpu_model +del _mps_gpu_model_fix +import services.patches.mps_vocoder_fix as _mps_vocoder_fix # pyright: ignore[reportUnusedImport] # Remove once ltx-core adds MPS awareness to VocoderWithBWE +del _mps_vocoder_fix +import services.patches.mps_chunked_attention_fix as _mps_chunked_attention_fix # pyright: ignore[reportUnusedImport] # Remove once ltx-core ships a memory-efficient attention path for MPS +del _mps_chunked_attention_fix from state.app_settings import AppSettings diff --git a/backend/runtime_config/runtime_policy.py b/backend/runtime_config/runtime_policy.py index 5ba38ddc9..68903d98a 100644 --- a/backend/runtime_config/runtime_policy.py +++ b/backend/runtime_config/runtime_policy.py @@ -6,7 +6,9 @@ def decide_force_api_generations(system: str, cuda_available: bool, vram_gb: int | None) -> bool: """Return whether API-only generation must be forced for this runtime.""" if system == "Darwin": - return True + if vram_gb is None: + return True + return vram_gb < 15 if system in ("Windows", "Linux"): if not cuda_available: diff --git a/backend/services/a2v_pipeline/ltx_a2v_pipeline.py b/backend/services/a2v_pipeline/ltx_a2v_pipeline.py index 991a4fcba..162de05cb 100644 --- a/backend/services/a2v_pipeline/ltx_a2v_pipeline.py +++ b/backend/services/a2v_pipeline/ltx_a2v_pipeline.py @@ -9,7 +9,7 @@ from api_types import ImageConditioningInput from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number -from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8 +from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8, get_device_type class LTXa2vPipeline: @@ -45,6 +45,11 @@ def __init__( device=device, quantization=QuantizationPolicy.fp8_cast() if device_supports_fp8(device) else None, ) + # MPS does not support CUDA streams or pin_memory(), so prefetch_count must be 0 + # (synchronous layer streaming) rather than None (no streaming — loads the full + # transformer into GPU memory at once, which causes OOM on large generations). + # The mps_layer_streaming_fix patch makes synchronous streaming safe on MPS. + self._streaming_prefetch_count: int | None = 1 if get_device_type(device) == "mps" else 2 def _run_inference( self, @@ -74,7 +79,7 @@ def _run_inference( audio_start_time=audio_start_time, audio_max_duration=audio_max_duration, tiling_config=tiling_config, - streaming_prefetch_count=2, + streaming_prefetch_count=self._streaming_prefetch_count, ) @torch.inference_mode() diff --git a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py index ccdaa24ee..546b2c28c 100644 --- a/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py +++ b/backend/services/fast_video_pipeline/ltx_fast_video_pipeline.py @@ -2,15 +2,58 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Callable, Iterator +from contextlib import contextmanager import os -from typing import Final, cast +from typing import Any, Final, cast import torch from api_types import ImageConditioningInput from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number -from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8 +from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8, get_device_type + +# Stage 1: 8 denoising steps, Stage 2: 3 denoising steps. +_STAGE1_STEPS = 8 +_STAGE2_STEPS = 3 +_TOTAL_DENOISING_STEPS = _STAGE1_STEPS + _STAGE2_STEPS + +StepCallback = Callable[[int, int], None] # (current_step, total_steps) + + +@contextmanager +def _tqdm_progress_interceptor(callback: StepCallback) -> Iterator[None]: + """Patch tqdm in ltx_pipelines.utils.samplers to forward step updates to callback. + + The denoising loops in samplers.py use tqdm directly with no external + callback hook. We replace tqdm there with a thin wrapper that calls + callback(current_step, total_steps) on each update() call. + """ + import ltx_pipelines.utils.samplers as _samplers_module + + _step_counter: list[int] = [0] + + original_tqdm = _samplers_module.tqdm + + class _ProgressTqdm: + def __init__(self, iterable: Any = None, **kwargs: Any) -> None: + self._items = list(iterable) if iterable is not None else [] + self._tqdm = original_tqdm(self._items, **kwargs) + + def __iter__(self) -> Iterator[Any]: + for item in self._tqdm: + yield item + _step_counter[0] += 1 + callback(_step_counter[0], _TOTAL_DENOISING_STEPS) + + def __len__(self) -> int: + return len(self._items) + + try: + _samplers_module.tqdm = _ProgressTqdm # type: ignore[attr-defined] + yield + finally: + _samplers_module.tqdm = original_tqdm # type: ignore[attr-defined] class LTXFastVideoPipeline: @@ -39,6 +82,11 @@ def __init__(self, checkpoint_path: str, gemma_root: str | None, upsampler_path: self._upsampler_path = upsampler_path self._device = device self._quantization = QuantizationPolicy.fp8_cast() if device_supports_fp8(device) else None + # MPS does not support CUDA streams or pin_memory(), so prefetch_count must be 0 + # (synchronous layer streaming) rather than None (no streaming — loads the full + # transformer into GPU memory at once, which causes OOM on large generations). + # The mps_layer_streaming_fix patch makes synchronous streaming safe on MPS. + self._streaming_prefetch_count: int | None = 1 if get_device_type(device) == "mps" else 2 self.pipeline = DistilledPipeline( distilled_checkpoint_path=checkpoint_path, @@ -71,7 +119,7 @@ def _run_inference( frame_rate=frame_rate, images=[_LtxImageInput(img.path, img.frame_idx, img.strength) for img in images], tiling_config=tiling_config, - streaming_prefetch_count=2, + streaming_prefetch_count=self._streaming_prefetch_count, ) @torch.inference_mode() @@ -85,18 +133,32 @@ def generate( frame_rate: float, images: list[ImageConditioningInput], output_path: str, + progress_callback: StepCallback | None = None, ) -> None: tiling_config = default_tiling_config() - video, audio = self._run_inference( - prompt=prompt, - seed=seed, - height=height, - width=width, - num_frames=num_frames, - frame_rate=frame_rate, - images=images, - tiling_config=tiling_config, - ) + if progress_callback is not None: + with _tqdm_progress_interceptor(progress_callback): + video, audio = self._run_inference( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=images, + tiling_config=tiling_config, + ) + else: + video, audio = self._run_inference( + prompt=prompt, + seed=seed, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + images=images, + tiling_config=tiling_config, + ) chunks = video_chunks_number(num_frames, tiling_config) encode_video_output(video=video, audio=audio, fps=int(frame_rate), output_path=output_path, video_chunks_number_value=chunks) diff --git a/backend/services/ic_lora_pipeline/ltx_ic_lora_pipeline.py b/backend/services/ic_lora_pipeline/ltx_ic_lora_pipeline.py index 14393edf6..298593604 100644 --- a/backend/services/ic_lora_pipeline/ltx_ic_lora_pipeline.py +++ b/backend/services/ic_lora_pipeline/ltx_ic_lora_pipeline.py @@ -9,7 +9,7 @@ from api_types import ImageConditioningInput from services.ltx_pipeline_common import default_tiling_config, encode_video_output, video_chunks_number -from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8 +from services.services_utils import AudioOrNone, TilingConfigType, device_supports_fp8, get_device_type class LTXIcLoraPipeline: @@ -51,6 +51,11 @@ def __init__( device=device, quantization=QuantizationPolicy.fp8_cast() if device_supports_fp8(device) else None, ) + # MPS does not support CUDA streams or pin_memory(), so prefetch_count must be 0 + # (synchronous layer streaming) rather than None (no streaming — loads the full + # transformer into GPU memory at once, which causes OOM on large generations). + # The mps_layer_streaming_fix patch makes synchronous streaming safe on MPS. + self._streaming_prefetch_count: int | None = 1 if get_device_type(device) == "mps" else 2 def _run_inference( self, @@ -76,7 +81,7 @@ def _run_inference( images=[_LtxImageInput(img.path, img.frame_idx, img.strength) for img in images], video_conditioning=video_conditioning, tiling_config=tiling_config, - streaming_prefetch_count=2, + streaming_prefetch_count=self._streaming_prefetch_count, ) @torch.inference_mode() diff --git a/backend/services/patches/mps_chunked_attention_fix.py b/backend/services/patches/mps_chunked_attention_fix.py new file mode 100644 index 000000000..9cdf23262 --- /dev/null +++ b/backend/services/patches/mps_chunked_attention_fix.py @@ -0,0 +1,141 @@ +"""Monkey-patch: chunked scaled_dot_product_attention for MPS (Apple Silicon). + +On MPS, PyTorch does not support Flash Attention or memory-efficient attention +(xformers/FlashAttention-3 are CUDA-only). LTX uses PytorchAttention, which +calls torch.nn.functional.scaled_dot_product_attention with the full Q/K/V +tensors. For long video sequences this allocates an attention matrix of size +O(N²) where N can exceed 10 000 tokens, producing a ~26 GB MTLBuffer that +immediately OOMs: + + MPSCore: failed assertion `Failed to allocate private MTLBuffer for size + 28341043200' + +This patch replaces PytorchAttention.__call__ on MPS with a chunked +implementation that processes the query sequence in fixed-size chunks, keeping +peak memory at O(N × chunk_size) instead of O(N²). CUDA and CPU paths are +left entirely unchanged. + +The chunk size is controlled by the environment variable +LTX_MPS_ATTN_CHUNK_SIZE (default: 512). Larger chunks are faster but use +more memory; smaller chunks are safer on constrained hardware. + +Remove this patch once ltx-core ships a memory-efficient attention path for +MPS (e.g. via torch.nn.attention.SDPBackend.CHUNKED_PREFILL or a first-party +chunked kernel). + +Usage: + import services.patches.mps_chunked_attention_fix # noqa: F401 +""" + +from __future__ import annotations + +import os +from typing import Any + +import torch + +import ltx_core.model.transformer.attention as _attn_module +from ltx_core.model.transformer.attention import PytorchAttention + +_DEFAULT_CHUNK_SIZE = 512 +_CHUNK_SIZE: int = int(os.environ.get("LTX_MPS_ATTN_CHUNK_SIZE", _DEFAULT_CHUNK_SIZE)) + +_original_pytorch_attention_call = PytorchAttention.__call__ + + +def _chunked_mps_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None, +) -> torch.Tensor: + """Chunked SDPA for MPS. + + Splits Q into chunks of _CHUNK_SIZE tokens and accumulates the output, + attending each Q-chunk against the full K/V. This keeps the intermediate + attention scores at O(chunk_size × N) instead of O(N²). + + Args: + q, k, v: [B, S, H*D] tensors (ltx_core's packed format before reshaping) + heads: number of attention heads + mask: optional attention mask — may arrive in any of the shapes that + PytorchAttention / SDPA accept: + 2-D [seq_q, seq_k] + 3-D [B_or_1, seq_q, seq_k] + 4-D [B_or_1, H_or_1, seq_q, seq_k] + A mask whose Q-dimension does not match seq_q (e.g. size 0 for + the text-encoder's audio-only path) is passed through as-is and + never sliced, letting SDPA broadcast it correctly. + + Returns: + Output tensor in the same [B, S, H*D] packed format. + """ + b, seq_q, _ = q.shape + dim_head = q.shape[-1] // heads + + # Reshape to [B, H, S, D] (standard SDPA layout) + q = q.view(b, seq_q, heads, dim_head).transpose(1, 2) + k = k.view(b, k.shape[1], heads, dim_head).transpose(1, 2) + v = v.view(b, v.shape[1], heads, dim_head).transpose(1, 2) + + # Normalise mask to 4-D [B, H, seq_q_mask, seq_k] so we can inspect dims. + if mask is not None: + if mask.ndim == 2: + mask = mask.unsqueeze(0) + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + # Determine whether we can safely slice the mask along the Q dimension. + # The mask Q-dim (dim 2) may be: + # == seq_q → slice per chunk + # == 1 → broadcast singleton, pass as-is to every chunk + # == 0 → empty / no-op mask (e.g. text-encoder audio-only path); + # treat as None so SDPA doesn't try to broadcast 0 → chunk_size + if mask is not None and mask.shape[2] == 0: + mask = None + mask_seq_q = mask.shape[2] if mask is not None else 0 + can_slice_mask = mask is not None and mask_seq_q == seq_q + + chunks: list[torch.Tensor] = [] + for start in range(0, seq_q, _CHUNK_SIZE): + end = min(start + _CHUNK_SIZE, seq_q) + q_chunk = q[:, :, start:end, :] + + if can_slice_mask: + assert mask is not None # narrowing for pyright + mask_chunk: torch.Tensor | None = mask[:, :, start:end, :] + else: + mask_chunk = mask # broadcast as-is (covers size-0 and size-1 cases) + + chunk_out = torch.nn.functional.scaled_dot_product_attention( + q_chunk, k, v, attn_mask=mask_chunk, dropout_p=0.0, is_causal=False + ) + chunks.append(chunk_out) + + # Reassemble and return in packed [B, S, H*D] format + out = torch.cat(chunks, dim=2) # [B, H, S, D] + out = out.transpose(1, 2).reshape(b, seq_q, heads * dim_head) + return out + + +def _patched_pytorch_attention_call( + self: PytorchAttention, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, +) -> torch.Tensor: + """Dispatch to chunked attention on MPS, original on other devices.""" + if q.device.type == "mps": + return _chunked_mps_attention(q, k, v, heads, mask) + return _original_pytorch_attention_call(self, q, k, v, heads, mask) + + +# Apply patch. +assert hasattr(PytorchAttention, "__call__") and callable(PytorchAttention.__call__), ( + "PytorchAttention.__call__ not found — was it renamed? " + "The mps_chunked_attention_fix patch needs updating." +) +PytorchAttention.__call__ = _patched_pytorch_attention_call # type: ignore[method-assign] diff --git a/backend/services/patches/mps_gpu_model_fix.py b/backend/services/patches/mps_gpu_model_fix.py new file mode 100644 index 000000000..bd136b818 --- /dev/null +++ b/backend/services/patches/mps_gpu_model_fix.py @@ -0,0 +1,273 @@ +"""Monkey-patch: replace torch.cuda.synchronize() calls with device-aware sync. + +ltx_pipelines calls torch.cuda.synchronize() unconditionally in two places: +- gpu_model() in ltx_pipelines.utils.gpu_model +- _streaming_model() in ltx_pipelines.utils.blocks + +On MPS (Apple Silicon) both raise: + + AssertionError: Torch not compiled with CUDA enabled + +This patch replaces both context managers with implementations that dispatch to +the correct synchronization primitive based on the actual device in use: + CUDA → torch.cuda.synchronize, MPS → torch.mps.synchronize, CPU → no-op. + +For _streaming_model() on MPS, LayerStreamingWrapper cannot be used because it +relies on torch.cuda.Stream, torch.cuda.Event, and torch.cuda.current_stream +throughout (in _AsyncPrefetcher and _register_hooks). Instead, we provide +_MpsLayerStreamingWrapper, a synchronous layer-streaming implementation that +moves each layer to MPS immediately before its forward pass and evicts it +immediately after, without any CUDA stream machinery. + +cleanup_memory() is called indirectly via _cleanup_memory(), which looks up the +function through ltx_pipelines.utils.helpers at call time rather than capturing +it at import time. This ensures we pick up any later patches applied to +cleanup_memory (e.g. by LTXTextEncoder._install_cleanup_memory_patch). + +Remove this patch once ltx-pipelines adds MPS awareness to these functions. + +Usage: + import services.patches.mps_gpu_model_fix # noqa: F401 +""" + +from __future__ import annotations + +import functools +import itertools +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any, TypeVar + +import torch +import torch.nn as nn + +import ltx_pipelines.utils.gpu_model as _gpu_model_module +import ltx_pipelines.utils.helpers as _helpers_module + +_M = TypeVar("_M", bound=torch.nn.Module) + + +def _cleanup_memory() -> None: + # Always look up through the module so we pick up any later patches + # (e.g. the cleanup_memory patch installed by LTXTextEncoder). + _helpers_module.cleanup_memory() + + +def _synchronize_device(device: torch.device) -> None: + """Run a device synchronization appropriate for the given device.""" + if device.type == "cuda": + torch.cuda.synchronize(device) + elif device.type == "mps": + torch.mps.synchronize() + # CPU: no synchronization needed. + + +def _synchronize_model(model: torch.nn.Module) -> None: + """Run a device synchronization for all devices a model's tensors live on.""" + devices: set[torch.device] = set() + for tensor in list(model.parameters()) + list(model.buffers()): + devices.add(tensor.device) + for device in devices: + _synchronize_device(device) + + +def _resolve_attr(module: nn.Module, dotted_path: str) -> nn.ModuleList: + obj: Any = module + for part in dotted_path.split("."): + obj = getattr(obj, part) + if not isinstance(obj, nn.ModuleList): + raise TypeError(f"Expected nn.ModuleList at '{dotted_path}', got {type(obj).__name__}") + return obj + + +class _MpsLayerStreamingWrapper(nn.Module): + """Synchronous layer-streaming wrapper for MPS (Apple Silicon). + + LayerStreamingWrapper cannot be used on MPS because _AsyncPrefetcher and + _register_hooks use torch.cuda.Stream/Event/current_stream throughout. + + This wrapper achieves the same memory-reduction goal — keep transformer + layers on CPU and move them to MPS one at a time — using synchronous + transfers and forward hooks. There is no async prefetch, so throughput + may be slightly lower than CUDA, but peak GPU memory is the same. + + Non-layer parameters and buffers are moved to MPS at setup time (matching + the behaviour of LayerStreamingWrapper._setup). + """ + + def __init__( + self, + model: nn.Module, + layers_attr: str, + target_device: torch.device, + ) -> None: + super().__init__() + self._model = model + self._layers = _resolve_attr(model, layers_attr) + self._target_device = target_device + self._hooks: list[torch.utils.hooks.RemovableHandle] = [] + + # Record source (CPU) tensors for each layer so we can restore them + # after eviction — same approach as _LayerStore. + self._source_data: list[dict[str, torch.Tensor]] = [] + for layer in self._layers: + source: dict[str, torch.Tensor] = {} + for name, tensor in itertools.chain(layer.named_parameters(), layer.named_buffers()): + source[name] = tensor.data + self._source_data.append(source) + + self._setup() + + def _setup(self) -> None: + # Move non-layer params/buffers to MPS so the rest of the model is + # ready for computation without waiting for layer transfers. + layer_tensor_ids: set[int] = set() + for layer in self._layers: + for t in itertools.chain(layer.parameters(), layer.buffers()): + layer_tensor_ids.add(id(t)) + + for p in self._model.parameters(): + if id(p) not in layer_tensor_ids: + p.data = p.data.to(self._target_device) + for b in self._model.buffers(): + if id(b) not in layer_tensor_ids: + b.data = b.data.to(self._target_device) + + self._register_hooks() + + def _move_layer_to_device(self, idx: int, layer: nn.Module) -> None: + source = self._source_data[idx] + for name, param in itertools.chain(layer.named_parameters(), layer.named_buffers()): + param.data = source[name].to(self._target_device, non_blocking=False) + + def _evict_layer_to_cpu(self, idx: int, layer: nn.Module) -> None: + source = self._source_data[idx] + for name, param in itertools.chain(layer.named_parameters(), layer.named_buffers()): + param.data = source[name] + + def _register_hooks(self) -> None: + idx_map: dict[int, int] = {id(layer): idx for idx, layer in enumerate(self._layers)} + + def _pre_hook(module: nn.Module, _args: Any, *, idx: int) -> None: + self._move_layer_to_device(idx, module) + + def _post_hook(module: nn.Module, _args: Any, _output: Any, *, idx: int) -> None: + self._evict_layer_to_cpu(idx, module) + + for layer in self._layers: + idx = idx_map[id(layer)] + h1 = layer.register_forward_pre_hook(functools.partial(_pre_hook, idx=idx)) + h2 = layer.register_forward_hook(functools.partial(_post_hook, idx=idx)) + self._hooks.extend([h1, h2]) + + def teardown(self) -> None: + for h in self._hooks: + h.remove() + self._hooks.clear() + self._source_data.clear() + + def forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + return self._model(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: # noqa: ANN401 + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self._model, name) + + +@contextmanager +def _patched_gpu_model(model: _M) -> Iterator[_M]: + """Device-aware replacement for gpu_model(). + + Identical to the original except it dispatches synchronization based on + the model's actual device rather than assuming CUDA. + """ + try: + yield model + finally: + _synchronize_model(model) + model.to("meta") + _cleanup_memory() + + +import ltx_pipelines.utils.blocks as _blocks_module # noqa: E402 +from ltx_pipelines.utils.blocks import LayerStreamingWrapper # noqa: E402 + + +@contextmanager +def _patched_streaming_model( + model: _M, + layers_attr: str, + target_device: torch.device, + prefetch_count: int, +) -> Iterator[_M]: + """Device-aware replacement for _streaming_model(). + + On CUDA: delegates to LayerStreamingWrapper (async prefetch with CUDA streams). + On MPS: uses _MpsLayerStreamingWrapper (synchronous, no CUDA streams). + """ + if target_device.type == "mps": + mps_wrapped = _MpsLayerStreamingWrapper( + model, + layers_attr=layers_attr, + target_device=target_device, + ) + try: + yield mps_wrapped # type: ignore[misc] + finally: + mps_wrapped.teardown() + mps_wrapped.to("meta") + _cleanup_memory() + torch.mps.synchronize() + try: + if hasattr(torch._C, "_host_emptyCache"): + torch._C._host_emptyCache() # type: ignore[attr-defined] + except Exception: + pass + else: + wrapped = LayerStreamingWrapper( + model, + layers_attr=layers_attr, + target_device=target_device, + prefetch_count=prefetch_count, + ) + try: + yield wrapped # type: ignore[misc] + finally: + wrapped.teardown() + wrapped.to("meta") + _cleanup_memory() + _synchronize_device(target_device) + try: + if hasattr(torch._C, "_host_emptyCache"): + torch._C._host_emptyCache() # type: ignore[attr-defined] + except Exception: + pass + + +# Apply patches. +# 1. Replace gpu_model in the defining module (catches any future dynamic lookups). +assert hasattr(_gpu_model_module, "gpu_model") and callable(getattr(_gpu_model_module, "gpu_model")), ( + "ltx_pipelines.utils.gpu_model.gpu_model not found — was it renamed? " + "The mps_gpu_model_fix patch needs updating." +) +_gpu_model_module.gpu_model = _patched_gpu_model # type: ignore[assignment] + +# 2. Replace gpu_model in ltx_pipelines.utils.blocks, which does +# `from ltx_pipelines.utils.gpu_model import gpu_model` at import time, +# binding the old function directly into its own namespace. +assert hasattr(_blocks_module, "gpu_model") and callable(getattr(_blocks_module, "gpu_model")), ( + "ltx_pipelines.utils.blocks.gpu_model not found — was it renamed or removed? " + "The mps_gpu_model_fix patch needs updating." +) +_blocks_module.gpu_model = _patched_gpu_model # type: ignore[assignment] + +# 3. Replace _streaming_model in ltx_pipelines.utils.blocks. +# This function also calls torch.cuda.synchronize() unconditionally on +# target_device in its finally block. +assert hasattr(_blocks_module, "_streaming_model") and callable(getattr(_blocks_module, "_streaming_model")), ( + "ltx_pipelines.utils.blocks._streaming_model not found — was it renamed or removed? " + "The mps_gpu_model_fix patch needs updating." +) +_blocks_module._streaming_model = _patched_streaming_model # type: ignore[assignment] diff --git a/backend/services/patches/mps_layer_streaming_fix.py b/backend/services/patches/mps_layer_streaming_fix.py new file mode 100644 index 000000000..ae00b5b46 --- /dev/null +++ b/backend/services/patches/mps_layer_streaming_fix.py @@ -0,0 +1,62 @@ +"""Monkey-patch: skip pin_memory() in _LayerStore.move_to_gpu on MPS. + +pin_memory() is a CUDA concept for host-pinned memory that enables async +DMA H2D transfers. MPS (Apple Metal Performance Shaders) does not support +it: calling tensor.pin_memory() on a CPU tensor when the target device is +MPS raises: + + RuntimeError: Attempted to set the storage of a tensor on device "cpu" + to a storage on different device "mps:0". This is no longer allowed; + the devices must match. + +This patch replaces _LayerStore.move_to_gpu with an implementation that +skips pin_memory() when the target device is MPS and does a direct +synchronous .to(device) copy instead. + +Remove this patch once ltx-core adds MPS awareness to _LayerStore. + +Usage: + import services.patches.mps_layer_streaming_fix # noqa: F401 +""" + +from __future__ import annotations + +import itertools + +import torch +from torch import nn + +from ltx_core.layer_streaming import _LayerStore # type: ignore[reportPrivateImportUsage] + + +_original_move_to_gpu = _LayerStore.move_to_gpu + + +def _patched_move_to_gpu( + self: _LayerStore, + idx: int, + layer: nn.Module, + *, + non_blocking: bool = False, +) -> None: + """Move layer idx to GPU, skipping pin_memory() on MPS targets.""" + self._check_idx(idx) + if idx in self._on_gpu: + return + source = self._source_data[idx] + + if self.target_device.type == "mps": + # MPS does not support pinned host memory — copy directly. + for name, param in itertools.chain(layer.named_parameters(), layer.named_buffers()): + param.data = source[name].to(self.target_device, non_blocking=False) + # No pinned_in_flight tracking needed for MPS (transfers are synchronous). + self._on_gpu.add(idx) + else: + _original_move_to_gpu(self, idx, layer, non_blocking=non_blocking) + + +# Apply patch. +assert hasattr(_LayerStore, "move_to_gpu") and callable(getattr(_LayerStore, "move_to_gpu")), ( + "_LayerStore.move_to_gpu not found — was it renamed? The mps_layer_streaming_fix patch needs updating." +) +_LayerStore.move_to_gpu = _patched_move_to_gpu # type: ignore[assignment] diff --git a/backend/services/patches/mps_vocoder_fix.py b/backend/services/patches/mps_vocoder_fix.py new file mode 100644 index 000000000..c6fb3c98f --- /dev/null +++ b/backend/services/patches/mps_vocoder_fix.py @@ -0,0 +1,64 @@ +"""Monkey-patch: fix VocoderWithBWE dtype mismatch on MPS. + +VocoderWithBWE.forward() uses torch.autocast(dtype=torch.float32) to upcast +bf16 weights per-op during inference. MPS does not support float32 autocast: + + UserWarning: In MPS autocast, but the target dtype is not supported. + Disabling autocast. MPS Autocast only supports dtypes of torch.bfloat16, + torch.float16 currently. + +When autocast is disabled, the model weights remain in bfloat16 but the input +is cast to float32 (mel_spec.float()), causing: + + RuntimeError: Input type (float) and bias type (c10::BFloat16) should be + the same + +On MPS we temporarily convert the vocoder submodule to float32 before the +forward pass and restore it to bfloat16 afterward. This is the same approach +the upstream code explicitly benchmarked and rejected for CUDA ("+324 MB peak +VRAM, 149 ms") — on MPS the memory penalty is acceptable because unified +memory is not subject to the same VRAM constraints. + +Remove this patch once ltx-core adds MPS awareness to VocoderWithBWE.forward. + +Usage: + import services.patches.mps_vocoder_fix # noqa: F401 +""" + +from __future__ import annotations + +import warnings +from typing import Any + +import torch +import torch.nn as nn + +from ltx_core.model.audio_vae.vocoder import VocoderWithBWE # type: ignore[reportMissingImports] + +_original_forward = VocoderWithBWE.forward + + +def _patched_forward(self: VocoderWithBWE, mel_spec: torch.Tensor) -> torch.Tensor: + if mel_spec.device.type != "mps": + return _original_forward(self, mel_spec) + + # MPS does not support float32 autocast — convert weights temporarily. + # Suppress the expected "MPS Autocast only supports dtypes of torch.bfloat16" + # warning: _original_forward still contains the autocast block but our + # self.float() conversion makes the weights match the float32 input, so + # the autocast no-op is harmless. + original_dtype = next(self.parameters()).dtype + try: + self.float() # convert all weights to float32 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="In MPS autocast", category=UserWarning) + return _original_forward(self, mel_spec) + finally: + self.to(dtype=original_dtype) # restore original dtype (bfloat16) + + +assert hasattr(VocoderWithBWE, "forward") and callable(VocoderWithBWE.forward), ( + "VocoderWithBWE.forward not found — was it renamed? " + "The mps_vocoder_fix patch needs updating." +) +VocoderWithBWE.forward = _patched_forward # type: ignore[method-assign] diff --git a/backend/services/patches/safetensors_loader_fix.py b/backend/services/patches/safetensors_loader_fix.py index 56c6d54aa..c0c195fe0 100644 --- a/backend/services/patches/safetensors_loader_fix.py +++ b/backend/services/patches/safetensors_loader_fix.py @@ -100,7 +100,15 @@ def _patched_load( expected_name = name if sd_ops is None else sd_ops.apply_to_key(name) if expected_name is None: continue - value = value.to(device=device, non_blocking=True, copy=False) + # copy=False is not valid when moving from CPU mmap storage to a + # non-CPU device (e.g. MPS): the devices must match for in-place + # storage reuse. Always copy when crossing device boundaries. + # non_blocking=True is also unsafe on MPS with mmap-backed tensors: + # MPS has no equivalent of CUDA's pinned-memory DMA, so the async + # transfer can read freed/remapped mmap pages → segfault. + needs_copy = value.device != device + non_blocking = device.type == "cuda" + value = value.to(device=device, non_blocking=non_blocking, copy=needs_copy) key_value_pairs = ((expected_name, value),) if sd_ops is not None: key_value_pairs = sd_ops.apply_to_key_value(expected_name, value) diff --git a/backend/services/retake_pipeline/ltx_retake_pipeline.py b/backend/services/retake_pipeline/ltx_retake_pipeline.py index 70522bf8f..624ce50ae 100644 --- a/backend/services/retake_pipeline/ltx_retake_pipeline.py +++ b/backend/services/retake_pipeline/ltx_retake_pipeline.py @@ -64,8 +64,14 @@ def __init__( VideoDecoder, ) + from services.services_utils import get_device_type self.device = device self.dtype = torch.bfloat16 + # MPS does not support CUDA streams or pin_memory(), so prefetch_count must be 0 + # (synchronous layer streaming) rather than None (no streaming — loads the full + # transformer into GPU memory at once, which causes OOM on large generations). + # The mps_layer_streaming_fix patch makes synchronous streaming safe on MPS. + self._streaming_prefetch_count: int | None = 1 if get_device_type(device) == "mps" else 2 self.prompt_encoder = PromptEncoder( checkpoint_path=checkpoint_path, @@ -284,7 +290,7 @@ def generate( regenerate_audio=regenerate_audio, enhance_prompt=enhance_prompt, distilled=distilled, - streaming_prefetch_count=2, + streaming_prefetch_count=self._streaming_prefetch_count, ) audio_out: Audio | None = audio tiling_config = TilingConfig.default() diff --git a/backend/services/text_encoder/ltx_text_encoder.py b/backend/services/text_encoder/ltx_text_encoder.py index 5fb760396..085dde1fa 100644 --- a/backend/services/text_encoder/ltx_text_encoder.py +++ b/backend/services/text_encoder/ltx_text_encoder.py @@ -126,9 +126,11 @@ def _install_cleanup_memory_patch(self, state_getter: Callable[[], AppState]) -> return try: + import gc + from ltx_pipelines.utils import helpers as ltx_utils - original_cleanup_memory = ltx_utils.cleanup_memory + device = self.device def patched_cleanup_memory() -> None: state = state_getter() @@ -138,7 +140,16 @@ def patched_cleanup_memory() -> None: te_state.cached_encoder.to(torch.device("cpu")) except Exception: logger.warning("Failed to move cached text encoder to CPU", exc_info=True) - original_cleanup_memory() + # Inline a device-aware cleanup instead of delegating to the + # original cleanup_memory (which calls torch.cuda.synchronize() + # unconditionally and crashes on MPS / CPU-only builds). + gc.collect() + if device.type == "cuda": + torch.cuda.empty_cache() + torch.cuda.synchronize() + elif device.type == "mps": + torch.mps.empty_cache() + torch.mps.synchronize() setattr(ltx_utils, "cleanup_memory", patched_cleanup_memory) @@ -150,7 +161,6 @@ def patched_cleanup_memory() -> None: "ltx_pipelines.ic_lora", "ltx_pipelines.a2vid_two_stage", "ltx_pipelines.retake", - "ltx_pipelines.retake_pipeline", ): try: module = __import__(module_name, fromlist=["cleanup_memory"]) diff --git a/backend/tests/test_runtime_policy_decision.py b/backend/tests/test_runtime_policy_decision.py index b05cdd8bf..d12e79af5 100644 --- a/backend/tests/test_runtime_policy_decision.py +++ b/backend/tests/test_runtime_policy_decision.py @@ -5,11 +5,20 @@ from runtime_config.runtime_policy import decide_force_api_generations -def test_darwin_always_forces_api() -> None: - assert decide_force_api_generations(system="Darwin", cuda_available=True, vram_gb=24) is True +def test_darwin_with_unknown_vram_forces_api() -> None: assert decide_force_api_generations(system="Darwin", cuda_available=False, vram_gb=None) is True +def test_darwin_with_low_vram_forces_api() -> None: + assert decide_force_api_generations(system="Darwin", cuda_available=False, vram_gb=14) is True + + +def test_darwin_with_required_vram_allows_local_mode() -> None: + assert decide_force_api_generations(system="Darwin", cuda_available=False, vram_gb=15) is False + assert decide_force_api_generations(system="Darwin", cuda_available=False, vram_gb=24) is False + assert decide_force_api_generations(system="Darwin", cuda_available=False, vram_gb=48) is False + + def test_windows_without_cuda_forces_api() -> None: assert decide_force_api_generations(system="Windows", cuda_available=False, vram_gb=24) is True