diff --git a/.gitignore b/.gitignore index fb301d52..26c74911 100644 --- a/.gitignore +++ b/.gitignore @@ -221,4 +221,43 @@ images/inputs/* controlnet_test_* # Demo uploads directory -demo/realtime-img2img/uploads/ \ No newline at end of file +demo/realtime-img2img/uploads/logs/ +.cgw.conf + +# Local Claude / session state (per-user, never committed) +.claude/ + +# Test-install debug dumps and snapshots +Debug/ + +# Sibling repo (tracked separately at dotsimulate/StreamDiffusion-installer) +StreamDiffusion-installer/ + +# TD-component dev copy (feeds .tox re-export, lives in dotsimulate/StreamDiffusionTD) +StreamDiffusionTD/ + +# TD custom processors (TD-side, not part of the StreamDiffusion pip package) +custom_processors/ + +# Generated by sd_installer.generate_batch_file — CUDA-variant and path specific +Install_StreamDiffusion.bat +Install_TensorRT.bat +Start_StreamDiffusion.bat + +# Stray file from a previous shell redirect (pip install ...>=0.19.0) +=0.19.0 + +# NVIDIA Nsight profiling outputs +*.nsys-rep +*.qdrep +*.qdstrm +*.ncu-rep +profiles/ +profiler_logs/ +logs/ncu_* + +# Per-session work log (local only, never committed; tracked separately by user) +SESSION_LOG.md + +# Profiling/audit CSV exports (Nsight summaries, kernel stats — generated artifacts) +audit_reports/ diff --git a/configs/profiling/profiling_fp16_cached.yaml b/configs/profiling/profiling_fp16_cached.yaml new file mode 100644 index 00000000..25cbabd6 --- /dev/null +++ b/configs/profiling/profiling_fp16_cached.yaml @@ -0,0 +1,54 @@ +# Profile D — FP16 + cached_attn, no ControlNet, no IPAdapter +# Engine: sdxl-turbo--tiny_vae-True--min_batch-1--max_batch-4--use_cached_attn-True--mode-img2img--trt10.16.1.11--cc89--res-512x512 +# IMPORTANT: build_engines_if_missing=false — fail if engine not present rather than rebuild + +model_id: "stabilityai/sdxl-turbo" + +t_index_list: [16, 35] +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +guidance_scale: 1.0 +num_inference_steps: 50 +seed: 3115577 +delta: 1.0 + +prompt: "default fancy banana wall" +negative_prompt: "" + +mode: "img2img" +frame_buffer_size: 1 +use_denoising_batch: true +use_tiny_vae: true +acceleration: "tensorrt" +cfg_type: "self" +do_add_noise: false +warmup: 10 +use_safety_checker: false +skip_diffusion: false +compile_engines_only: false +build_engines_if_missing: false +static_shapes: true +fp8: false +fp8_allow_fp16_fallback: false +builder_optimization_level: 3 + +scheduler: "lcm" +sampler: "normal" + +use_cached_attn: true +cache_maxframes: 2 +cache_interval: 1 + +enable_similar_image_filter: false +similar_image_filter_threshold: 0.99 +similar_image_filter_max_skip_frame: 1 + +hf_cache: "" + +engine_dir: "D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/engines/td" + +use_controlnet: false +use_ipadapter: false diff --git a/configs/profiling/profiling_fp16_flexible.yaml b/configs/profiling/profiling_fp16_flexible.yaml new file mode 100644 index 00000000..02c3ad78 --- /dev/null +++ b/configs/profiling/profiling_fp16_flexible.yaml @@ -0,0 +1,54 @@ +# Profile — FP16 Flexible (TrtProfile=Flexible: static_shapes=false + optlvl=3) +# Matches engine: sdxl-turbo--tiny_vae-True--min_batch-1--max_batch-4--use_cached_attn-False--optlvl3--mode-img2img--trt10.16.1.11--cc89--res-512x512 +# IMPORTANT: build_engines_if_missing=false — fail if engine not present rather than rebuild + +model_id: "stabilityai/sdxl-turbo" + +t_index_list: [6, 25] +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +guidance_scale: 1.0 +num_inference_steps: 50 +seed: 3115577 +delta: 1.0 + +prompt: "default fancy banana wall" +negative_prompt: "" + +mode: "img2img" +frame_buffer_size: 1 +use_denoising_batch: true +use_tiny_vae: true +acceleration: "tensorrt" +cfg_type: "self" +do_add_noise: false +warmup: 10 +use_safety_checker: false +skip_diffusion: false +compile_engines_only: false +build_engines_if_missing: false +static_shapes: false +fp8: false +fp8_allow_fp16_fallback: false +builder_optimization_level: 3 + +scheduler: "lcm" +sampler: "normal" + +use_cached_attn: false +cache_maxframes: 2 +cache_interval: 1 + +enable_similar_image_filter: false +similar_image_filter_threshold: 0.99 +similar_image_filter_max_skip_frame: 1 + +hf_cache: "" + +engine_dir: "D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/engines/td" + +use_controlnet: false +use_ipadapter: false diff --git a/configs/profiling/profiling_fp16_full.yaml b/configs/profiling/profiling_fp16_full.yaml new file mode 100644 index 00000000..be606556 --- /dev/null +++ b/configs/profiling/profiling_fp16_full.yaml @@ -0,0 +1,66 @@ +# Profile A — FP16 + cached_attn + ControlNet + IPAdapter (full production config) +# Engine: sdxl-turbo--tiny_vae-True--min_batch-1--max_batch-4--tokens4--use_cached_attn-True--controlnet--optlvl3--mode-img2img--trt10.16.1.11--cc89--res-512x512 +# IMPORTANT: build_engines_if_missing=false — fail if engine not present rather than rebuild + +model_id: "stabilityai/sdxl-turbo" + +t_index_list: [16, 35] +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +guidance_scale: 1.0 +num_inference_steps: 50 +seed: 3115577 +delta: 1.0 + +prompt: "default fancy banana wall" +negative_prompt: "" + +mode: "img2img" +frame_buffer_size: 1 +use_denoising_batch: true +use_tiny_vae: true +acceleration: "tensorrt" +cfg_type: "self" +do_add_noise: false +warmup: 10 +use_safety_checker: false +skip_diffusion: false +compile_engines_only: false +build_engines_if_missing: false +static_shapes: true +fp8: false +fp8_allow_fp16_fallback: false +builder_optimization_level: 3 + +scheduler: "lcm" +sampler: "normal" + +use_cached_attn: true +cache_maxframes: 2 +cache_interval: 1 + +enable_similar_image_filter: false +similar_image_filter_threshold: 0.99 +similar_image_filter_max_skip_frame: 1 + +hf_cache: "" + +engine_dir: "D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/engines/td" + +use_controlnet: true +controlnets: + - model_id: "xinsir/controlnet-canny-sdxl-1.0" + conditioning_scale: 0.3 + preprocessor: "canny" + enabled: true + +use_ipadapter: true +ipadapters: + - ipadapter_model_path: "h94/IP-Adapter/sdxl_models/ip-adapter_sdxl.bin" + image_encoder_path: "h94/IP-Adapter/sdxl_models/image_encoder" + scale: 0.25 + enabled: true + type: regular diff --git a/configs/profiling/profiling_fp16_plain.yaml b/configs/profiling/profiling_fp16_plain.yaml new file mode 100644 index 00000000..93b77648 --- /dev/null +++ b/configs/profiling/profiling_fp16_plain.yaml @@ -0,0 +1,55 @@ +# Profile C (fresh) — FP16 plain, single denoising step, no cached_attn, no ControlNet, no IPAdapter +# t_index_list=[16] gives denoising_steps_num=1 → trt_unet_batch_size=1 → compatible with batch=1 static engine +# Engine: sdxl-turbo--tiny_vae-True--min_batch-1--max_batch-4--use_cached_attn-False--mode-img2img--trt10.16.1.11--cc89--res-512x512 +# IMPORTANT: build_engines_if_missing=false — fail if engine not present rather than rebuild + +model_id: "stabilityai/sdxl-turbo" + +t_index_list: [16] +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +guidance_scale: 1.0 +num_inference_steps: 50 +seed: 3115577 +delta: 1.0 + +prompt: "default fancy banana wall" +negative_prompt: "" + +mode: "img2img" +frame_buffer_size: 1 +use_denoising_batch: true +use_tiny_vae: true +acceleration: "tensorrt" +cfg_type: "self" +do_add_noise: false +warmup: 10 +use_safety_checker: false +skip_diffusion: false +compile_engines_only: false +build_engines_if_missing: false +static_shapes: true +fp8: false +fp8_allow_fp16_fallback: false +builder_optimization_level: + +scheduler: "lcm" +sampler: "normal" + +use_cached_attn: false +cache_maxframes: 2 +cache_interval: 1 + +enable_similar_image_filter: false +similar_image_filter_threshold: 0.99 +similar_image_filter_max_skip_frame: 1 + +hf_cache: "" + +engine_dir: "D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/engines/td" + +use_controlnet: false +use_ipadapter: false diff --git a/configs/profiling/profiling_fp8v3.yaml b/configs/profiling/profiling_fp8v3.yaml new file mode 100644 index 00000000..b44c897e --- /dev/null +++ b/configs/profiling/profiling_fp8v3.yaml @@ -0,0 +1,54 @@ +# Profile B — FP8v3, no cached_attn, no ControlNet, no IPAdapter +# Engine: sdxl-turbo--tiny_vae-True--min_batch-1--max_batch-4--use_cached_attn-False--fp8v3--mode-img2img--trt10.16.1.11--cc89--res-512x512 +# IMPORTANT: build_engines_if_missing=false — fail if engine not present rather than rebuild + +model_id: "stabilityai/sdxl-turbo" + +t_index_list: [16] +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +guidance_scale: 1.0 +num_inference_steps: 50 +seed: 3115577 +delta: 1.0 + +prompt: "default fancy banana wall" +negative_prompt: "" + +mode: "img2img" +frame_buffer_size: 1 +use_denoising_batch: true +use_tiny_vae: true +acceleration: "tensorrt" +cfg_type: "self" +do_add_noise: false +warmup: 10 +use_safety_checker: false +skip_diffusion: false +compile_engines_only: false +build_engines_if_missing: false +static_shapes: true +fp8: true +fp8_allow_fp16_fallback: false +builder_optimization_level: + +scheduler: "lcm" +sampler: "normal" + +use_cached_attn: false +cache_maxframes: 2 +cache_interval: 1 + +enable_similar_image_filter: false +similar_image_filter_threshold: 0.99 +similar_image_filter_max_skip_frame: 1 + +hf_cache: "" + +engine_dir: "D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/engines/td" + +use_controlnet: false +use_ipadapter: false diff --git a/configs/profiling/profiling_quality_cn.yaml b/configs/profiling/profiling_quality_cn.yaml new file mode 100644 index 00000000..708e4e61 --- /dev/null +++ b/configs/profiling/profiling_quality_cn.yaml @@ -0,0 +1,61 @@ +# Profile — Quality FP16 + ControlNet (canny) only (cached_attn=false, no ipadapter) +# Matches engine: sdxl-turbo--tiny_vae-True--min_batch-1--max_batch-4--use_cached_attn-False--controlnet--optlvl3--mode-img2img--trt10.16.1.11--cc89--res-512x512 +# B2-2 isolation test: does ControlNet alone reproduce the per-frame bool-op overhead? +# IMPORTANT: build_engines_if_missing=false — fail if engine not present rather than rebuild + +model_id: "stabilityai/sdxl-turbo" + +t_index_list: [16, 35] +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +guidance_scale: 1.0 +num_inference_steps: 50 +seed: 3115577 +delta: 1.0 + +prompt: "default fancy banana wall" +negative_prompt: "" + +mode: "img2img" +frame_buffer_size: 1 +use_denoising_batch: true +use_tiny_vae: true +acceleration: "tensorrt" +cfg_type: "self" +do_add_noise: false +warmup: 10 +use_safety_checker: false +skip_diffusion: false +compile_engines_only: false +build_engines_if_missing: false +static_shapes: true +fp8: false +fp8_allow_fp16_fallback: false +builder_optimization_level: 3 + +scheduler: "lcm" +sampler: "normal" + +use_cached_attn: false +cache_maxframes: 2 +cache_interval: 1 + +enable_similar_image_filter: false +similar_image_filter_threshold: 0.99 +similar_image_filter_max_skip_frame: 1 + +hf_cache: "" + +engine_dir: "D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/engines/td" + +use_controlnet: true +controlnets: + - model_id: "xinsir/controlnet-canny-sdxl-1.0" + conditioning_scale: 0.3 + preprocessor: "canny" + enabled: true + +use_ipadapter: false diff --git a/configs/profiling/profiling_quality_ipa.yaml b/configs/profiling/profiling_quality_ipa.yaml new file mode 100644 index 00000000..e1569ebc --- /dev/null +++ b/configs/profiling/profiling_quality_ipa.yaml @@ -0,0 +1,62 @@ +# Profile — Quality FP16 + IPAdapter only (cached_attn=false, no controlnet) +# Matches engine: sdxl-turbo--tiny_vae-True--min_batch-1--max_batch-4--tokens4--use_cached_attn-False--optlvl3--mode-img2img--trt10.16.1.11--cc89--res-512x512 +# B2-2 isolation test: does IPAdapter alone reproduce the per-frame bool-op overhead? +# IMPORTANT: build_engines_if_missing=false — fail if engine not present rather than rebuild + +model_id: "stabilityai/sdxl-turbo" + +t_index_list: [10, 25] +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +guidance_scale: 1.0 +num_inference_steps: 50 +seed: 3115577 +delta: 1.0 + +prompt: "default fancy banana wall" +negative_prompt: "" + +mode: "img2img" +frame_buffer_size: 1 +use_denoising_batch: true +use_tiny_vae: true +acceleration: "tensorrt" +cfg_type: "self" +do_add_noise: false +warmup: 10 +use_safety_checker: false +skip_diffusion: false +compile_engines_only: false +build_engines_if_missing: false +static_shapes: true +fp8: false +fp8_allow_fp16_fallback: false +builder_optimization_level: 3 + +scheduler: "lcm" +sampler: "normal" + +use_cached_attn: false +cache_maxframes: 2 +cache_interval: 1 + +enable_similar_image_filter: false +similar_image_filter_threshold: 0.99 +similar_image_filter_max_skip_frame: 1 + +hf_cache: "" + +engine_dir: "D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/engines/td" + +use_controlnet: false + +use_ipadapter: true +ipadapters: + - ipadapter_model_path: "h94/IP-Adapter/sdxl_models/ip-adapter_sdxl.bin" + image_encoder_path: "h94/IP-Adapter/sdxl_models/image_encoder" + scale: 0.25 + enabled: true + type: regular diff --git a/configs/td_config.yaml.example b/configs/td_config.yaml.example new file mode 100644 index 00000000..48a620a2 --- /dev/null +++ b/configs/td_config.yaml.example @@ -0,0 +1,99 @@ +# StreamDiffusionTD — Reference Configuration Template +# +# This is a TRACKED reference template. The active runtime config lives at +# StreamDiffusionTD/td_config.yaml (gitignored — contains absolute paths and +# local overrides). Copy this file there and edit for your environment. + +model_id: "stabilityai/sdxl-turbo" + +# Core StreamDiffusion parameters +t_index_list: [16] +width: 512 +height: 512 +device: "cuda" +dtype: "float16" + +# Generation parameters (defaults, can be updated via OSC) +guidance_scale: 1.0 +num_inference_steps: 50 +seed: 3115577 +delta: 1.0 + +# Prompt configuration (supports both single and blending) +prompt: "default fancy banana wall" +negative_prompt: "" + +# Optimization settings +mode: "img2img" # Always use img2img engines (mode switching handled at runtime) +frame_buffer_size: 1 +use_denoising_batch: true +use_tiny_vae: true +acceleration: "tensorrt" +cfg_type: "self" +do_add_noise: false +warmup: 10 +use_safety_checker: false +skip_diffusion: false +compile_engines_only: false +build_engines_if_missing: true +static_shapes: true +fp8: false +fp8_allow_fp16_fallback: false + +# builder_optimization_level: maps to TensorRT IBuilderConfig.builder_optimization_level +# Verified against TensorRT 10.x Python API (range 0-5, TRT default 3). +# TrtProfile UI mapping aligned with NVIDIA reference pipelines +# (demoDiffusion: level 3 for FP16 static+dynamic; TensorRT-Model-Optimizer: level 4 for FP8/INT8 quantized): +# 0 = Flexible static_shapes=false + level 3 — FP16 dynamic, NVIDIA demoDiffusion default +# 2 = Fast Build static_shapes=true + level 2 — heuristic-sorted fastest tactics, +# ~30-40% faster build with minimal runtime loss (build-time tradeoff) +# 4 = Quality static_shapes=true + level 3 — FP16 static, NVIDIA demoDiffusion default +# (level 4 has no NVIDIA-validated benefit for unquantized FP16) +# Performance static_shapes=true + level 4 — FP8 quantized, NVIDIA TensorRT-Model-Optimizer default +# Levels 1 and 5 are valid TRT values but not exposed via TrtProfile UI: +# 1 = degraded (top-heuristic + low compile optimization on dynamic kernels) +# 5 = exhaustive (avoided — used by no NVIDIA reference pipeline; community segfault reports) +# Omit / null = auto-detect per GPU (Ada/Ampere/Blackwell → 4, pre-Ampere → 3) +builder_optimization_level: 3 + +# Scheduler and sampler (TCD/StreamV2V support) +scheduler: "lcm" +sampler: "normal" + +# StreamV2V Cached Attention (Cattenable enables, Cattmaxframes/Cattinterval tune) +use_cached_attn: true +cache_maxframes: 2 +cache_interval: 1 + +# Image filtering (similar frame skip) +enable_similar_image_filter: false +similar_image_filter_threshold: 0.99 +similar_image_filter_max_skip_frame: 1 + +# HuggingFace cache directory (for model downloads); leave empty to use default +hf_cache: "" + +# TensorRT engine directory — use a relative path or your local absolute path +engine_dir: "engines/td" + +# ControlNet configuration (disabled) +use_controlnet: false + +# IPAdapter configuration (disabled) +use_ipadapter: false + + + + +# TouchDesigner specific settings +td_settings: + # OSC communication + osc_receive_port: 8576 + osc_transmit_port: 8588 + + # Memory interface + input_mem_name: 'StreamDiffusionTD_512-512' + output_mem_name: 'StreamDiffusionTD_512-512_out' + + # Debug settings + debug_mode: false diff --git a/scripts/profiling/README.md b/scripts/profiling/README.md new file mode 100644 index 00000000..6baa32c3 --- /dev/null +++ b/scripts/profiling/README.md @@ -0,0 +1,119 @@ +# StreamDiffusion Nsight Profiling + +## Quick Start + +### nsys — GPU timeline (benchmark target, existing cached engine) + +Pass `--config` to load the exact same wrapper kwargs as `td_main.py`, guaranteeing a cache +hit — no engine rebuild. The config at `StreamDiffusionTD/td_config.yaml` is the "Quality / FP16" +preset (`stabilityai/sdxl-turbo`, 512×512, fp16, img2img). + +```bat +set NSYS="C:/Program Files/NVIDIA Corporation/Nsight Systems 2025.3.2/target-windows-x64/nsys.exe" +%NSYS% profile --trace=cuda,nvtx,cublas --cuda-memory-usage=true ^ + -o profiles/sdtd_quality_fp16 --force-overwrite true ^ + .venv/Scripts/python scripts/profiling/profile_nsys.py --target benchmark ^ + --config StreamDiffusionTD/td_config.yaml + +REM Open the report: +"C:/Program Files/NVIDIA Corporation/Nsight Systems 2025.3.2/host-windows-x64/nsys-ui.exe" profiles/sdtd_quality_fp16.nsys-rep + +%NSYS% stats --report nvtx_pushpop_trace profiles/sdtd_quality_fp16.nsys-rep > nvtx_trace.txt +%NSYS% stats --report cuda_kern_exec_trace profiles/sdtd_quality_fp16.nsys-rep > kernel_trace.txt +``` + +Without `--config` the script falls back to inline defaults (`KBlueLeaf/kohaku-v2.1`) which will +trigger an engine build if no matching cache exists. + +### nsys — GPU timeline (td_main production path) + +```bat +set NSYS="C:/Program Files/NVIDIA Corporation/Nsight Systems 2025.3.2/target-windows-x64/nsys.exe" + +REM Wrap td_main.py directly; SDTD_NSYS_CAPTURE=1 fires start/stop at precise frame boundaries: +set GPU_PROFILER=1 +set SDTD_NSYS_CAPTURE=1 +set SDTD_NSYS_WARMUP_FRAMES=20 +set SDTD_NSYS_CAPTURE_FRAMES=500 +%NSYS% profile --trace=cuda,nvtx,cublas --capture-range cudaProfilerApi ^ + -o profiles/sdtd_td_main --force-overwrite true ^ + .venv/Scripts/python StreamDiffusionTD/td_main.py + +REM Or let the launcher manage it (no nsys wrapping required for deferred stats): +.venv/Scripts/python scripts/profiling/profile_nsys.py --target td_main --warmup 20 --frames 500 +``` + +### ncu — per-kernel metrics + +```bat +REM Basic metrics (2-3× overhead): +.venv/Scripts/python scripts/profiling/profile_ncu.py --target benchmark --set basic + +REM Roofline analysis: +.venv/Scripts/python scripts/profiling/profile_ncu.py --target benchmark --set roofline --launch-count 100 + +REM See the exact command without running: +.venv/Scripts/python scripts/profiling/profile_ncu.py --target benchmark --dry-run +``` + +--- + +## What's Instrumented + +The following `profiler.region()` names appear in nsys NVTX rows and in the JSON stats file: + +| Region | File | Description | +|---|---|---| +| `frame` | `examples/benchmark/single.py`, `wrapper.py` | Full per-frame round-trip | +| `encode_image` | `pipeline.py` | VAE encode: RGB → latent | +| `predict_x0_batch` | `pipeline.py` | Full diffusion denoising block | +| `unet_step` | `pipeline.py` | UNet forward (inside predict_x0_batch) | +| `scheduler_step` | `pipeline.py` | LCM/TCD scheduler step | +| `trt_infer` | `acceleration/tensorrt/utilities.py` | TRT engine execute_async_v3 / graph launch | +| `decode_image` | `pipeline.py` | VAE decode: latent → RGB | +| `d2h_sync` | `wrapper.py` | Device→Host DMA event sync (np output path) | + +> **CUDA graph note:** `trt_infer` NVTX markers fire only at graph capture time (first 3 warmup +> frames), not on each replay. Set `GPU_PROFILER_NVTX=0` for events-only mode (graph-safe); +> CUDA-event timings in the JSON stats file are always accurate. + +> **CUPTI subscriber note:** When running the benchmark target under nsys, `torch.profiler` +> may print `CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED` — this is benign. nsys and +> torch.profiler both register CUPTI subscribers; CUDA-event timing in `*_stats.json` and +> the nsys GPU timeline are both unaffected. + +--- + +## Environment Variables + +| Variable | Effect | +|---|---| +| `GPU_PROFILER=1` | Activate profiler (master switch). Auto-read by `configure()` in wrapper `__init__`. | +| `GPU_PROFILER_NVTX=0` | Disable NVTX ranges (safe with CUDA graphs); CUDA-event timing stays on. | +| `GPU_PROFILER_EVENTS=0` | Disable CUDA-event timing (NVTX only). | +| `STREAMDIFFUSION_PROFILE_TRT=1` | Activate existing TRT IProfiler (per-layer times; disables CUDA graphs). | +| `SDTD_NSYS_CAPTURE=1` | Enable deferred-capture handshake in `td_manager._streaming_loop`. | +| `SDTD_NSYS_WARMUP_FRAMES` | Frames before `cudaProfilerStart` (default: 20). | +| `SDTD_NSYS_CAPTURE_FRAMES` | Frames to capture after warmup (default: 500). | +| `SDTD_PROFILE_JSON` | Path for `profiler.export_stats()` on capture stop. | +| `SDTD_MODEL_ID` | Override model for `profile_nsys.py --target benchmark`. | +| `SDTD_ACCELERATION` | Override acceleration (default: `tensorrt`). | +| `SDTD_WIDTH` / `SDTD_HEIGHT` | Override resolution for `profile_nsys.py --target benchmark`. | +| `NSYS` | Override nsys.exe path (auto-discovered if unset). | +| `NCU` | Override ncu.exe path (auto-discovered if unset). | + +--- + +## Output Convention + +| Path | Content | +|---|---| +| `profiles/*.nsys-rep` | Nsight Systems reports — open in `nsys-ui.exe` | +| `profiler_logs/*_trace.json` | Chrome trace from `torch.profiler` — open in Perfetto / `chrome://tracing` | +| `profiler_logs/*_stats.json` | CUDA-event timing stats (mean/p50/p95/p99/min/max/total ms per region) | +| `profiler_logs/*_report.md` | Markdown timing table auto-rendered from stats JSON | +| `logs/ncu_*.ncu-rep` | Nsight Compute reports — open in Nsight Compute UI | +| `logs/ncu_*.csv` | Kernel details CSV (with `--csv` flag) | + +All output directories are gitignored. Use the `.nsys-rep` / `.ncu-rep` files directly with the +Nsight UI, or share the `*_stats.json` for lightweight timing comparison. diff --git a/scripts/profiling/profile_ncu.py b/scripts/profiling/profile_ncu.py new file mode 100644 index 00000000..a21be9f4 --- /dev/null +++ b/scripts/profiling/profile_ncu.py @@ -0,0 +1,256 @@ +""" +StreamDiffusion Nsight Compute (ncu) profiling launcher. + +Captures per-kernel performance metrics from StreamDiffusion's TRT inference path. + +── Quick start ───────────────────────────────────────────────────────────────── +Basic metrics (2-3× overhead), first 50 kernels after 10 skipped: + .venv/Scripts/python scripts/profiling/profile_ncu.py --target benchmark --set basic + +Full metrics (20-50× overhead): + .venv/Scripts/python scripts/profiling/profile_ncu.py --target benchmark --set full + +Roofline analysis (3-5× overhead): + .venv/Scripts/python scripts/profiling/profile_ncu.py --target benchmark --set roofline + +── Output ────────────────────────────────────────────────────────────────────── + logs/ncu___.ncu-rep — open in Nsight Compute UI + logs/ncu___.csv — (with --csv) kernel summary table + +── Overhead factors ──────────────────────────────────────────────────────────── + basic: 2-3× (arithmetic throughput, memory bandwidth, occupancy) + full: 20-50× (all hardware counters, multi-pass) + roofline: 3-5× (achievable roofline — adds SM throughput counters) + memoryworkload: 5-10× (L1/L2/DRAM access patterns) + source: 10-30× (source-level annotation, needs -lineinfo in compilation) + +── Notes ──────────────────────────────────────────────────────────────────────── + - ncu attaches to the target process; TRT CUDA graphs must be disabled. + Set STREAMDIFFUSION_PROFILE_TRT=1 — this disables CUDA graphs automatically + (TRTProfiler IProfiler hooks are incompatible with graph replay). + - --launch-skip N: skip the first N kernel launches (skip warmup/compilation). + - --launch-count N: capture the next N launches after the skip. + - Use --dry-run to see the exact ncu command without executing. +""" + +import argparse +import os +import shutil +import subprocess +import sys +import time + + +# ── CLI args ─────────────────────────────────────────────────────────────────── +parser = argparse.ArgumentParser(description="StreamDiffusion Nsight Compute launcher") +parser.add_argument( + "--target", + default="benchmark", + choices=["benchmark", "infer"], + help="Target script: benchmark (single.py, 1 iter) or infer (minimal 1-frame driver)", +) +parser.add_argument( + "--set", + dest="metric_set", + default="basic", + choices=["basic", "full", "roofline", "memoryworkload", "source"], + help="ncu metric preset (default: basic)", +) +parser.add_argument( + "--kernel-regex", + default="", + help="Filter captured kernels by name regex (empty = all kernels)", +) +parser.add_argument( + "--launch-skip", + type=int, + default=0, + help="Skip first N kernel launches (default: 0)", +) +parser.add_argument( + "--launch-count", + type=int, + default=50, + help="Capture N kernel launches after skip (default: 50)", +) +parser.add_argument( + "--csv", + action="store_true", + help="After capture, export details as CSV to logs/", +) +parser.add_argument( + "--dry-run", + action="store_true", + help="Print the ncu command without executing", +) +args = parser.parse_args() + +# ── Paths ────────────────────────────────────────────────────────────────────── +_HERE = os.path.dirname(os.path.abspath(__file__)) +_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..", "..")) +_LOGS_DIR = os.path.join(_PROJECT_ROOT, "logs") +_TIMESTAMP = time.strftime("%Y%m%d_%H%M%S") +_PYTHON = sys.executable + +os.makedirs(_LOGS_DIR, exist_ok=True) + +# ── Locate ncu ───────────────────────────────────────────────────────────────── +# NOTE: We resolve the raw .exe path to bypass the ncu.bat shim, which +# interprets '|' in --kernel-name as a pipe and corrupts the regex. +_NCU_CANDIDATES = [ + os.environ.get("NCU", ""), + r"C:\Program Files\NVIDIA Corporation\Nsight Compute 2025.1.1\target\windows-desktop-win7-x64\ncu.exe", + r"C:\Program Files\NVIDIA Corporation\Nsight Compute 2024.3.2\target\windows-desktop-win7-x64\ncu.exe", + r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin\ncu.exe", + "ncu", +] +_NCU = next((p for p in _NCU_CANDIDATES if p and shutil.which(p)), None) +if _NCU is None: + sys.exit("[profile_ncu] ERROR: ncu not found. Set NCU= or install Nsight Compute.") +print(f"[profile_ncu] ncu: {_NCU}") + +# ── Target command ───────────────────────────────────────────────────────────── +_TARGETS = { + "benchmark": [ + _PYTHON, + os.path.join(_PROJECT_ROOT, "examples", "benchmark", "single.py"), + "--iterations", + "1", + "--warmup", + "0", + "--acceleration", + "tensorrt", + ], + "infer": [ + _PYTHON, + os.path.join(_PROJECT_ROOT, "examples", "benchmark", "single.py"), + "--iterations", + "1", + "--warmup", + "0", + "--acceleration", + "tensorrt", + ], +} + +target_cmd = _TARGETS[args.target] + +# ── Output paths ─────────────────────────────────────────────────────────────── +rep_name = f"ncu_{args.target}_{args.metric_set}_{_TIMESTAMP}" +rep_path = os.path.join(_LOGS_DIR, rep_name + ".ncu-rep") +csv_path = os.path.join(_LOGS_DIR, rep_name + ".csv") + +# ── Build ncu command ────────────────────────────────────────────────────────── +ncu_cmd = [ + str(_NCU), + "--target-processes", + "all", + "--set", + args.metric_set, + "--kernel-name-base", + "demangled", + "--launch-skip", + str(args.launch_skip), + "--launch-count", + str(args.launch_count), + "--import-source", + "yes", + "--source-folders", + os.path.join(_PROJECT_ROOT, "src", "streamdiffusion"), + "--export", + rep_path, +] +if args.kernel_regex: + ncu_cmd += ["--kernel-name", args.kernel_regex] + +ncu_cmd += target_cmd + +# ── Environment ──────────────────────────────────────────────────────────────── +proc_env = dict(os.environ) +proc_env["CUDA_LAUNCH_BLOCKING"] = "1" # required for accurate per-kernel profiling +proc_env["GPU_PROFILER"] = "1" + +print(f"\n[profile_ncu] target: {args.target}") +print(f"[profile_ncu] metric set: {args.metric_set}") +print(f"[profile_ncu] launch skip: {args.launch_skip} count: {args.launch_count}") +print(f"[profile_ncu] output: {rep_path}") +print(f"\n[profile_ncu] Command:\n {' '.join(ncu_cmd)}\n") + +if args.dry_run: + print("[profile_ncu] --dry-run: exiting without executing.") + sys.exit(0) + +# ── Run ncu (stream output + early-abort on known fatal errors) ─────────────── +# ncu often keeps the target process running even after it can't collect metrics +# (e.g. ERR_NVGPUCTRPERM), so a 30-second TRT workload would otherwise complete +# only to produce no .ncu-rep. Stream stderr/stdout, kill the proc on first match. +_FATAL_NCU_PATTERNS = [ + ( + "ERR_NVGPUCTRPERM", + "GPU performance counter access denied. Either run this terminal as " + "Administrator OR open NVIDIA Control Panel -> Desktop menu -> Enable " + "Developer Settings -> Developer -> Manage GPU Performance Counters -> " + "'Allow access to the GPU performance counters to all users'. " + "See https://developer.nvidia.com/ERR_NVGPUCTRPERM", + ), +] + +proc = subprocess.Popen( + ncu_cmd, + env=proc_env, + cwd=_PROJECT_ROOT, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, +) +abort_msg = None +assert proc.stdout is not None +for line in iter(proc.stdout.readline, ""): + sys.stdout.write(line) + sys.stdout.flush() + for pat, msg in _FATAL_NCU_PATTERNS: + if pat in line: + abort_msg = msg + break + if abort_msg: + break + +if abort_msg: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() +else: + proc.wait() + +if abort_msg: + print(f"\n[profile_ncu] ABORTED: {abort_msg}", file=sys.stderr) + sys.exit(2) + +print(f"\n[profile_ncu] ncu exited (code {proc.returncode})") + +# Post-run sanity check: ncu can exit 0 even when no report is produced +# (e.g. zero matched kernels under --kernel-name). Surface that explicitly. +if not os.path.exists(rep_path): + print( + f"[profile_ncu] WARNING: expected report at {rep_path} but file is missing.\n" + " Likely causes: --kernel-name regex matched zero kernels, or ncu " + "encountered an error not in _FATAL_NCU_PATTERNS. Inspect the output above " + "for ==ERROR== / ==WARNING== lines.", + file=sys.stderr, + ) + sys.exit(2) + +print(f"[profile_ncu] Report -> {rep_path}") +print(f"[profile_ncu] Open with: Nsight Compute UI -> File -> Open -> {rep_name}.ncu-rep") + +# ── Optional CSV export ──────────────────────────────────────────────────────── +if args.csv and os.path.exists(rep_path): + csv_cmd = [str(_NCU), "--csv", "--page", "details", "--import", rep_path] + print(f"\n[profile_ncu] Exporting CSV -> {csv_path}") + with open(csv_path, "w") as csv_fh: + subprocess.run(csv_cmd, stdout=csv_fh, env=proc_env) + print(f"[profile_ncu] CSV -> {csv_path}") diff --git a/scripts/profiling/profile_nsys.py b/scripts/profiling/profile_nsys.py new file mode 100644 index 00000000..c73107a0 --- /dev/null +++ b/scripts/profiling/profile_nsys.py @@ -0,0 +1,324 @@ +""" +StreamDiffusion Nsight Systems profiling launcher. + +Supports two targets: + --target benchmark In-process wrapper loop (clean, deterministic, no TouchDesigner needed) + --target td_main Subprocess td_main.py with deferred-capture env vars (real production path) + +── benchmark target ─────────────────────────────────────────────────────────── +Run standalone (torch.profiler only, no nsys): + .venv/Scripts/python scripts/profiling/profile_nsys.py --target benchmark + +Run under nsys for GPU timeline: + set NSYS="C:/Program Files/NVIDIA Corporation/Nsight Systems 2025.3.2/target-windows-x64/nsys.exe" + %NSYS% profile --trace=cuda,nvtx,cublas --cuda-memory-usage=true ^ + -o profiles/sdtd_benchmark --force-overwrite true ^ + .venv/Scripts/python scripts/profiling/profile_nsys.py --target benchmark + +── td_main target ───────────────────────────────────────────────────────────── +Launches StreamDiffusionTD/td_main.py with deferred-capture env vars so +cudaProfilerStart fires after --warmup frames (default 20) and stop + exit +fires after warmup + --frames (default 500). The launcher script itself does +NOT need to run under nsys; td_main.py runs nsys-attached: + + set NSYS="C:/Program Files/NVIDIA Corporation/Nsight Systems 2025.3.2/target-windows-x64/nsys.exe" + %NSYS% profile --trace=cuda,nvtx,cublas --cuda-memory-usage=true ^ + --capture-range cudaProfilerApi ^ + -o profiles/sdtd_td_main --force-overwrite true ^ + .venv/Scripts/python StreamDiffusionTD/td_main.py + + # Or let this script manage the subprocess (SDTD_NSYS_CAPTURE=1 deferred handshake): + .venv/Scripts/python scripts/profiling/profile_nsys.py --target td_main --warmup 20 --frames 500 + +── Post-capture analysis ────────────────────────────────────────────────────── + %NSYS% stats --report cuda_kern_exec_trace profiles/sdtd_benchmark.nsys-rep > kernel_trace.txt + %NSYS% stats --report nvtx_pushpop_trace profiles/sdtd_benchmark.nsys-rep > nvtx_trace.txt + "C:/Program Files/NVIDIA Corporation/Nsight Systems 2025.3.2/host-windows-x64/nsys-ui.exe" profiles/sdtd_benchmark.nsys-rep + +── CUDA graph + NVTX note ───────────────────────────────────────────────────── + NVTX push/pop calls embedded in a CUDA graph fire only at capture time + (3-warmup passes), not on each replay step. For events-only mode (graph-safe) + set GPU_PROFILER_NVTX=0. The CUDA-event timings in profiler_logs/*.json are + always accurate because they use CUDA events, not NVTX. +""" + +import argparse +import json +import os +import shutil +import subprocess +import sys +import time + + +# ── CLI args ─────────────────────────────────────────────────────────────────── +parser = argparse.ArgumentParser(description="StreamDiffusion Nsight Systems profiling launcher") +parser.add_argument( + "--target", + default="benchmark", + choices=["benchmark", "td_main"], + help="Profiling target: 'benchmark' (in-process) or 'td_main' (subprocess)", +) +parser.add_argument( + "--frames", + type=int, + default=500, + help="[td_main] Frames to capture after warmup (default: 500)", +) +parser.add_argument( + "--warmup", + type=int, + default=20, + help="[td_main] Frames to skip before cudaProfilerStart (default: 20)", +) +parser.add_argument( + "--dry-run", + action="store_true", + help="Print subprocess commands without executing (td_main target only)", +) +parser.add_argument( + "--config", + default="", + metavar="PATH", + help="[benchmark] YAML/JSON config file (e.g. StreamDiffusionTD/td_config.yaml). " + "Loads wrapper kwargs from file so an existing cached engine is reused. " + "When empty, uses inline defaults (KBlueLeaf/kohaku-v2.1).", +) +args = parser.parse_args() + +# ── Common paths ─────────────────────────────────────────────────────────────── +_HERE = os.path.dirname(os.path.abspath(__file__)) +_PROJECT_ROOT = os.path.normpath(os.path.join(_HERE, "..", "..")) +_PROFILES_DIR = os.path.join(_PROJECT_ROOT, "profiles") +_PROFILER_LOGS_DIR = os.path.join(_PROJECT_ROOT, "profiler_logs") +_TIMESTAMP = time.strftime("%Y%m%d_%H%M%S") +_PYTHON = sys.executable + +os.makedirs(_PROFILES_DIR, exist_ok=True) +os.makedirs(_PROFILER_LOGS_DIR, exist_ok=True) + +# ── Locate nsys ──────────────────────────────────────────────────────────────── +_NSYS_CANDIDATES = [ + os.environ.get("NSYS", ""), + "nsys", + r"C:\Program Files\NVIDIA Corporation\Nsight Systems 2025.3.2\target-windows-x64\nsys.exe", + r"C:\Program Files\NVIDIA Corporation\Nsight Systems 2025.1.3\target-windows-x64\nsys.exe", + r"C:\Program Files\NVIDIA Corporation\Nsight Systems 2024.6.2\target-windows-x64\nsys.exe", +] +_NSYS = next((p for p in _NSYS_CANDIDATES if p and shutil.which(p)), None) +if _NSYS: + print(f"[profile] nsys: {_NSYS}") +else: + print("[profile] WARNING: nsys not found — torch.profiler only (no GPU kernel timeline).") + print(" Set NSYS= or add nsys to PATH to enable GPU tracing.") + +# ═══════════════════════════════════════════════════════════════════════════════ +# td_main target: deferred-capture subprocess +# ═══════════════════════════════════════════════════════════════════════════════ +if args.target == "td_main": + profile_base = os.path.join(_PROFILES_DIR, f"sdtd_td_main_{_TIMESTAMP}") + json_path = os.path.join(_PROFILER_LOGS_DIR, f"sdtd_td_main_{_TIMESTAMP}_stats.json") + md_path = os.path.join(_PROFILER_LOGS_DIR, f"sdtd_td_main_{_TIMESTAMP}_report.md") + + td_main_script = os.path.join(_PROJECT_ROOT, "StreamDiffusionTD", "td_main.py") + if not os.path.exists(td_main_script): + sys.exit(f"[profile] ERROR: td_main.py not found at {td_main_script}") + + # Build the subprocess environment — deferred-capture handshake + profiler activation + proc_env = dict(os.environ) + proc_env["GPU_PROFILER"] = "1" + proc_env.setdefault("GPU_PROFILER_NVTX", "1") + proc_env["GPU_PROFILER_EVENTS"] = "1" + proc_env["SDTD_NSYS_CAPTURE"] = "1" + proc_env["SDTD_NSYS_WARMUP_FRAMES"] = str(args.warmup) + proc_env["SDTD_NSYS_CAPTURE_FRAMES"] = str(args.frames) + proc_env["SDTD_PROFILE_JSON"] = json_path + + td_cmd = [_PYTHON, td_main_script] + + print("\n[profile] td_main deferred-capture run") + print(f" warmup: {args.warmup} frames capture: {args.frames} frames") + print(f" profile out: {profile_base}.nsys-rep (wrap with nsys manually)") + print(f" stats json: {json_path}") + print(f"\n Command: {' '.join(td_cmd)}") + print("\n Tip: wrap td_main.py directly with nsys for GPU kernel capture:") + print(" nsys profile --trace=cuda,nvtx,cublas --capture-range cudaProfilerApi \\") + print(f" -o {profile_base} .venv/Scripts/python StreamDiffusionTD/td_main.py") + + if args.dry_run: + print("\n[profile] --dry-run: exiting without launching.") + sys.exit(0) + + print("\n[profile] Launching td_main.py ...") + proc = subprocess.Popen(td_cmd, env=proc_env, cwd=_PROJECT_ROOT) + print(f"[profile] Waiting for {args.frames} capture frames + {args.warmup} warmup frames ...") + proc.wait() + print(f"[profile] td_main.py exited (code {proc.returncode})") + + # ── Render Markdown report ───────────────────────────────────────────────── + if os.path.exists(json_path): + with open(json_path) as fh: + data = json.load(fh) + regions = sorted(data.get("regions", []), key=lambda r: r["total_ms"], reverse=True) + if regions: + with open(md_path, "w") as rpt: + rpt.write(f"# TD Main Profile — {_TIMESTAMP}\n\n") + rpt.write(f"**Target**: `td_main` **Warmup**: {args.warmup} **Frames**: {args.frames}\n\n") + rpt.write("## Per-region timing (sorted by total ms)\n\n") + rpt.write("| Region | Count | Mean ms | P50 ms | P95 ms | P99 ms | Min ms | Max ms | Total ms |\n") + rpt.write("|---|---:|---:|---:|---:|---:|---:|---:|---:|\n") + for r in regions: + rpt.write( + f"| `{r['name']}` | {r['count']} " + f"| {r['mean_ms']:.3f} | {r['p50_ms']:.3f} | {r['p95_ms']:.3f} " + f"| {r['p99_ms']:.3f} | {r['min_ms']:.3f} | {r['max_ms']:.3f} " + f"| {r['total_ms']:.1f} |\n" + ) + print(f"\n[profile] Report -> {md_path}") + print(f"\n{'Region':<30} {'Mean ms':>8} {'P50 ms':>8} {'P95 ms':>8} {'Total ms':>10}") + print("-" * 70) + for r in regions[:15]: + print( + f" {r['name']:<28} {r['mean_ms']:>8.3f} {r['p50_ms']:>8.3f} {r['p95_ms']:>8.3f} {r['total_ms']:>10.1f}" + ) + else: + print(f"[profile] WARNING: stats JSON not found at {json_path} — td_main may have crashed.") + + print("\n[profile] Complete.") + sys.exit(0) + +# ═══════════════════════════════════════════════════════════════════════════════ +# benchmark target: in-process wrapper loop +# ═══════════════════════════════════════════════════════════════════════════════ +import torch +from torch.profiler import ProfilerActivity, profile, schedule + +from streamdiffusion.tools.gpu_profiler import profiler + + +os.environ.setdefault("GPU_PROFILER", "1") # wrapper.__init__ reads this to activate + +WARMUP_RUNS = 3 # extra warmup before torch.profiler + nsys capture window +PROFILE_RUNS = 10 # inferences captured by nsys / torch.profiler + +sys.path.insert(0, _PROJECT_ROOT) + +# ── Load pipeline ────────────────────────────────────────────────────────────── +if args.config: + # Config-file path: identical wrapper kwargs to td_main.py → cache hit, no rebuild. + from streamdiffusion import create_wrapper_from_config, load_config + + cfg_path = args.config if os.path.isabs(args.config) else os.path.join(_PROJECT_ROOT, args.config) + print(f"[profile] benchmark target — loading wrapper from {cfg_path}") + cfg = load_config(cfg_path) + _WIDTH = cfg.get("width", 512) + _HEIGHT = cfg.get("height", 512) + t0 = time.perf_counter() + stream = create_wrapper_from_config(cfg) # also calls .prepare() +else: + # Inline-defaults path: useful for experiments; may trigger an engine build. + from streamdiffusion import StreamDiffusionWrapper + + print("[profile] benchmark target — loading StreamDiffusionWrapper (inline defaults) ...") + print(" Tip: pass --config StreamDiffusionTD/td_config.yaml to hit an existing cached engine.") + _MODEL_ID = os.environ.get("SDTD_MODEL_ID", "KBlueLeaf/kohaku-v2.1") + _ACCELERATION = os.environ.get("SDTD_ACCELERATION", "tensorrt") + _WIDTH = int(os.environ.get("SDTD_WIDTH", "512")) + _HEIGHT = int(os.environ.get("SDTD_HEIGHT", "512")) + t0 = time.perf_counter() + stream = StreamDiffusionWrapper( + model_id_or_path=_MODEL_ID, + t_index_list=[32, 45], + mode="img2img", + frame_buffer_size=1, + width=_WIDTH, + height=_HEIGHT, + warmup=WARMUP_RUNS, + acceleration=_ACCELERATION, + use_lcm_lora=True, + use_tiny_vae=True, + use_denoising_batch=True, + cfg_type="initialize", + seed=42, + ) + stream.prepare( + prompt="abstract flowing colorful pattern", + negative_prompt="bad quality", + num_inference_steps=50, + guidance_scale=1.4, + delta=0.5, + ) +print(f"[profile] Pipeline ready in {time.perf_counter() - t0:.1f}s\n") + +# ── Dummy input image ────────────────────────────────────────────────────────── +import PIL.Image + + +dummy_img = PIL.Image.new("RGB", (_WIDTH, _HEIGHT), (128, 128, 128)) + +# ── Preprocess once ──────────────────────────────────────────────────────────── +image_tensor = stream.preprocess_image(dummy_img) + + +def run_inference(label: str = ""): + """One inference step with NVTX frame label.""" + torch.cuda.nvtx.range_push(f"frame{label}") + result = stream(image=image_tensor) + torch.cuda.nvtx.range_pop() + return result + + +# ── Extra warmup (not captured) ──────────────────────────────────────────────── +print(f"[profile] Extra warmup ({WARMUP_RUNS} runs)...") +for i in range(WARMUP_RUNS): + run_inference(f"_warmup{i}") +torch.cuda.synchronize() +print("[profile] Warmup done.\n") + +# ── torch.profiler capture ───────────────────────────────────────────────────── +TRACE_PATH = os.path.join(_PROFILER_LOGS_DIR, f"sdtd_benchmark_{_TIMESTAMP}_trace.json") +TOTAL_STEPS = 1 + PROFILE_RUNS + +print(f"[profile] torch.profiler capture ({PROFILE_RUNS} active steps)...") +# CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED is benign when running under nsys; +# both register CUPTI subscribers but CUDA-event timings in *_stats.json are unaffected. +with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=1, active=PROFILE_RUNS, repeat=1), + record_shapes=True, + with_stack=True, +) as prof: + for step in range(TOTAL_STEPS): + run_inference(f"_prof{step}") + prof.step() + +torch.cuda.synchronize() +print("\n" + "=" * 80) +print("TORCH.PROFILER [benchmark] — Top 30 ops by CUDA time") +print("=" * 80) +print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) +prof.export_chrome_trace(TRACE_PATH) +print(f"\n[profile] Chrome trace -> {TRACE_PATH}") + +# ── nsys-gated capture window (cudaProfilerStart / Stop) ────────────────────── +print(f"\n[profile] nsys-gated capture ({PROFILE_RUNS} inferences)...") +torch.cuda.cudart().cudaProfilerStart() + +for i in range(PROFILE_RUNS): + t0 = time.perf_counter() + run_inference(f"_nsys{i}") + torch.cuda.synchronize() + ms = (time.perf_counter() - t0) * 1000 + print(f" Step {i}: {ms:.1f} ms ({1000 / ms:.2f} FPS)") + +torch.cuda.cudart().cudaProfilerStop() + +# ── CUDA-event stats report ──────────────────────────────────────────────────── +profiler.report() +_STATS_PATH = os.path.join(_PROFILER_LOGS_DIR, f"sdtd_benchmark_{_TIMESTAMP}_stats.json") +profiler.export_stats(_STATS_PATH) + +print("\n[profile] Complete.") +print("If running under nsys, analyze with:") +print(" nsys stats --report cuda_kern_exec_trace profiles/sdtd_benchmark.nsys-rep > kernel_trace.txt") +print(" nsys stats --report nvtx_pushpop_trace profiles/sdtd_benchmark.nsys-rep > nvtx_trace.txt") diff --git a/setup.py b/setup.py index 78d4aaff..da303155 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ from setuptools import find_packages, setup + # Copied from pip_utils.py to avoid import def _check_torch_installed(): try: @@ -18,16 +19,18 @@ def _check_torch_installed(): raise RuntimeError(msg) if not torch.version.cuda: - raise RuntimeError("Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package.") + raise RuntimeError( + "Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package." + ) def get_cuda_constraint(): - cuda_version = os.environ.get("STREAMDIFFUSION_CUDA_VERSION") or \ - os.environ.get("CUDA_VERSION") + cuda_version = os.environ.get("STREAMDIFFUSION_CUDA_VERSION") or os.environ.get("CUDA_VERSION") if not cuda_version: try: import torch + cuda_version = torch.version.cuda except Exception: # might not be available during wheel build, so we have to ignore @@ -55,13 +58,16 @@ def get_cuda_constraint(): "Pillow>=12.2.0", # CVE-2026-25990: out-of-bounds write in PSD loading; 12.2.0 verified "fire==0.7.1", "omegaconf==2.3.0", - "onnx==1.18.0", # IR 11 — modelopt needs FLOAT4E2M1 (added in 1.18); float32_to_bfloat16 present (removed in 1.19+) + "onnx==1.19.1", # IR 11; modelopt needs FLOAT4E2M1 (added in 1.18); onnx-gs 0.6.1 no longer needs float32_to_bfloat16 "onnxruntime-gpu==1.24.4", # TRT EP, supports IR 11; never co-install CPU onnxruntime — shared files conflict + "onnxoptimizer==0.4.2", + "onnxslim==0.1.91", + "onnxscript==0.6.2", "polygraphy==0.49.26", "protobuf>=4.25.8,<5", # mediapipe 0.10.21 requires protobuf 4.x; 4.25.8 fixes CVE-2025-4565; CVE-2026-0994 (JSON DoS) accepted risk for local pipeline "colored==2.3.2", "pywin32==311;sys_platform == 'win32'", - "onnx-graphsurgeon==0.5.8", + "onnx-graphsurgeon==0.6.1", "controlnet-aux==0.0.10", "diffusers-ipadapter @ git+https://github.com/livepeer/Diffusers_IPAdapter.git@405f87da42932e30bd55ee8dca3ce502d7834a99", "mediapipe==0.10.21", @@ -80,7 +86,18 @@ def deps_list(*pkgs): extras = {} extras["xformers"] = deps_list("xformers") extras["torch"] = deps_list("torch", "accelerate") -extras["tensorrt"] = deps_list("protobuf", "cuda-python", "onnx", "onnxruntime-gpu", "colored", "polygraphy", "onnx-graphsurgeon") +extras["tensorrt"] = deps_list( + "protobuf", + "cuda-python", + "onnx", + "onnxruntime-gpu", + "onnxoptimizer", + "onnxslim", + "onnxscript", + "colored", + "polygraphy", + "onnx-graphsurgeon", +) extras["controlnet"] = deps_list("onnx-graphsurgeon", "controlnet-aux") extras["ipadapter"] = deps_list("diffusers-ipadapter", "mediapipe", "insightface") diff --git a/src/streamdiffusion/__init__.py b/src/streamdiffusion/__init__.py index 8ff48ea6..88bfdb89 100644 --- a/src/streamdiffusion/__init__.py +++ b/src/streamdiffusion/__init__.py @@ -1,13 +1,14 @@ +from .config import create_wrapper_from_config, load_config, save_config from .pipeline import StreamDiffusion -from .wrapper import StreamDiffusionWrapper -from .config import load_config, save_config, create_wrapper_from_config from .preprocessing.processors import list_preprocessors +from .wrapper import StreamDiffusionWrapper + __all__ = [ - "StreamDiffusion", - "StreamDiffusionWrapper", - "load_config", - "list_preprocessors", - "save_config", - "create_wrapper_from_config", - ] \ No newline at end of file + "StreamDiffusion", + "StreamDiffusionWrapper", + "load_config", + "list_preprocessors", + "save_config", + "create_wrapper_from_config", +] diff --git a/src/streamdiffusion/_hf_tracing_patches.py b/src/streamdiffusion/_hf_tracing_patches.py index d2d611f3..422b6bf6 100644 --- a/src/streamdiffusion/_hf_tracing_patches.py +++ b/src/streamdiffusion/_hf_tracing_patches.py @@ -5,7 +5,10 @@ import torch + _ALREADY = False # idempotence guard + + # --------------------------------------------------------------------------- # # 1. UNet2DConditionModel: guard in_channels % up_factor # --------------------------------------------------------------------------- # @@ -16,11 +19,11 @@ def _patch_unet(): def patched(self, sample, *args, **kwargs): if torch.jit.is_tracing(): - dim = torch.as_tensor(getattr(self.config, "in_channels", self.in_channels)) + dim = torch.as_tensor(getattr(self.config, "in_channels", self.in_channels)) up_factor = torch.as_tensor(getattr(self.config, "default_overall_up_factor", 1)) torch._assert( torch.remainder(dim, up_factor) == 0, - f"in_channels={dim} not divisible by default_overall_up_factor={up_factor}" + f"in_channels={dim} not divisible by default_overall_up_factor={up_factor}", ) return orig_fwd(self, sample, *args, **kwargs) @@ -32,12 +35,13 @@ def patched(self, sample, *args, **kwargs): # --------------------------------------------------------------------------- # def _patch_downsample(): import diffusers.models.downsampling as d + orig_fwd = d.Downsample2D.forward def patched(self, hidden_states, *args, **kwargs): torch._assert( hidden_states.shape[1] == self.channels, - f"[Downsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}" + f"[Downsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}", ) return orig_fwd(self, hidden_states, *args, **kwargs) @@ -49,12 +53,13 @@ def patched(self, hidden_states, *args, **kwargs): # --------------------------------------------------------------------------- # def _patch_upsample(): import diffusers.models.upsampling as u + orig_fwd = u.Upsample2D.forward def patched(self, hidden_states, *args, **kwargs): torch._assert( hidden_states.shape[1] == self.channels, - f"[Upsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}" + f"[Upsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}", ) return orig_fwd(self, hidden_states, *args, **kwargs) diff --git a/src/streamdiffusion/acceleration/tensorrt/__init__.py b/src/streamdiffusion/acceleration/tensorrt/__init__.py index 52bbcb3c..20b8f5da 100644 --- a/src/streamdiffusion/acceleration/tensorrt/__init__.py +++ b/src/streamdiffusion/acceleration/tensorrt/__init__.py @@ -1,22 +1,31 @@ +import os +import warnings + import torch import torch.nn as nn -from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel + + +os.environ.setdefault("CUDA_MODULE_LOADING", "LAZY") +from diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( retrieve_latents, ) from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + from .builder import EngineBuilder from .models.models import BaseModel + def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + class StableDiffusionSafetyCheckerWrapper(StableDiffusionSafetyChecker): def __init__(self, config): super().__init__(config) - + @torch.no_grad() def forward(self, clip_input): pooled_output = self.vision_model(clip_input)[1] @@ -37,6 +46,7 @@ def forward(self, clip_input): return has_nsfw_concepts + class TorchVAEEncoder(torch.nn.Module): def __init__(self, vae: AutoencoderKL): super().__init__() @@ -45,6 +55,7 @@ def __init__(self, vae: AutoencoderKL): def forward(self, x: torch.Tensor): return retrieve_latents(self.vae.encode(x)) + def compile_vae_encoder( vae: TorchVAEEncoder, model_data: BaseModel, @@ -84,6 +95,7 @@ def compile_vae_decoder( **engine_build_options, ) + def compile_safety_checker( safety_checker: StableDiffusionSafetyCheckerWrapper, model_data: BaseModel, @@ -117,7 +129,22 @@ def compile_unet( # These are not valid kwargs for build_engine() and must be handled here. build_options = dict(engine_build_options) fp8 = build_options.pop("fp8", False) - calibration_data_fn = build_options.pop("calibration_data_fn", None) + pipe_ref = build_options.pop("pipe_ref", None) + calibration_prompts = build_options.pop("calibration_prompts", None) + calibration_steps = build_options.pop("calibration_steps", 20) + fp8_allow_fp16_fallback = build_options.pop("fp8_allow_fp16_fallback", False) + fp8_use_cached_attn = build_options.pop("fp8_use_cached_attn", False) + fp8_use_controlnet = build_options.pop("fp8_use_controlnet", False) + fp8_num_ip_layers = build_options.pop("fp8_num_ip_layers", 0) + for _legacy in ("calibration_data_fn", "amax_save_path", "fp8_alpha"): + if _legacy in build_options: + warnings.warn( + f"engine_build_options['{_legacy}'] is deprecated and ignored — the FP8 path " + "switched to ONNX-level quantization. Remove this kwarg from your config.", + DeprecationWarning, + stacklevel=2, + ) + build_options.pop(_legacy) unet = unet.to(torch.device("cuda"), dtype=torch.float16) builder = EngineBuilder(model_data, unet, device=torch.device("cuda")) @@ -127,7 +154,13 @@ def compile_unet( engine_path, opt_batch_size=opt_batch_size, fp8=fp8, - calibration_data_fn=calibration_data_fn, + pipe_ref=pipe_ref, + calibration_prompts=calibration_prompts, + calibration_steps=calibration_steps, + fp8_allow_fp16_fallback=fp8_allow_fp16_fallback, + fp8_use_cached_attn=fp8_use_cached_attn, + fp8_use_controlnet=fp8_use_controlnet, + fp8_num_ip_layers=fp8_num_ip_layers, **build_options, ) @@ -149,4 +182,4 @@ def compile_controlnet( engine_path, opt_batch_size=opt_batch_size, **engine_build_options, - ) \ No newline at end of file + ) diff --git a/src/streamdiffusion/acceleration/tensorrt/builder.py b/src/streamdiffusion/acceleration/tensorrt/builder.py index 32d10f87..dd3a97e0 100644 --- a/src/streamdiffusion/acceleration/tensorrt/builder.py +++ b/src/streamdiffusion/acceleration/tensorrt/builder.py @@ -20,6 +20,12 @@ _build_logger = logging.getLogger(__name__) +class StageStatus: + BUILT = "built" + CACHED = "cached" + FAILED = "failed" + + def _write_build_stats(engine_path: str, stats: dict): """Append build stats to a JSON-lines file next to the engine directory.""" try: @@ -37,6 +43,32 @@ def _write_build_stats(engine_path: str, stats: dict): _build_logger.warning(f"Failed to write build stats: {e}") +def _run_fp8_stage(name: str, fn, stats: dict, allow_fallback: bool, engine_filename: str) -> bool: + """Run an FP8 build stage with timing + fallback handling. Returns True on success.""" + t0 = time.perf_counter() + try: + fn() + elapsed = time.perf_counter() - t0 + stats["stages"][name] = {"status": StageStatus.BUILT, "elapsed_s": round(elapsed, 2)} + _build_logger.info(f"[BUILD] {name} ({engine_filename}): {elapsed:.1f}s") + return True + except Exception as err: + elapsed = time.perf_counter() - t0 + stats["stages"][name] = { + "status": StageStatus.FAILED, + "elapsed_s": round(elapsed, 2), + "error": str(err), + } + if allow_fallback: + _build_logger.warning(f"[BUILD] {name} failed after {elapsed:.1f}s: {err}. Falling back to FP16.") + return False + raise RuntimeError( + f"{name} failed: {err}.\n" + "Set fp8_allow_fp16_fallback=True in TRT_PROFILES to silently fall back to FP16, " + "or fix the error above." + ) from err + + def create_onnx_path(name, onnx_dir, opt=True): return os.path.join(onnx_dir, name + (".opt" if opt else "") + ".onnx") @@ -72,7 +104,15 @@ def build( force_onnx_export: bool = False, force_onnx_optimize: bool = False, fp8: bool = False, - calibration_data_fn=None, + pipe_ref=None, + calibration_prompts=None, + calibration_steps: int = 20, + fp8_guidance_scale: float = 7.5, + fp8_allow_fp16_fallback: bool = False, + fp8_use_cached_attn: bool = False, + fp8_use_controlnet: bool = False, + fp8_num_ip_layers: int = 0, + builder_optimization_level: Optional[int] = None, ): build_total_start = time.perf_counter() engine_name = Path(engine_path).parent.name @@ -88,6 +128,13 @@ def build( "stages": {}, } + # FP8 paths are resolved relative to the engine directory. + # calib_data.npz: cached UNet activations (survives engine rebuilds). + # unet.fp8.onnx: ONNX with native FLOAT8E4M3FN Q/DQ (also cached). + engine_dir_early = os.path.dirname(engine_path) + _calib_data_path = os.path.join(engine_dir_early, "calib_data.npz") + _fp8_onnx_path = os.path.join(engine_dir_early, "unet.fp8.onnx") + # --- ONNX Export --- if not force_onnx_export and os.path.exists(onnx_path): print(f"Found cached model: {onnx_path}") @@ -95,8 +142,7 @@ def build( else: print(f"Exporting model: {onnx_path}") t0 = time.perf_counter() - export_onnx( - self.network, + _export_kwargs = dict( onnx_path=onnx_path, model_data=self.model, opt_image_height=opt_image_height, @@ -104,9 +150,10 @@ def build( opt_batch_size=opt_batch_size, onnx_opset=onnx_opset, ) + export_onnx(self.network, **_export_kwargs) elapsed = time.perf_counter() - t0 stats["stages"]["onnx_export"] = {"status": "built", "elapsed_s": round(elapsed, 2)} - _build_logger.warning(f"[BUILD] ONNX export ({engine_filename}): {elapsed:.1f}s") + _build_logger.info(f"[BUILD] ONNX export ({engine_filename}): {elapsed:.1f}s") self.network = self.network.to("cpu") del self.network gc.collect() @@ -126,7 +173,7 @@ def build( ) elapsed = time.perf_counter() - t0 stats["stages"]["onnx_optimize"] = {"status": "built", "elapsed_s": round(elapsed, 2)} - _build_logger.warning(f"[BUILD] ONNX optimize ({engine_filename}): {elapsed:.1f}s") + _build_logger.info(f"[BUILD] ONNX optimize ({engine_filename}): {elapsed:.1f}s") self.model.min_latent_shape = min_image_resolution // 8 self.model.max_latent_shape = max_image_resolution // 8 @@ -148,49 +195,69 @@ def build( ) _build_logger.info(f"Verified ONNX opt file: {onnx_opt_path} ({opt_file_size / (1024**2):.1f} MB)") - # --- FP8 Quantization (if enabled) --- - # Inserts Q/DQ nodes into the optimized ONNX and replaces onnx_opt_path with - # the FP8-annotated ONNX for the TRT build step below. - onnx_trt_input = onnx_opt_path # default: use FP16 opt ONNX - fp8_trt = fp8 # may be set to False below if FP8 quantization fails - if fp8: - onnx_fp8_path = onnx_opt_path.replace(".opt.onnx", ".fp8.onnx") - if not os.path.exists(onnx_fp8_path): - _build_logger.warning("[BUILD] FP8 quantization starting...") - t0 = time.perf_counter() - from .fp8_quantize import quantize_onnx_fp8 + # --- FP8: Capture calibration tensors (once, cached in calib_data.npz) --- + if fp8 and pipe_ref is not None: + if os.path.exists(_calib_data_path): + _build_logger.info(f"[BUILD] FP8 calibration data cached: {_calib_data_path}") + stats["stages"]["fp8_calib_capture"] = {"status": StageStatus.CACHED} + else: - try: - quantize_onnx_fp8( - onnx_opt_path, - onnx_fp8_path, - model_data=self.model, - opt_batch_size=opt_batch_size, - opt_image_height=opt_image_height, - opt_image_width=opt_image_width, + def _calib_fn(): + from .fp8_quantize import _load_calibration_prompts, capture_calibration_data + + prompts = calibration_prompts or _load_calibration_prompts() + _build_logger.info( + f"[BUILD] FP8 activation capture: {len(prompts)} prompts × " + f"{calibration_steps} steps, guidance_scale={fp8_guidance_scale}" ) - elapsed = time.perf_counter() - t0 - stats["stages"]["fp8_quantize"] = {"status": "built", "elapsed_s": round(elapsed, 2)} - _build_logger.warning(f"[BUILD] FP8 quantization ({engine_filename}): {elapsed:.1f}s") - onnx_trt_input = onnx_fp8_path - except Exception as fp8_err: - elapsed = time.perf_counter() - t0 - _build_logger.warning( - f"[BUILD] FP8 quantization failed after {elapsed:.1f}s: {fp8_err}. " - f"Falling back to FP16 TensorRT engine (onnx_trt_input unchanged)." + capture_calibration_data( + pipe_ref, + prompts, + num_inference_steps=calibration_steps, + save_path=_calib_data_path, + guidance_scale=fp8_guidance_scale, + onnx_path=onnx_opt_path, + use_cached_attn=fp8_use_cached_attn, + use_controlnet=fp8_use_controlnet, + num_ip_layers=fp8_num_ip_layers, ) - stats["stages"]["fp8_quantize"] = { - "status": "failed_fallback_fp16", - "elapsed_s": round(elapsed, 2), - "error": str(fp8_err), - } - # onnx_trt_input remains onnx_opt_path (FP16 ONNX) - # Disable FP8 engine build path (avoids STRONGLY_TYPED flag) - fp8_trt = False + + if not _run_fp8_stage("fp8_calib_capture", _calib_fn, stats, fp8_allow_fp16_fallback, engine_filename): + fp8 = False + elif fp8 and pipe_ref is None: + _build_logger.warning( + "[BUILD] fp8=True but pipe_ref not provided — FP8 calibration skipped. " + "Pass pipe_ref in engine_build_options for proper activation capture." + ) + fp8 = False + + # --- FP8: Inject native FLOAT8E4M3FN Q/DQ into the ONNX (cached in unet.fp8.onnx) --- + if fp8: + if os.path.exists(_fp8_onnx_path + ".ok"): + _build_logger.info(f"[BUILD] FP8 ONNX cached: {_fp8_onnx_path}") + stats["stages"]["fp8_onnx_quantize"] = {"status": StageStatus.CACHED} else: - _build_logger.info(f"[BUILD] Found cached FP8 ONNX: {onnx_fp8_path}") - stats["stages"]["fp8_quantize"] = {"status": "cached"} - onnx_trt_input = onnx_fp8_path + + def _quant_fn(): + from .fp8_quantize import load_calibration_data, quantize_onnx_fp8 + + calib_data = load_calibration_data(_calib_data_path) + if calib_data is None: + raise RuntimeError(f"Calibration data missing after capture step: {_calib_data_path}") + quantize_onnx_fp8( + onnx_path=onnx_opt_path, + output_path=_fp8_onnx_path, + calibration_data=calib_data, + use_cached_attn=fp8_use_cached_attn, + use_controlnet=fp8_use_controlnet, + num_ip_layers=fp8_num_ip_layers, + ) + + if not _run_fp8_stage("fp8_onnx_quantize", _quant_fn, stats, fp8_allow_fp16_fallback, engine_filename): + fp8 = False + + # Select the ONNX to feed into TRT: FP8-quantized when available, else plain opt. + _trt_onnx_path = _fp8_onnx_path if (fp8 and os.path.exists(_fp8_onnx_path + ".ok")) else onnx_opt_path # --- TRT Engine Build --- if not force_engine_build and os.path.exists(engine_path): @@ -200,7 +267,7 @@ def build( t0 = time.perf_counter() build_engine( engine_path=engine_path, - onnx_opt_path=onnx_trt_input, + onnx_opt_path=_trt_onnx_path, model_data=self.model, opt_image_height=opt_image_height, opt_image_width=opt_image_width, @@ -209,11 +276,32 @@ def build( build_dynamic_shape=build_dynamic_shape, build_all_tactics=build_all_tactics, build_enable_refit=build_enable_refit, - fp8=fp8_trt, + fp8=fp8, + builder_optimization_level=builder_optimization_level, ) elapsed = time.perf_counter() - t0 stats["stages"]["trt_build"] = {"status": "built", "elapsed_s": round(elapsed, 2)} - _build_logger.warning(f"[BUILD] TRT engine build ({engine_filename}): {elapsed:.1f}s") + _build_logger.info(f"[BUILD] TRT engine build ({engine_filename}): {elapsed:.1f}s") + + # --- FP8 Q/DQ layer count (sanity gate: < 100 means quantization is inactive) --- + if fp8 and os.path.exists(engine_path): + try: + import tensorrt as trt + + _rt = trt.Runtime(trt.Logger(trt.Logger.WARNING)) + with open(engine_path, "rb") as _f: + _eng = _rt.deserialize_cuda_engine(_f.read()) + _insp = _eng.create_engine_inspector() + _info = _insp.get_engine_information(trt.LayerInformationFormat.JSON) + _qdq = _info.count("QuantizeLinear") + _info.count("DequantizeLinear") + stats["fp8_qdq_layers"] = _qdq + _build_logger.info(f"[BUILD] FP8 engine Q/DQ layer count: {_qdq}") + if _qdq < 500: + _build_logger.warning( + f"[BUILD] Low Q/DQ count ({_qdq} < 500) — FP8 quantization likely inactive or incomplete" + ) + except Exception as _e: + _build_logger.warning(f"[BUILD] FP8 inspector check skipped: {_e}") # Record totals (before cleanup so build_stats.json is preserved) total_elapsed = time.perf_counter() - build_total_start @@ -224,16 +312,20 @@ def build( if os.path.exists(engine_path): stats["engine_size_mb"] = round(os.path.getsize(engine_path) / (1024 * 1024), 1) - _build_logger.warning(f"[BUILD] {engine_filename} complete: {total_elapsed:.1f}s total") + _build_logger.info(f"[BUILD] {engine_filename} complete: {total_elapsed:.1f}s total") _write_build_stats(engine_path, stats) - # Cleanup ONNX artifacts — preserve .engine, .fp8.onnx, timing.cache, and build_stats.json + # Cleanup ONNX artifacts — preserve .engine, calib_data.npz, unet.fp8.onnx* files, timing.cache, build_stats.json # Two-pass deletion to handle Windows file locks (gc.collect releases Python handles) - _keep_suffixes = (".engine", ".fp8.onnx", ".cache") - _keep_exact = {"build_stats.json", "timing.cache"} + _keep_suffixes = (".engine", ".cache") + _keep_exact = {"build_stats.json", "timing.cache", "calib_data.npz"} engine_dir = os.path.dirname(engine_path) _to_delete = [] for file in os.listdir(engine_dir): + # Keep all files that are part of the FP8 quantized ONNX artifact + # (unet.fp8.onnx and any external data companion like unet.fp8.onnx.data) + if "fp8.onnx" in file: + continue if file in _keep_exact or any(file.endswith(s) for s in _keep_suffixes): continue _to_delete.append(os.path.join(engine_dir, file)) @@ -246,29 +338,34 @@ def build( except OSError: _failed.append(fpath) - # Release Python-held file handles (ONNX model refs), retry failures + # Release Python-held file handles (ONNX model refs), retry locked files. + # Per-file poll with 50ms backoff instead of a single global sleep — most + # handles release within 1-2 retries on Windows; worst case ~0.5s same as before. if _failed: gc.collect() torch.cuda.empty_cache() - time.sleep(0.5) _still_failed = [] for fpath in _failed: - try: - os.remove(fpath) - except OSError as cleanup_err: + _last_err = None + for _attempt in range(10): + try: + os.remove(fpath) + _last_err = None + break + except OSError as _e: + _last_err = _e + time.sleep(0.05) + if _last_err is not None: _still_failed.append(os.path.basename(fpath)) _build_logger.warning( - f"[BUILD] Could not delete temp file {os.path.basename(fpath)}: {cleanup_err}" + f"[BUILD] Could not delete temp file {os.path.basename(fpath)}: {_last_err}" ) if _still_failed: _build_logger.warning( f"[BUILD] {len(_still_failed)} intermediate files could not be cleaned. " - f"Manual cleanup: delete all files except *.engine and *.fp8.onnx from {engine_dir}" + f"Manual cleanup: delete all files except *.engine, calib_data.npz, unet.fp8.onnx from {engine_dir}" ) cleaned = len(_to_delete) - len(_still_failed) else: cleaned = len(_to_delete) _build_logger.info(f"[BUILD] Cleaned {cleaned}/{len(_to_delete)} intermediate files") - else: - gc.collect() - torch.cuda.empty_cache() diff --git a/src/streamdiffusion/acceleration/tensorrt/calibration_prompts_sdxl.txt b/src/streamdiffusion/acceleration/tensorrt/calibration_prompts_sdxl.txt new file mode 100644 index 00000000..4affcdeb --- /dev/null +++ b/src/streamdiffusion/acceleration/tensorrt/calibration_prompts_sdxl.txt @@ -0,0 +1,36 @@ +# Calibration prompts for FP8 UNet quantization. +# These are used by calibrate_unet_fp8_torch() to collect activation ranges +# across a diverse set of visual concepts and lighting conditions. +# 32 prompts — covers portraits, landscapes, abstract, stylized, lit/dark. +a portrait of a young woman with long brown hair, soft studio lighting, photorealistic +a majestic snow-capped mountain peak at golden hour, landscape photography +an oil painting of a dense forest with sunlight filtering through the trees +close-up of a red rose with water droplets, macro photography, shallow depth of field +a futuristic cityscape at night with neon lights and rain-slick streets +abstract digital art with swirling blue and gold geometric patterns +a cozy log cabin interior with a fireplace, warm tungsten light +a black and white portrait of an elderly man with deep wrinkles +a tropical beach with turquoise water and white sand, aerial view +impressionist painting of a flower market with colorful blooms +a dark dramatic fantasy warrior in silver armor, cinematic lighting +a flat lay of assorted autumn leaves on a wooden surface +an astronaut floating in deep space with Earth visible behind them +a vintage leather-bound book open on a wooden desk, moody low key lighting +a smooth glass sphere reflecting a sunset over the ocean +wide-angle shot of the Milky Way over a desert landscape at night +a white ceramic bowl of ramen with soft-boiled egg and green onions +an animated character with oversized round eyes and pastel blue hair +a fox sitting in a field of tall grass at dusk, soft backlighting +birds-eye view of a dense rainforest canopy with morning mist +a steampunk mechanical clockwork heart with glowing orange embers +a close-up of a human eye reflecting a burning candle flame +a chessboard with a dramatic single key light casting long shadows +minimalist white room with a single window and a beam of sunlight +a lightning bolt over the sea, long exposure, dark stormy sky +a watercolor painting of a Japanese pagoda in autumn +a pair of hands kneading bread dough on a floured wooden surface +a dragon coiled around a mountain summit, digital painting, epic scale +a 1960s diner interior with red vinyl booths and a jukebox +blurry bokeh lights of a city street at night through a rainy window +a child blowing dandelion seeds in a summer meadow, shallow focus +a hyperrealistic oil painting of a golden retriever puppy looking up diff --git a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py index f59da463..ca01cc5c 100644 --- a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py +++ b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py @@ -113,6 +113,7 @@ def get_engine_path( use_controlnet: bool = False, fp8: bool = False, resolution: Optional[tuple] = None, + builder_optimization_level: Optional[int] = None, ) -> Path: """ Generate engine path using wrapper.py's current logic. @@ -121,6 +122,7 @@ def get_engine_path( Special handling for ControlNet engines which use model_id-based directories. """ filename = self._configs[engine_type]["filename"] + optlvl_suffix = f"--optlvl{builder_optimization_level}" if builder_optimization_level is not None else "" if engine_type == EngineType.CONTROLNET: # ControlNet engines use special model_id-based directory structure @@ -134,7 +136,7 @@ def get_engine_path( prefix = f"controlnet_{model_dir_name}--min_batch-{min_batch_size}--max_batch-{max_batch_size}--res-{resolution[0]}x{resolution[1]}" else: prefix = f"controlnet_{model_dir_name}--min_batch-{min_batch_size}--max_batch-{max_batch_size}--dyn-384-1024" - return self.engine_dir / prefix / filename + return self.engine_dir / (prefix + optlvl_suffix) / filename else: # Standard engines use the unified prefix format # Extract base name (from wrapper.py lines 1002-1003) @@ -160,10 +162,25 @@ def get_engine_path( if use_controlnet: prefix += "--controlnet" if fp8: - prefix += "--fp8" + prefix += "--fp8v3" + + prefix += optlvl_suffix prefix += f"--mode-{mode}" + # Embed TRT version + compute capability so upgrading TRT invalidates + # stale engines automatically. Old engine dirs are orphaned (not deleted), + # keeping them available for rollback. Fails silently if tensorrt isn't + # installed yet (e.g. during a partial install). + try: + import tensorrt as _trt + import torch as _torch + + _cc = _torch.cuda.get_device_capability(0) + prefix += f"--trt{_trt.__version__}--cc{_cc[0]}{_cc[1]}" + except Exception: + pass + if resolution is not None: prefix += f"--res-{resolution[0]}x{resolution[1]}" @@ -224,6 +241,7 @@ def _get_default_controlnet_build_options( opt_image_height: int = 704, opt_image_width: int = 704, build_dynamic_shape: bool = False, + builder_optimization_level: Optional[int] = None, ) -> Dict: """Get default engine build options for ControlNet engines.""" opts = { @@ -235,6 +253,8 @@ def _get_default_controlnet_build_options( if build_dynamic_shape: opts["min_image_resolution"] = 384 opts["max_image_resolution"] = 1024 + if builder_optimization_level is not None: + opts["builder_optimization_level"] = builder_optimization_level return opts def compile_and_load_engine( @@ -334,6 +354,7 @@ def get_or_load_controlnet_engine( conditioning_channels: int = 3, opt_image_height: int = 704, opt_image_width: int = 704, + builder_optimization_level: Optional[int] = None, ) -> Any: """ Get or load ControlNet engine, providing unified interface for ControlNet management. @@ -350,6 +371,7 @@ def get_or_load_controlnet_engine( use_tiny_vae=False, # Not used for ControlNet controlnet_model_id=model_id, resolution=(opt_image_height, opt_image_width), + builder_optimization_level=builder_optimization_level, ) # Compile and load ControlNet engine @@ -370,5 +392,6 @@ def get_or_load_controlnet_engine( engine_build_options=self._get_default_controlnet_build_options( opt_image_height=opt_image_height, opt_image_width=opt_image_width, + builder_optimization_level=builder_optimization_level, ), ) diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py index 13532a0c..be540807 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py @@ -1,15 +1,16 @@ from .controlnet_export import SDXLControlNetExportWrapper from .unet_controlnet_export import ControlNetUNetExportWrapper, MultiControlNetUNetExportWrapper from .unet_ipadapter_export import IPAdapterUNetExportWrapper -from .unet_sdxl_export import SDXLExportWrapper, SDXLConditioningHandler +from .unet_sdxl_export import SDXLConditioningHandler, SDXLExportWrapper from .unet_unified_export import UnifiedExportWrapper + __all__ = [ "SDXLControlNetExportWrapper", "ControlNetUNetExportWrapper", - "MultiControlNetUNetExportWrapper", + "MultiControlNetUNetExportWrapper", "IPAdapterUNetExportWrapper", "SDXLExportWrapper", "SDXLConditioningHandler", "UnifiedExportWrapper", -] \ No newline at end of file +] diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py index 946917b1..43809a1a 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py @@ -1,23 +1,24 @@ import torch + class SDXLControlNetExportWrapper(torch.nn.Module): """Wrapper for SDXL ControlNet models to handle added_cond_kwargs properly during ONNX export""" - + def __init__(self, controlnet_model): super().__init__() self.controlnet = controlnet_model - + # Get device and dtype from model - if hasattr(controlnet_model, 'device'): + if hasattr(controlnet_model, "device"): self.device = controlnet_model.device else: # Try to infer from first parameter try: self.device = next(controlnet_model.parameters()).device except: - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - if hasattr(controlnet_model, 'dtype'): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if hasattr(controlnet_model, "dtype"): self.dtype = controlnet_model.dtype else: # Try to infer from first parameter @@ -25,15 +26,14 @@ def __init__(self, controlnet_model): self.dtype = next(controlnet_model.parameters()).dtype except: self.dtype = torch.float16 - - def forward(self, sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, text_embeds, time_ids): + + def forward( + self, sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, text_embeds, time_ids + ): """Forward pass that handles SDXL ControlNet requirements and produces 9 down blocks""" # Use the provided SDXL conditioning - added_cond_kwargs = { - 'text_embeds': text_embeds, - 'time_ids': time_ids - } - + added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids} + # Call the ControlNet with proper arguments including conditioning_scale result = self.controlnet( sample=sample, @@ -42,40 +42,49 @@ def forward(self, sample, timestep, encoder_hidden_states, controlnet_cond, cond controlnet_cond=controlnet_cond, conditioning_scale=conditioning_scale, added_cond_kwargs=added_cond_kwargs, - return_dict=False + return_dict=False, ) - + # Extract down blocks and mid block from result if isinstance(result, tuple) and len(result) >= 2: down_block_res_samples, mid_block_res_sample = result[0], result[1] - elif hasattr(result, 'down_block_res_samples') and hasattr(result, 'mid_block_res_sample'): + elif hasattr(result, "down_block_res_samples") and hasattr(result, "mid_block_res_sample"): down_block_res_samples = result.down_block_res_samples mid_block_res_sample = result.mid_block_res_sample else: raise ValueError(f"Unexpected ControlNet output format: {type(result)}") - + # SDXL ControlNet should have exactly 9 down blocks if len(down_block_res_samples) != 9: raise ValueError(f"SDXL ControlNet expected 9 down blocks, got {len(down_block_res_samples)}") - + # Return 9 down blocks + 1 mid block with explicit names matching UNet pattern # Following the pattern from controlnet_wrapper.py and models.py: # down_block_00: Initial sample (320 channels) - # down_block_01-03: Block 0 residuals (320 channels) + # down_block_01-03: Block 0 residuals (320 channels) # down_block_04-06: Block 1 residuals (640 channels) # down_block_07-08: Block 2 residuals (1280 channels) down_block_00 = down_block_res_samples[0] # Initial: 320 channels, 88x88 down_block_01 = down_block_res_samples[1] # Block0: 320 channels, 88x88 - down_block_02 = down_block_res_samples[2] # Block0: 320 channels, 88x88 + down_block_02 = down_block_res_samples[2] # Block0: 320 channels, 88x88 down_block_03 = down_block_res_samples[3] # Block0: 320 channels, 44x44 down_block_04 = down_block_res_samples[4] # Block1: 640 channels, 44x44 down_block_05 = down_block_res_samples[5] # Block1: 640 channels, 44x44 down_block_06 = down_block_res_samples[6] # Block1: 640 channels, 22x22 down_block_07 = down_block_res_samples[7] # Block2: 1280 channels, 22x22 down_block_08 = down_block_res_samples[8] # Block2: 1280 channels, 22x22 - mid_block = mid_block_res_sample # Mid: 1280 channels, 22x22 - + mid_block = mid_block_res_sample # Mid: 1280 channels, 22x22 + # Return as individual tensors to preserve names in ONNX - return (down_block_00, down_block_01, down_block_02, down_block_03, - down_block_04, down_block_05, down_block_06, down_block_07, - down_block_08, mid_block) \ No newline at end of file + return ( + down_block_00, + down_block_01, + down_block_02, + down_block_03, + down_block_04, + down_block_05, + down_block_06, + down_block_07, + down_block_08, + mid_block, + ) diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py index e7a834bc..fb6b5733 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py @@ -1,30 +1,32 @@ """ControlNet-aware UNet wrapper for ONNX export""" +from typing import Dict, List, Optional + import torch -from typing import List, Optional, Dict, Any from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + from ..models.utils import convert_list_to_structure class ControlNetUNetExportWrapper(torch.nn.Module): """Wrapper that combines UNet with ControlNet inputs for ONNX export""" - + def __init__(self, unet: UNet2DConditionModel, control_input_names: List[str], kvo_cache_structure: List[int]): super().__init__() self.unet = unet self.control_input_names = control_input_names self.kvo_cache_structure = kvo_cache_structure - + self.control_names = [] for name in control_input_names: if "input_control" in name or "output_control" in name or "middle_control" in name: self.control_names.append(name) - + self.num_controlnet_args = len(self.control_names) - + # Detect if this is SDXL based on UNet config self.is_sdxl = self._detect_sdxl_architecture(unet) - + # SDXL ControlNet has different structure than SD1.5 if self.is_sdxl: # SDXL has 1 initial + 3 down blocks producing 9 control tensors total @@ -32,30 +34,30 @@ def __init__(self, unet: UNet2DConditionModel, control_input_names: List[str], k else: # SD1.5 has 12 down blocks self.expected_down_blocks = 12 - + def _detect_sdxl_architecture(self, unet): """Detect if UNet is SDXL based on architecture""" - if hasattr(unet, 'config'): + if hasattr(unet, "config"): config = unet.config # SDXL has 3 down blocks vs SD1.5's 4 down blocks - block_out_channels = getattr(config, 'block_out_channels', None) + block_out_channels = getattr(config, "block_out_channels", None) if block_out_channels and len(block_out_channels) == 3: return True return False - + def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): """Forward pass that organizes control inputs and calls UNet""" - - control_args = args[:self.num_controlnet_args] - kvo_cache = args[self.num_controlnet_args:] - + + control_args = args[: self.num_controlnet_args] + kvo_cache = args[self.num_controlnet_args :] + down_block_controls = [] mid_block_control = None - + if control_args: all_control_tensors = [] middle_tensor = None - + for tensor, name in zip(control_args, self.control_names): if "input_control" in name: if "middle" in name: @@ -64,7 +66,7 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): all_control_tensors.append(tensor) elif "middle_control" in name: middle_tensor = tensor - + if len(all_control_tensors) == self.expected_down_blocks: down_block_controls = all_control_tensors mid_block_control = middle_tensor @@ -73,7 +75,7 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): if len(all_control_tensors) > 0: if len(all_control_tensors) > self.expected_down_blocks: # Too many tensors - take the first expected_down_blocks - down_block_controls = all_control_tensors[:self.expected_down_blocks] + down_block_controls = all_control_tensors[: self.expected_down_blocks] else: # Too few tensors - use what we have down_block_controls = all_control_tensors @@ -82,30 +84,29 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): # No control tensors available - skip ControlNet down_block_controls = None mid_block_control = None - + formatted_kvo_cache = [] if len(kvo_cache) > 0: formatted_kvo_cache = convert_list_to_structure(kvo_cache, self.kvo_cache_structure) unet_kwargs = { - 'sample': sample, - 'timestep': timestep, - 'encoder_hidden_states': encoder_hidden_states, - 'kvo_cache': formatted_kvo_cache, - 'return_dict': False, + "sample": sample, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "kvo_cache": formatted_kvo_cache, + "return_dict": False, } - + # Pass through all additional kwargs (for SDXL models) unet_kwargs.update(kwargs) # Auto-generate SDXL conditioning if missing and UNet requires it - if 'added_cond_kwargs' not in unet_kwargs or unet_kwargs.get('added_cond_kwargs') is None: - if (hasattr(self.unet, 'config') and - getattr(self.unet.config, 'addition_embed_type', None) == 'text_time'): + if "added_cond_kwargs" not in unet_kwargs or unet_kwargs.get("added_cond_kwargs") is None: + if hasattr(self.unet, "config") and getattr(self.unet.config, "addition_embed_type", None) == "text_time": batch_size = sample.shape[0] - unet_kwargs['added_cond_kwargs'] = { - 'text_embeds': torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype), - 'time_ids': torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype), + unet_kwargs["added_cond_kwargs"] = { + "text_embeds": torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype), + "time_ids": torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype), } if down_block_controls: @@ -115,30 +116,30 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): # Control tensors are now generated in the correct order to match UNet's down_block_res_samples # For SDXL: [88x88, 88x88, 88x88, 44x44, 44x44, 44x44, 22x22, 22x22, 22x22] # This directly aligns with UNet's: [initial_sample] + [block0_residuals] + [block1_residuals] + [block2_residuals] - unet_kwargs['down_block_additional_residuals'] = adapted_controls - + unet_kwargs["down_block_additional_residuals"] = adapted_controls + if mid_block_control is not None: # Adapt middle control tensor shape if needed adapted_mid_control = self._adapt_middle_control_tensor(mid_block_control, sample) - unet_kwargs['mid_block_additional_residual'] = adapted_mid_control - + unet_kwargs["mid_block_additional_residual"] = adapted_mid_control + try: res = self.unet(**unet_kwargs) if len(kvo_cache) > 0: return res else: return res[0] - except Exception as e: + except Exception: raise - + def _adapt_control_tensors(self, control_tensors, sample): """Adapt control tensor shapes to match UNet expectations""" if not control_tensors: return control_tensors - + adapted_tensors = [] sample_height, sample_width = sample.shape[-2:] - + # Updated factors to match the corrected control tensor generation # SDXL: 9 tensors [88x88, 88x88, 88x88, 44x44, 44x44, 44x44, 22x22, 22x22, 22x22] # Factors: [1, 1, 1, 2, 2, 2, 4, 4, 4] to match UNet down_block_res_samples structure @@ -146,30 +147,31 @@ def _adapt_control_tensors(self, control_tensors, sample): expected_downsample_factors = [1, 1, 1, 2, 2, 2, 4, 4, 4] # 9 tensors for SDXL else: expected_downsample_factors = [1, 1, 1, 2, 2, 2, 4, 4, 4, 8, 8, 8] # 12 tensors for SD1.5 - + for i, control_tensor in enumerate(control_tensors): if control_tensor is None: adapted_tensors.append(control_tensor) continue - + # Check if tensor needs spatial adaptation if len(control_tensor.shape) >= 4: control_height, control_width = control_tensor.shape[-2:] - + # Use the correct downsampling factor for this tensor index if i < len(expected_downsample_factors): downsample_factor = expected_downsample_factors[i] expected_height = sample_height // downsample_factor expected_width = sample_width // downsample_factor - + if control_height != expected_height or control_width != expected_width: # Use interpolation to adapt size import torch.nn.functional as F + adapted_tensor = F.interpolate( - control_tensor, + control_tensor, size=(expected_height, expected_width), - mode='bilinear', - align_corners=False + mode="bilinear", + align_corners=False, ) adapted_tensors.append(adapted_tensor) else: @@ -179,94 +181,94 @@ def _adapt_control_tensors(self, control_tensors, sample): adapted_tensors.append(control_tensor) else: adapted_tensors.append(control_tensor) - + return adapted_tensors - + def _adapt_middle_control_tensor(self, mid_control, sample): """Adapt middle control tensor shape to match UNet expectations""" if mid_control is None: return mid_control - + # Middle control is typically at the bottleneck, so heavily downsampled if len(mid_control.shape) >= 4 and len(sample.shape) >= 4: sample_height, sample_width = sample.shape[-2:] control_height, control_width = mid_control.shape[-2:] - + # For SDXL: middle block is at 4x downsampling (22x22 from 88x88) # For SD1.5: middle block is at 8x downsampling expected_factor = 4 if self.is_sdxl else 8 expected_height = sample_height // expected_factor expected_width = sample_width // expected_factor - + if control_height != expected_height or control_width != expected_width: import torch.nn.functional as F + adapted_tensor = F.interpolate( - mid_control, - size=(expected_height, expected_width), - mode='bilinear', - align_corners=False + mid_control, size=(expected_height, expected_width), mode="bilinear", align_corners=False ) return adapted_tensor - + return mid_control class MultiControlNetUNetExportWrapper(torch.nn.Module): """Advanced wrapper for multiple ControlNets with different scales""" - - def __init__(self, - unet: UNet2DConditionModel, - control_input_names: List[str], - kvo_cache_structure: List[int], - num_controlnets: int = 1, - conditioning_scales: Optional[List[float]] = None): + + def __init__( + self, + unet: UNet2DConditionModel, + control_input_names: List[str], + kvo_cache_structure: List[int], + num_controlnets: int = 1, + conditioning_scales: Optional[List[float]] = None, + ): super().__init__() self.unet = unet self.control_input_names = control_input_names self.num_controlnets = num_controlnets self.conditioning_scales = conditioning_scales or [1.0] * num_controlnets self.kvo_cache_structure = kvo_cache_structure - + self.control_names = [] for name in control_input_names: if "input_control" in name or "output_control" in name or "middle_control" in name: self.control_names.append(name) - + self.num_controlnet_args = len(self.control_names) self.controlnet_indices = [] controls_per_net = self.num_controlnet_args // num_controlnets - + for cn_idx in range(num_controlnets): start_idx = cn_idx * controls_per_net end_idx = start_idx + controls_per_net self.controlnet_indices.append(list(range(start_idx, end_idx))) - + def forward(self, sample, timestep, encoder_hidden_states, *args): """Forward pass for multiple ControlNets""" - control_args = args[:self.num_controlnet_args] - kvo_cache = args[self.num_controlnet_args:] + control_args = args[: self.num_controlnet_args] + kvo_cache = args[self.num_controlnet_args :] combined_down_controls = None combined_mid_control = None - + for cn_idx, indices in enumerate(self.controlnet_indices): scale = self.conditioning_scales[cn_idx] if scale == 0: continue - + cn_controls = [control_args[i] for i in indices if i < len(control_args)] - + if not cn_controls: continue - + num_down = len(cn_controls) - 1 down_controls = cn_controls[:num_down] mid_control = cn_controls[num_down] if num_down < len(cn_controls) else None - + scaled_down = [ctrl * scale for ctrl in down_controls] scaled_mid = mid_control * scale if mid_control is not None else None - + if combined_down_controls is None: combined_down_controls = scaled_down combined_mid_control = scaled_mid @@ -275,24 +277,24 @@ def forward(self, sample, timestep, encoder_hidden_states, *args): combined_down_controls[i] += scaled_down[i] if scaled_mid is not None and combined_mid_control is not None: combined_mid_control += scaled_mid - + formatted_kvo_cache = [] if len(kvo_cache) > 0: formatted_kvo_cache = convert_list_to_structure(kvo_cache, self.kvo_cache_structure) unet_kwargs = { - 'sample': sample, - 'timestep': timestep, - 'encoder_hidden_states': encoder_hidden_states, - 'kvo_cache': formatted_kvo_cache, - 'return_dict': False, + "sample": sample, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "kvo_cache": formatted_kvo_cache, + "return_dict": False, } - + if combined_down_controls: - unet_kwargs['down_block_additional_residuals'] = list(reversed(combined_down_controls)) + unet_kwargs["down_block_additional_residuals"] = list(reversed(combined_down_controls)) if combined_mid_control is not None: - unet_kwargs['mid_block_additional_residual'] = combined_mid_control - + unet_kwargs["mid_block_additional_residual"] = combined_mid_control + res = self.unet(**unet_kwargs) if len(kvo_cache) > 0: return res @@ -301,11 +303,13 @@ def forward(self, sample, timestep, encoder_hidden_states, *args): return res -def create_controlnet_wrapper(unet: UNet2DConditionModel, - control_input_names: List[str], - kvo_cache_structure: List[int], - num_controlnets: int = 1, - conditioning_scales: Optional[List[float]] = None) -> torch.nn.Module: +def create_controlnet_wrapper( + unet: UNet2DConditionModel, + control_input_names: List[str], + kvo_cache_structure: List[int], + num_controlnets: int = 1, + conditioning_scales: Optional[List[float]] = None, +) -> torch.nn.Module: """Factory function to create appropriate ControlNet wrapper""" if num_controlnets == 1: return ControlNetUNetExportWrapper(unet, control_input_names, kvo_cache_structure) @@ -315,17 +319,18 @@ def create_controlnet_wrapper(unet: UNet2DConditionModel, ) -def organize_control_tensors(control_tensors: List[torch.Tensor], - control_input_names: List[str]) -> Dict[str, List[torch.Tensor]]: +def organize_control_tensors( + control_tensors: List[torch.Tensor], control_input_names: List[str] +) -> Dict[str, List[torch.Tensor]]: """Organize control tensors by type (input, output, middle)""" - organized = {'input': [], 'output': [], 'middle': []} - + organized = {"input": [], "output": [], "middle": []} + for tensor, name in zip(control_tensors, control_input_names): if "input_control" in name: - organized['input'].append(tensor) + organized["input"].append(tensor) elif "output_control" in name: - organized['output'].append(tensor) + organized["output"].append(tensor) elif "middle_control" in name: - organized['middle'].append(tensor) - - return organized \ No newline at end of file + organized["middle"].append(tensor) + + return organized diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py index f7eb1861..93310ad8 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py @@ -1,30 +1,37 @@ +from typing import List + import torch from diffusers import UNet2DConditionModel -from typing import Optional, Dict, Any, List - -from ....model_detection import detect_model, detect_model_from_diffusers_unet from diffusers_ipadapter.ip_adapter.attention_processor import TRTIPAttnProcessor, TRTIPAttnProcessor2_0 +from ....model_detection import detect_model_from_diffusers_unet + class IPAdapterUNetExportWrapper(torch.nn.Module): """ Wrapper that bakes IPAdapter attention processors into the UNet for ONNX export. - + This approach installs IPAdapter attention processors before ONNX export, allowing the specialized attention logic to be compiled into TensorRT. The UNet expects concatenated embeddings (text + image) as encoder_hidden_states. """ - - def __init__(self, unet: UNet2DConditionModel, cross_attention_dim: int, num_tokens: int = 4, install_processors: bool = True): + + def __init__( + self, + unet: UNet2DConditionModel, + cross_attention_dim: int, + num_tokens: int = 4, + install_processors: bool = True, + ): super().__init__() self.unet = unet self.num_image_tokens = num_tokens # 4 for standard, 16 for plus self.cross_attention_dim = cross_attention_dim # 768 for SD1.5, 2048 for SDXL self.install_processors = install_processors - + # Convert to float32 BEFORE installing processors (to avoid resetting them) self.unet = self.unet.to(dtype=torch.float32) - + # Track installed TRT processors self._ip_trt_processors: List[torch.nn.Module] = [] self.num_ip_layers: int = 0 @@ -36,8 +43,10 @@ def __init__(self, unet: UNet2DConditionModel, cross_attention_dim: int, num_tok # Install IPAdapter processors AFTER dtype conversion self._install_ipadapter_processors() else: - print("IPAdapterUNetExportWrapper: WARNING - UNet will not have IPAdapter functionality without processors!") - + print( + "IPAdapterUNetExportWrapper: WARNING - UNet will not have IPAdapter functionality without processors!" + ) + def _has_ipadapter_processors(self) -> bool: """Check if the UNet already has IPAdapter processors installed""" try: @@ -45,44 +54,48 @@ def _has_ipadapter_processors(self) -> bool: for name, processor in processors.items(): # Check for IPAdapter processor class names processor_class = processor.__class__.__name__ - if 'IPAttn' in processor_class or 'IPAttnProcessor' in processor_class: + if "IPAttn" in processor_class or "IPAttnProcessor" in processor_class: return True return False except Exception as e: print(f"IPAdapterUNetExportWrapper: Error checking existing processors: {e}") return False - + def _ensure_processor_dtype_consistency(self): """Ensure existing IPAdapter processors have correct dtype for ONNX export""" if hasattr(torch.nn.functional, "scaled_dot_product_attention"): from diffusers.models.attention_processor import AttnProcessor2_0 as AttnProcessor + IPProcClass = TRTIPAttnProcessor2_0 else: from diffusers.models.attention_processor import AttnProcessor + IPProcClass = TRTIPAttnProcessor try: processors = self.unet.attn_processors updated_processors = {} self._ip_trt_processors = [] ip_layer_index = 0 - + for name, processor in processors.items(): processor_class = processor.__class__.__name__ - if 'TRTIPAttn' in processor_class: + if "TRTIPAttn" in processor_class: # Already TRT processors: ensure dtype and record proc = processor.to(dtype=torch.float32) proc._scale_index = ip_layer_index self._ip_trt_processors.append(proc) ip_layer_index += 1 updated_processors[name] = proc - elif 'IPAttn' in processor_class or 'IPAttnProcessor' in processor_class: + elif "IPAttn" in processor_class or "IPAttnProcessor" in processor_class: # Replace standard processors with TRT variants, preserving weights where applicable - hidden_size = getattr(processor, 'hidden_size', None) - cross_attention_dim = getattr(processor, 'cross_attention_dim', None) - num_tokens = getattr(processor, 'num_tokens', self.num_image_tokens) - proc = IPProcClass(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens) + hidden_size = getattr(processor, "hidden_size", None) + cross_attention_dim = getattr(processor, "cross_attention_dim", None) + num_tokens = getattr(processor, "num_tokens", self.num_image_tokens) + proc = IPProcClass( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens + ) # Copy IP projection weights if present - if hasattr(processor, 'to_k_ip') and hasattr(processor, 'to_v_ip') and hasattr(proc, 'to_k_ip'): + if hasattr(processor, "to_k_ip") and hasattr(processor, "to_v_ip") and hasattr(proc, "to_k_ip"): with torch.no_grad(): proc.to_k_ip.weight.copy_(processor.to_k_ip.weight.to(dtype=torch.float32)) proc.to_v_ip.weight.copy_(processor.to_v_ip.weight.to(dtype=torch.float32)) @@ -93,16 +106,17 @@ def _ensure_processor_dtype_consistency(self): updated_processors[name] = proc else: updated_processors[name] = AttnProcessor() - + # Update all processors to ensure consistency self.unet.set_attn_processor(updated_processors) self.num_ip_layers = len(self._ip_trt_processors) - + except Exception as e: print(f"IPAdapterUNetExportWrapper: Error updating processor dtypes: {e}") import traceback + traceback.print_exc() - + def _install_ipadapter_processors(self): """ Install IPAdapter attention processors that will be baked into ONNX. @@ -112,19 +126,23 @@ def _install_ipadapter_processors(self): try: if hasattr(torch.nn.functional, "scaled_dot_product_attention"): from diffusers.models.attention_processor import AttnProcessor2_0 as AttnProcessor + IPProcClass = TRTIPAttnProcessor2_0 else: from diffusers.models.attention_processor import AttnProcessor + IPProcClass = TRTIPAttnProcessor - + # Install attention processors with proper configuration processor_names = list(self.unet.attn_processors.keys()) - + attn_procs = {} ip_layer_index = 0 for name in processor_names: - cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim - + cross_attention_dim = ( + None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + ) + # Determine hidden_size based on processor location hidden_size = None if name.startswith("mid_block"): @@ -138,7 +156,7 @@ def _install_ipadapter_processors(self): else: # Fallback for any unexpected processor names hidden_size = self.unet.config.block_out_channels[0] # Use first block size as fallback - + if cross_attention_dim is None: # Self-attention layers use standard processors attn_procs[name] = AttnProcessor() @@ -154,38 +172,46 @@ def _install_ipadapter_processors(self): self._ip_trt_processors.append(proc) ip_layer_index += 1 attn_procs[name] = proc - + self.unet.set_attn_processor(attn_procs) self.num_ip_layers = len(self._ip_trt_processors) - - except Exception as e: print(f"IPAdapterUNetExportWrapper: ERROR - Could not install IPAdapter processors: {e}") print(f"IPAdapterUNetExportWrapper: Exception type: {type(e).__name__}") print("IPAdapterUNetExportWrapper: IPAdapter functionality will not work without processors!") import traceback + traceback.print_exc() raise e - + def set_ipadapter_scale(self, ipadapter_scale: torch.Tensor) -> None: """Assign per-layer scale tensor to installed TRTIPAttn processors.""" if not isinstance(ipadapter_scale, torch.Tensor): import logging - logging.getLogger(__name__).error(f"IPAdapterUNetExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}") + + logging.getLogger(__name__).error( + f"IPAdapterUNetExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}" + ) raise TypeError("ipadapter_scale must be a torch.Tensor") if self.num_ip_layers <= 0 or not self._ip_trt_processors: raise RuntimeError("No TRTIPAttn processors installed") if ipadapter_scale.ndim != 1 or ipadapter_scale.shape[0] != self.num_ip_layers: import logging - logging.getLogger(__name__).error(f"IPAdapterUNetExportWrapper: ipadapter_scale has wrong shape {tuple(ipadapter_scale.shape)}, expected=({self.num_ip_layers},)") + + logging.getLogger(__name__).error( + f"IPAdapterUNetExportWrapper: ipadapter_scale has wrong shape {tuple(ipadapter_scale.shape)}, expected=({self.num_ip_layers},)" + ) raise ValueError(f"ipadapter_scale must have shape [{self.num_ip_layers}]") # Ensure float32 for ONNX export stability scale_vec = ipadapter_scale.to(dtype=torch.float32) try: import logging - logging.getLogger(__name__).debug(f"IPAdapterUNetExportWrapper: scale_vec min={scale_vec.min()}, max={scale_vec.max()}") + + logging.getLogger(__name__).debug( + f"IPAdapterUNetExportWrapper: scale_vec min={scale_vec.min()}, max={scale_vec.max()}" + ) except Exception: pass for proc in self._ip_trt_processors: @@ -194,27 +220,27 @@ def set_ipadapter_scale(self, ipadapter_scale: torch.Tensor) -> None: def forward(self, sample, timestep, encoder_hidden_states, ipadapter_scale: torch.Tensor = None): """ Forward pass with concatenated embeddings (text + image). - + The IPAdapter processors installed in the UNet will automatically: 1. Split the concatenated embeddings into text and image parts 2. Process image tokens with separate attention computation 3. Apply scaling and blending between text and image attention - + Args: sample: Latent input tensor - timestep: Timestep tensor + timestep: Timestep tensor encoder_hidden_states: Concatenated embeddings [text_tokens + image_tokens, cross_attention_dim] - + Returns: UNet output (noise prediction) """ # Validate input shapes batch_size, seq_len, embed_dim = encoder_hidden_states.shape - + # Check that we have the expected number of image tokens if embed_dim != self.cross_attention_dim: raise ValueError(f"Embedding dimension {embed_dim} doesn't match expected {self.cross_attention_dim}") - + # Ensure dtype consistency for ONNX export if encoder_hidden_states.dtype != torch.float32: encoder_hidden_states = encoder_hidden_states.to(torch.float32) @@ -223,29 +249,28 @@ def forward(self, sample, timestep, encoder_hidden_states, ipadapter_scale: torc if ipadapter_scale is None: raise RuntimeError("IPAdapterUNetExportWrapper.forward requires ipadapter_scale tensor") self.set_ipadapter_scale(ipadapter_scale) - + # Pass concatenated embeddings to UNet with baked-in IPAdapter processors return self.unet( - sample=sample, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - return_dict=False + sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, return_dict=False ) -def create_ipadapter_wrapper(unet: UNet2DConditionModel, num_tokens: int = 4, install_processors: bool = True) -> IPAdapterUNetExportWrapper: +def create_ipadapter_wrapper( + unet: UNet2DConditionModel, num_tokens: int = 4, install_processors: bool = True +) -> IPAdapterUNetExportWrapper: """ Create an IPAdapter wrapper with automatic architecture detection and baked-in processors. - + Handles both cases: 1. UNet with pre-loaded IPAdapter processors (preserves existing weights) 2. UNet without IPAdapter processors (installs new ones if install_processors=True) - + Args: unet: UNet2DConditionModel to wrap num_tokens: Number of image tokens (4 for standard, 16 for plus) install_processors: Whether to install IPAdapter processors if none exist - + Returns: IPAdapterUNetExportWrapper with baked-in IPAdapter attention processors """ @@ -253,23 +278,21 @@ def create_ipadapter_wrapper(unet: UNet2DConditionModel, num_tokens: int = 4, in try: model_type = detect_model_from_diffusers_unet(unet) cross_attention_dim = unet.config.cross_attention_dim - + # Check if UNet already has IPAdapter processors installed existing_processors = unet.attn_processors - has_ipadapter = any('IPAttn' in proc.__class__.__name__ or 'IPAttnProcessor' in proc.__class__.__name__ - for proc in existing_processors.values()) - + has_ipadapter = any( + "IPAttn" in proc.__class__.__name__ or "IPAttnProcessor" in proc.__class__.__name__ + for proc in existing_processors.values() + ) + # Validate expected dimensions - expected_dims = { - "SD15": 768, - "SDXL": 2048, - "SD21": 1024 - } - + expected_dims = {"SD15": 768, "SDXL": 2048, "SD21": 1024} + expected_dim = expected_dims.get(model_type) - + return IPAdapterUNetExportWrapper(unet, cross_attention_dim, num_tokens, install_processors) - + except Exception as e: print(f"create_ipadapter_wrapper: Error during model detection: {e}") - return IPAdapterUNetExportWrapper(unet, 768, num_tokens, install_processors) \ No newline at end of file + return IPAdapterUNetExportWrapper(unet, 768, num_tokens, install_processors) diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py index fa1f0f89..078b1f91 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py @@ -4,13 +4,17 @@ conditioning parameters, and Turbo variants """ +import logging +from typing import Any, Dict + import torch -from typing import Dict, List, Optional, Tuple, Any, Union from diffusers import UNet2DConditionModel + from ....model_detection import ( detect_model, ) -import logging + + logger = logging.getLogger(__name__) # Handle different diffusers versions for CLIPTextModel import @@ -18,7 +22,7 @@ from diffusers.models.transformers.clip_text_model import CLIPTextModel except ImportError: try: - from diffusers.models.clip_text_model import CLIPTextModel + from diffusers.models.clip_text_model import CLIPTextModel except ImportError: try: from transformers import CLIPTextModel @@ -29,79 +33,81 @@ class SDXLExportWrapper(torch.nn.Module): """Wrapper for SDXL UNet to handle optional conditioning in legacy TensorRT""" - + def __init__(self, unet): super().__init__() self.unet = unet self.base_unet = self._get_base_unet(unet) self.supports_added_cond = self._test_added_cond_support() - + def _get_base_unet(self, unet): """Extract the base UNet from wrappers""" # Handle ControlNet wrapper - if hasattr(unet, 'unet_model') and hasattr(unet.unet_model, 'config'): + if hasattr(unet, "unet_model") and hasattr(unet.unet_model, "config"): return unet.unet_model - elif hasattr(unet, 'unet') and hasattr(unet.unet, 'config'): + elif hasattr(unet, "unet") and hasattr(unet.unet, "config"): return unet.unet - elif hasattr(unet, 'config'): + elif hasattr(unet, "config"): return unet else: # Fallback: try to find any attribute that has config for attr_name in dir(unet): - if not attr_name.startswith('_'): + if not attr_name.startswith("_"): attr = getattr(unet, attr_name, None) - if hasattr(attr, 'config') and hasattr(attr.config, 'addition_embed_type'): + if hasattr(attr, "config") and hasattr(attr.config, "addition_embed_type"): return attr return unet - + def _test_added_cond_support(self): """Test if this SDXL model supports added_cond_kwargs""" try: # Create minimal test inputs - sample = torch.randn(1, 4, 8, 8, device='cuda', dtype=torch.float16) - timestep = torch.tensor([0.5], device='cuda', dtype=torch.float32) - encoder_hidden_states = torch.randn(1, 77, 2048, device='cuda', dtype=torch.float16) - + sample = torch.randn(1, 4, 8, 8, device="cuda", dtype=torch.float16) + timestep = torch.tensor([0.5], device="cuda", dtype=torch.float32) + encoder_hidden_states = torch.randn(1, 77, 2048, device="cuda", dtype=torch.float16) + # Test with added_cond_kwargs test_added_cond = { - 'text_embeds': torch.randn(1, 1280, device='cuda', dtype=torch.float16), - 'time_ids': torch.randn(1, 6, device='cuda', dtype=torch.float16) + "text_embeds": torch.randn(1, 1280, device="cuda", dtype=torch.float16), + "time_ids": torch.randn(1, 6, device="cuda", dtype=torch.float16), } - + with torch.no_grad(): _ = self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs=test_added_cond) - + logger.info("SDXL model supports added_cond_kwargs") return True - + except Exception as e: logger.error(f"SDXL model does not support added_cond_kwargs: {e}") return False - + def forward(self, *args, **kwargs): """Forward pass that handles SDXL conditioning gracefully""" try: # Ensure added_cond_kwargs is never None to prevent TypeError - if 'added_cond_kwargs' in kwargs and kwargs['added_cond_kwargs'] is None: - kwargs['added_cond_kwargs'] = {} - + if "added_cond_kwargs" in kwargs and kwargs["added_cond_kwargs"] is None: + kwargs["added_cond_kwargs"] = {} + # Auto-generate SDXL conditioning if missing and model needs it - if (len(args) >= 3 and 'added_cond_kwargs' not in kwargs and - hasattr(self.base_unet.config, 'addition_embed_type') and - self.base_unet.config.addition_embed_type == 'text_time'): - + if ( + len(args) >= 3 + and "added_cond_kwargs" not in kwargs + and hasattr(self.base_unet.config, "addition_embed_type") + and self.base_unet.config.addition_embed_type == "text_time" + ): sample = args[0] device = sample.device batch_size = sample.shape[0] - + logger.info("Auto-generating required SDXL conditioning...") - kwargs['added_cond_kwargs'] = { - 'text_embeds': torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype), - 'time_ids': torch.zeros(batch_size, 6, device=device, dtype=sample.dtype) + kwargs["added_cond_kwargs"] = { + "text_embeds": torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype), + "time_ids": torch.zeros(batch_size, 6, device=device, dtype=sample.dtype), } - + # If model supports added conditioning and we have the kwargs, use them - if self.supports_added_cond and 'added_cond_kwargs' in kwargs: + if self.supports_added_cond and "added_cond_kwargs" in kwargs: result = self.unet(*args, **kwargs) return result elif len(args) >= 3: @@ -110,7 +116,7 @@ def forward(self, *args, **kwargs): else: # Fallback return self.unet(*args, **kwargs) - + except (TypeError, AttributeError) as e: logger.error(f"[SDXL_WRAPPER] forward: Exception caught: {e}") if "NoneType" in str(e) or "iterable" in str(e) or "text_embeds" in str(e): @@ -120,15 +126,17 @@ def forward(self, *args, **kwargs): sample, timestep, encoder_hidden_states = args[0], args[1], args[2] device = sample.device batch_size = sample.shape[0] - + # Create minimal valid SDXL conditioning minimal_conditioning = { - 'text_embeds': torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype), - 'time_ids': torch.zeros(batch_size, 6, device=device, dtype=sample.dtype) + "text_embeds": torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype), + "time_ids": torch.zeros(batch_size, 6, device=device, dtype=sample.dtype), } - + try: - return self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs=minimal_conditioning) + return self.unet( + sample, timestep, encoder_hidden_states, added_cond_kwargs=minimal_conditioning + ) except Exception as final_e: logger.info(f"Final fallback to basic call: {final_e}") return self.unet(sample, timestep, encoder_hidden_states) @@ -136,181 +144,180 @@ def forward(self, *args, **kwargs): return self.unet(*args) else: raise e - + + class SDXLConditioningHandler: """Handles SDXL conditioning parameters and dual text encoders""" - + def __init__(self, unet_info: Dict[str, Any]): self.unet_info = unet_info - self.is_sdxl = unet_info['is_sdxl'] - self.has_time_cond = unet_info['has_time_cond'] - self.has_addition_embed = unet_info['has_addition_embed'] - + self.is_sdxl = unet_info["is_sdxl"] + self.has_time_cond = unet_info["has_time_cond"] + self.has_addition_embed = unet_info["has_addition_embed"] + def get_conditioning_spec(self) -> Dict[str, Any]: """Get conditioning specification for ONNX export and TensorRT""" spec = { - 'text_encoder_dim': 768, # CLIP ViT-L - 'context_dim': 768, # Default SD1.5 - 'pooled_embeds': False, - 'time_ids': False, - 'dual_encoders': False + "text_encoder_dim": 768, # CLIP ViT-L + "context_dim": 768, # Default SD1.5 + "pooled_embeds": False, + "time_ids": False, + "dual_encoders": False, } - + if self.is_sdxl: - spec.update({ - 'text_encoder_dim': 768, # CLIP ViT-L - 'text_encoder_2_dim': 1280, # OpenCLIP ViT-bigG - 'context_dim': 2048, # Concatenated 768 + 1280 - 'pooled_embeds': True, # Pooled text embeddings - 'time_ids': self.has_time_cond, # Size/crop conditioning - 'dual_encoders': True - }) - + spec.update( + { + "text_encoder_dim": 768, # CLIP ViT-L + "text_encoder_2_dim": 1280, # OpenCLIP ViT-bigG + "context_dim": 2048, # Concatenated 768 + 1280 + "pooled_embeds": True, # Pooled text embeddings + "time_ids": self.has_time_cond, # Size/crop conditioning + "dual_encoders": True, + } + ) + return spec - - def create_sample_conditioning(self, batch_size: int = 1, device: str = 'cuda') -> Dict[str, torch.Tensor]: + + def create_sample_conditioning(self, batch_size: int = 1, device: str = "cuda") -> Dict[str, torch.Tensor]: """Create sample conditioning tensors for testing/export""" spec = self.get_conditioning_spec() dtype = torch.float16 - + conditioning = { - 'encoder_hidden_states': torch.randn( - batch_size, 77, spec['context_dim'], - device=device, dtype=dtype - ) + "encoder_hidden_states": torch.randn(batch_size, 77, spec["context_dim"], device=device, dtype=dtype) } - - if spec['pooled_embeds']: - conditioning['text_embeds'] = torch.randn( - batch_size, spec['text_encoder_2_dim'], - device=device, dtype=dtype + + if spec["pooled_embeds"]: + conditioning["text_embeds"] = torch.randn( + batch_size, spec["text_encoder_2_dim"], device=device, dtype=dtype ) - - if spec['time_ids']: - conditioning['time_ids'] = torch.randn( - batch_size, 6, # [height, width, crop_h, crop_w, target_height, target_width] - device=device, dtype=dtype + + if spec["time_ids"]: + conditioning["time_ids"] = torch.randn( + batch_size, + 6, # [height, width, crop_h, crop_w, target_height, target_width] + device=device, + dtype=dtype, ) - + return conditioning - + def test_unet_conditioning(self, unet: UNet2DConditionModel) -> Dict[str, bool]: """Test what conditioning the UNet actually supports""" - results = { - 'basic': False, - 'added_cond_kwargs': False, - 'separate_args': False - } - + results = {"basic": False, "added_cond_kwargs": False, "separate_args": False} + try: # Ensure model is on CUDA and in eval mode for testing - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" unet_test = unet.to(device).eval() - + # Create test inputs on the same device sample = torch.randn(1, 4, 8, 8, device=device, dtype=torch.float16) timestep = torch.tensor([0.5], device=device, dtype=torch.float32) conditioning = self.create_sample_conditioning(1, device=device) - + # Test basic call try: with torch.no_grad(): - _ = unet_test(sample, timestep, conditioning['encoder_hidden_states']) - results['basic'] = True + _ = unet_test(sample, timestep, conditioning["encoder_hidden_states"]) + results["basic"] = True except Exception: pass - + # Test added_cond_kwargs (standard SDXL) if self.is_sdxl: try: added_cond = {} - if 'text_embeds' in conditioning: - added_cond['text_embeds'] = conditioning['text_embeds'] - if 'time_ids' in conditioning: - added_cond['time_ids'] = conditioning['time_ids'] - + if "text_embeds" in conditioning: + added_cond["text_embeds"] = conditioning["text_embeds"] + if "time_ids" in conditioning: + added_cond["time_ids"] = conditioning["time_ids"] + with torch.no_grad(): - _ = unet_test(sample, timestep, conditioning['encoder_hidden_states'], - added_cond_kwargs=added_cond) - results['added_cond_kwargs'] = True + _ = unet_test( + sample, timestep, conditioning["encoder_hidden_states"], added_cond_kwargs=added_cond + ) + results["added_cond_kwargs"] = True except Exception: pass - + # Test separate arguments (some implementations) try: - args = [sample, timestep, conditioning['encoder_hidden_states']] - if 'text_embeds' in conditioning: - args.append(conditioning['text_embeds']) - if 'time_ids' in conditioning: - args.append(conditioning['time_ids']) - + args = [sample, timestep, conditioning["encoder_hidden_states"]] + if "text_embeds" in conditioning: + args.append(conditioning["text_embeds"]) + if "time_ids" in conditioning: + args.append(conditioning["time_ids"]) + with torch.no_grad(): _ = unet_test(*args) - results['separate_args'] = True + results["separate_args"] = True except Exception: pass - + except Exception as e: # If testing fails completely, provide safe defaults print(f"⚠️ UNet conditioning test setup failed: {e}") results = { - 'basic': True, # Assume basic call works - 'added_cond_kwargs': self.is_sdxl, # Assume SDXL models support this - 'separate_args': False + "basic": True, # Assume basic call works + "added_cond_kwargs": self.is_sdxl, # Assume SDXL models support this + "separate_args": False, } - + return results def get_onnx_export_spec(self) -> Dict[str, Any]: """Get specification for ONNX export""" spec = self.conditioning_handler.get_conditioning_spec() - + # Add export-specific details - spec.update({ - 'input_names': ['sample', 'timestep', 'encoder_hidden_states'], - 'output_names': ['noise_pred'], - 'dynamic_axes': { - 'sample': {0: 'batch_size'}, - 'timestep': {0: 'batch_size'}, - 'encoder_hidden_states': {0: 'batch_size'}, - 'noise_pred': {0: 'batch_size'} + spec.update( + { + "input_names": ["sample", "timestep", "encoder_hidden_states"], + "output_names": ["noise_pred"], + "dynamic_axes": { + "sample": {0: "batch_size"}, + "timestep": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + "noise_pred": {0: "batch_size"}, + }, } - }) - + ) + # Add SDXL-specific inputs if supported - if self.is_sdxl and self.supported_calls['added_cond_kwargs']: - if spec['pooled_embeds']: - spec['input_names'].append('text_embeds') - spec['dynamic_axes']['text_embeds'] = {0: 'batch_size'} - - if spec['time_ids']: - spec['input_names'].append('time_ids') - spec['dynamic_axes']['time_ids'] = {0: 'batch_size'} - - return spec + if self.is_sdxl and self.supported_calls["added_cond_kwargs"]: + if spec["pooled_embeds"]: + spec["input_names"].append("text_embeds") + spec["dynamic_axes"]["text_embeds"] = {0: "batch_size"} + + if spec["time_ids"]: + spec["input_names"].append("time_ids") + spec["dynamic_axes"]["time_ids"] = {0: "batch_size"} + return spec def get_sdxl_tensorrt_config(model_path: str, unet: UNet2DConditionModel) -> Dict[str, Any]: """Get complete TensorRT configuration for SDXL model""" # Use the new detection function detection_result = detect_model(unet) - + # Create a config dict compatible with SDXLConditioningHandler config = { - 'is_sdxl': detection_result['is_sdxl'], - 'has_time_cond': detection_result['architecture_details']['has_time_conditioning'], - 'has_addition_embed': detection_result['architecture_details']['has_addition_embeds'], - 'model_type': detection_result['model_type'], - 'is_turbo': detection_result['is_turbo'], - 'is_sd3': detection_result['is_sd3'], - 'confidence': detection_result['confidence'], - 'architecture_details': detection_result['architecture_details'], - 'compatibility_info': detection_result['compatibility_info'] + "is_sdxl": detection_result["is_sdxl"], + "has_time_cond": detection_result["architecture_details"]["has_time_conditioning"], + "has_addition_embed": detection_result["architecture_details"]["has_addition_embeds"], + "model_type": detection_result["model_type"], + "is_turbo": detection_result["is_turbo"], + "is_sd3": detection_result["is_sd3"], + "confidence": detection_result["confidence"], + "architecture_details": detection_result["architecture_details"], + "compatibility_info": detection_result["compatibility_info"], } - + # Add conditioning specification conditioning_handler = SDXLConditioningHandler(config) - config['conditioning_spec'] = conditioning_handler.get_conditioning_spec() - - return config \ No newline at end of file + config["conditioning_spec"] = conditioning_handler.get_conditioning_spec() + + return config diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py index 1c87efbf..67f28832 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py @@ -1,23 +1,28 @@ +from typing import List, Optional + import torch from diffusers import UNet2DConditionModel -from typing import Optional, List + +from ..models.utils import convert_list_to_structure from .unet_controlnet_export import create_controlnet_wrapper from .unet_ipadapter_export import create_ipadapter_wrapper -from ..models.utils import convert_list_to_structure + class UnifiedExportWrapper(torch.nn.Module): """ - Unified wrapper that composes wrappers for conditioning modules. + Unified wrapper that composes wrappers for conditioning modules. """ - - def __init__(self, - unet: UNet2DConditionModel, - use_controlnet: bool = False, - use_ipadapter: bool = False, - control_input_names: Optional[List[str]] = None, - num_tokens: int = 4, - kvo_cache_structure: List[int] = [], - **kwargs): + + def __init__( + self, + unet: UNet2DConditionModel, + use_controlnet: bool = False, + use_ipadapter: bool = False, + control_input_names: Optional[List[str]] = None, + num_tokens: int = 4, + kvo_cache_structure: List[int] = [], + **kwargs, + ): super().__init__() self.use_controlnet = use_controlnet self.use_ipadapter = use_ipadapter @@ -25,23 +30,24 @@ def __init__(self, self.ipadapter_wrapper = None self.unet = unet self.kvo_cache_structure = kvo_cache_structure - + # Apply IPAdapter first (installs processors into UNet) if use_ipadapter: - ipadapter_kwargs = {k: v for k, v in kwargs.items() if k in ['install_processors']} - if 'install_processors' not in ipadapter_kwargs: - ipadapter_kwargs['install_processors'] = True - + ipadapter_kwargs = {k: v for k, v in kwargs.items() if k in ["install_processors"]} + if "install_processors" not in ipadapter_kwargs: + ipadapter_kwargs["install_processors"] = True self.ipadapter_wrapper = create_ipadapter_wrapper(unet, num_tokens=num_tokens, **ipadapter_kwargs) self.unet = self.ipadapter_wrapper.unet - + # Apply ControlNet second (wraps whatever UNet we have) if use_controlnet and control_input_names: - controlnet_kwargs = {k: v for k, v in kwargs.items() if k in ['num_controlnets', 'conditioning_scales']} + controlnet_kwargs = {k: v for k, v in kwargs.items() if k in ["num_controlnets", "conditioning_scales"]} + + self.controlnet_wrapper = create_controlnet_wrapper( + self.unet, control_input_names, kvo_cache_structure, **controlnet_kwargs + ) - self.controlnet_wrapper = create_controlnet_wrapper(self.unet, control_input_names, kvo_cache_structure, **controlnet_kwargs) - def _basic_unet_forward(self, sample, timestep, encoder_hidden_states, *kvo_cache, **kwargs): """Basic UNet forward that passes through all parameters to handle any model type""" formatted_kvo_cache = [] @@ -49,52 +55,57 @@ def _basic_unet_forward(self, sample, timestep, encoder_hidden_states, *kvo_cach formatted_kvo_cache = convert_list_to_structure(kvo_cache, self.kvo_cache_structure) # Auto-generate SDXL conditioning if missing and UNet requires it - if 'added_cond_kwargs' not in kwargs or kwargs.get('added_cond_kwargs') is None: + if "added_cond_kwargs" not in kwargs or kwargs.get("added_cond_kwargs") is None: base_unet = self.unet - if (hasattr(base_unet, 'config') and - getattr(base_unet.config, 'addition_embed_type', None) == 'text_time'): + if hasattr(base_unet, "config") and getattr(base_unet.config, "addition_embed_type", None) == "text_time": batch_size = sample.shape[0] - kwargs['added_cond_kwargs'] = { - 'text_embeds': torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype), - 'time_ids': torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype), + kwargs["added_cond_kwargs"] = { + "text_embeds": torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype), + "time_ids": torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype), } unet_kwargs = { - 'sample': sample, - 'timestep': timestep, - 'encoder_hidden_states': encoder_hidden_states, - 'return_dict': False, - 'kvo_cache': formatted_kvo_cache, - **kwargs # Pass through all additional parameters (SDXL, future model types, etc.) + "sample": sample, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "return_dict": False, + "kvo_cache": formatted_kvo_cache, + **kwargs, # Pass through all additional parameters (SDXL, future model types, etc.) } res = self.unet(**unet_kwargs) if len(kvo_cache) > 0: return res else: return res[0] - - def forward(self, - sample: torch.Tensor, - timestep: torch.Tensor, - encoder_hidden_states: torch.Tensor, - *args, - **kwargs) -> torch.Tensor: + + def forward( + self, sample: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: """Forward pass that handles any UNet parameters via **kwargs passthrough""" # Handle IP-Adapter runtime scale vector as a positional argument placed before control tensors if self.use_ipadapter and self.ipadapter_wrapper is not None: # ipadapter_scale is appended as the first extra positional input after the 3 base inputs if len(args) == 0: import logging - logging.getLogger(__name__).error("UnifiedExportWrapper: ipadapter_scale missing; required when use_ipadapter=True") + + logging.getLogger(__name__).error( + "UnifiedExportWrapper: ipadapter_scale missing; required when use_ipadapter=True" + ) raise RuntimeError("UnifiedExportWrapper: ipadapter_scale tensor is required when use_ipadapter=True") ipadapter_scale = args[0] if not isinstance(ipadapter_scale, torch.Tensor): import logging - logging.getLogger(__name__).error(f"UnifiedExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}") + + logging.getLogger(__name__).error( + f"UnifiedExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}" + ) raise TypeError("ipadapter_scale must be a torch.Tensor") try: import logging - logging.getLogger(__name__).debug(f"UnifiedExportWrapper: ipadapter_scale shape={tuple(ipadapter_scale.shape)}, dtype={ipadapter_scale.dtype}") + + logging.getLogger(__name__).debug( + f"UnifiedExportWrapper: ipadapter_scale shape={tuple(ipadapter_scale.shape)}, dtype={ipadapter_scale.dtype}" + ) except Exception: pass # assign per-layer scale tensors into processors @@ -107,4 +118,4 @@ def forward(self, return self.controlnet_wrapper(sample, timestep, encoder_hidden_states, *args, **kwargs) else: # Basic UNet call with all parameters passed through - return self._basic_unet_forward(sample, timestep, encoder_hidden_states, *args, **kwargs) \ No newline at end of file + return self._basic_unet_forward(sample, timestep, encoder_hidden_states, *args, **kwargs) diff --git a/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py b/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py index 66e5a899..762c638a 100644 --- a/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py +++ b/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py @@ -1,470 +1,435 @@ """ FP8 Quantization for StreamDiffusion TensorRT UNet engine. -Uses nvidia-modelopt for ONNX-level FP8 quantization via Q/DQ node insertion. -The quantized ONNX is then compiled to TRT with STRONGLY_TYPED + FP8 builder flags. +ONNX-level approach: export a plain FP16 ONNX first, then inject native +FLOAT8E4M3FN Q/DQ nodes via modelopt.onnx.quantization.quantize with +real activation tensors captured from the diffusers pipeline. + +Why this is better than the previous PyTorch nn.Module path: +- modelopt's torch path defaults trt_high_precision_dtype="Float" (FP32), + which inserts Cast(FP16→FP32) before every Q node and stores all weight + initializers as FP32 → 9 GB ONNX on SDXL UNet. +- The nn.Module path required generate_fp8_scales to rewrite FP8(4,3)→INT8(8) + because torch.onnx.export's ScaledE4M3Function symbolic corrupts the graph + for attention/embedding quantizers → INT8 kernels, not FP8 GEMMs. +- The ONNX-level path keeps weights in FP16 (high_precision_dtype="fp16") and + emits native FLOAT8E4M3FN Q/DQ → ~2.5 GB ONNX, true FP8 tensor-core kernels. Requirements: - nvidia-modelopt[onnx] >= 0.35.0 - TensorRT >= 10.0 (FP8 support) - RTX 4090+ (Ada Lovelace, compute 8.9, FP8 E4M3 hardware support) - -This module is called from builder.py when fp8=True is passed to EngineBuilder.build(). + nvidia-modelopt[onnx] >= 0.19.0 + onnxruntime-gpu >= 1.17 (ORT CUDA EP for calibration) + TensorRT >= 10.0 (FP8 E4M3 hardware support, STRONGLY_TYPED build flag) + RTX 4090+ (Ada Lovelace, compute capability 8.9) """ import logging import os +from pathlib import Path from typing import Dict, List, Optional import numpy as np + logger = logging.getLogger(__name__) +_BUNDLED_PROMPTS_PATH = Path(__file__).parent / "calibration_prompts_sdxl.txt" + + +def _load_calibration_prompts(user_path: Optional[str] = None) -> List[str]: + """Load calibration prompts from user path (if given) or bundled default.""" + path = Path(user_path) if user_path else _BUNDLED_PROMPTS_PATH + if not path.exists(): + logger.warning(f"[FP8] Calibration prompts not found: {path}. Using 3-prompt fallback.") + return [ + "a portrait of a person in soft studio lighting", + "abstract colorful geometric pattern", + "landscape photography at golden hour", + ] + with open(path, "r", encoding="utf-8") as f: + prompts = [line.strip() for line in f if line.strip() and not line.startswith("#")] + logger.info(f"[FP8] Loaded {len(prompts)} calibration prompts from {path.name}") + return prompts + + +def capture_calibration_data( + pipe, + prompts: List[str], + num_inference_steps: int = 20, + save_path: Optional[str] = None, + batch_size: int = 1, + guidance_scale: float = 7.5, + onnx_path: Optional[str] = None, + use_cached_attn: bool = False, + use_controlnet: bool = False, + num_ip_layers: int = 0, +) -> Dict[str, np.ndarray]: + """ + Capture UNet input activations from a real diffusers pipeline run. + + Registers a forward pre-hook on pipe.unet that records inputs across all + denoising timesteps and all calibration prompts. Returns a calibration_data + dict compatible with modelopt.onnx.quantization.quantize(calibration_data=...). -def _restore_dynamic_axes(onnx_fp8_path: str, model_data) -> None: - """Restore dynamic dim_param symbols in FP8 ONNX after ModelOpt quantization. + If LoRAs are active they are baked into the captured activations, which is + correct — quantization should see the same distribution as inference. - ModelOpt's override_shapes replaces dim_param with static dim_value for - calibration. TRT requires dynamic dims (dim_param) on inputs/outputs to - accept optimization profiles (min/opt/max ranges). This reads the original - dynamic_axes from model_data and restores them in the FP8 ONNX. + Args: + pipe: StableDiffusionPipeline or StableDiffusionXLPipeline. + prompts: Calibration texts (32–128 recommended). + num_inference_steps: Denoising steps per prompt. 20 for SDXL, 4 for Turbo. + save_path: Optional path to write calib_data.npz for caching between builds. + batch_size: Prompts per pipe() call. + guidance_scale: CFG scale (affects conditional/unconditional stacking). - Uses load_external_data=False so only the small protobuf is loaded/modified, - leaving the ~23GB external weight file untouched. + Returns: + Dict mapping UNet input names to np.ndarray arrays of shape [N, ...]. """ - import onnx + import torch + + _KEY_MAP = {0: "sample", 1: "timestep", 2: "encoder_hidden_states"} + _SDXL_COND_KEYS = ["text_embeds", "time_ids"] + + # builder.py moves pipe.unet to CPU after ONNX export to free GPU during + # optimize. Move it back to CUDA for calibration; restore on exit so the + # next build stage starts from the same VRAM state. + _unet_orig_device = next(pipe.unet.parameters()).device + if _unet_orig_device.type != "cuda": + pipe.unet.to("cuda") + + captured: Dict[str, list] = {} + + def _to_npy(t): + # Keep dtype as-is (FP16 model → FP16 captures). modelopt's max-abs + # calibration does not need FP32; FP32 upcast doubles transfer bandwidth. + # atleast_1d: SDXL passes timestep as a 0-dim scalar tensor in single- + # prompt calls; np.concatenate(axis=0) requires at least 1 axis. + return np.atleast_1d(t.detach().cpu().numpy()) + + def _hook(module, args, kwargs): + # SDXL pipeline calls unet(sample, t, encoder_hidden_states=..., added_cond_kwargs=...) + # — encoder_hidden_states arrives as a kwarg, not positional. Fall through to kwargs. + for idx, key in _KEY_MAP.items(): + val = args[idx] if idx < len(args) else kwargs.get(key) + if val is not None: + captured.setdefault(key, []).append(_to_npy(val)) + added = kwargs.get("added_cond_kwargs") or {} + if not added and len(args) > 3 and isinstance(args[3], dict): + added = args[3] + for key in _SDXL_COND_KEYS: + if key in added and added[key] is not None: + captured.setdefault(key, []).append(_to_npy(added[key])) + + handle = pipe.unet.register_forward_pre_hook(_hook, with_kwargs=True) + try: + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): + batches = [prompts[i : i + batch_size] for i in range(0, len(prompts), batch_size)] + for i, batch in enumerate(batches): + logger.info(f"[FP8] Capture batch {i + 1}/{len(batches)}: {batch[0][:60]}") + try: + pipe( + prompt=batch if len(batch) > 1 else batch[0], + num_inference_steps=num_inference_steps, + output_type="latent", + guidance_scale=guidance_scale, + ).images + except Exception as e: + logger.warning(f"[FP8] Capture batch {i + 1} failed ({type(e).__name__}): {e}. Skipping.") + finally: + handle.remove() + if _unet_orig_device.type != "cuda": + pipe.unet.to(_unet_orig_device) + torch.cuda.empty_cache() + + if not captured: + raise RuntimeError("[FP8] No UNet activations captured — check pipe.unet forward signature.") + + calib_data: Dict[str, np.ndarray] = {} + for key, arrays in captured.items(): + stacked = np.concatenate(arrays, axis=0) + calib_data[key] = stacked + logger.info(f"[FP8] Captured '{key}': shape={stacked.shape}, dtype={stacked.dtype}") + + # Synthesize zero/one tensors for feature-specific ONNX inputs not captured by the + # bare-pipe hook (kvo_cache_in_*, input_control_*, middle_control_*, ipadapter_scale). + # These inputs feed Q/DQ-excluded layers (see _FEATURE_EXCLUDE_PATTERNS), so the + # synthetic values only need to be shape-compatible — they never drive scale computation. + if onnx_path and (use_cached_attn or use_controlnet or num_ip_layers): + # modelopt's CalibrationDataProvider computes n_itr = first_input.shape[0] / + # first_input_onnx_dim_0, then np.array_split(arr, n_itr, axis=0) for EVERY + # input. Each chunk must satisfy the ONNX-declared dim 0 (fixed or dynamic). + # Bound captured leading dims so n_itr stays small — synthesized KV-cache memory + # scales as n_itr × per-layer-size × n_layers, which blows up at large n_itr. + _MAX_CALIB_ROWS = 8 + for _k in list(calib_data.keys()): + if calib_data[_k].shape[0] > _MAX_CALIB_ROWS: + calib_data[_k] = calib_data[_k][:_MAX_CALIB_ROWS] + + try: + _specs = _read_onnx_input_specs(onnx_path) + + # Reconcile captured tensors with ONNX-declared static dims. The bare-pipe + # forward hook captures the diffusers UNet *before* feature wrappers run, + # so dims that wrappers reshape (e.g. IPA's UnifiedExportWrapper concatenates + # 4 image tokens onto encoder_hidden_states → seq_len 77 → 81) won't match + # the exported ONNX. Pad with zeros if undersized, trim if oversized, on any + # static (non-dynamic, non-leading) axis. Padding zeros is benign for max-abs + # calibration — the zero-region contributes no signal to scale computation. + for _name, (_, _expected_dims) in _specs.items(): + if _name not in calib_data: + continue + _arr = calib_data[_name] + _resized = False + for _axis, _expected in enumerate(_expected_dims): + if _expected is None or _axis == 0 or _axis >= _arr.ndim: + continue + if _arr.shape[_axis] == _expected: + continue + if _arr.shape[_axis] < _expected: + _pw = [(0, 0)] * _arr.ndim + _pw[_axis] = (0, _expected - _arr.shape[_axis]) + _arr = np.pad(_arr, _pw, mode="constant") + else: + _slc = [slice(None)] * _arr.ndim + _slc[_axis] = slice(0, _expected) + _arr = _arr[tuple(_slc)] + _resized = True + if _resized: + calib_data[_name] = _arr + logger.info(f"[FP8] Reshaped captured '{_name}' to ONNX dims: shape={_arr.shape}") + + # Mirror CalibrationDataProvider's n_itr derivation (calib_utils.py:90): + # first ONNX-declared input that appears in calib_data drives the count. + _present = [n for n in _specs if n in calib_data] + if _present: + _first = _present[0] + _first_d0 = max(1, (_specs[_first][1][0] or 1)) + _n_itr = max(1, calib_data[_first].shape[0] // _first_d0) + else: + _n_itr = 1 + + for name, (dtype, dims) in _specs.items(): + if name in calib_data: + continue + # Resolve symbolic dims to 1. dim 0 is the per-chunk shape ORT sees. + resolved = [d if d is not None else 1 for d in dims] + # ipadapter_scale: ONNX dim 0 is dynamic ("L_ip") but the exported + # graph has hardcoded Gather(scale_vec, idx=k) for k=0..num_ip_layers-1 + # at every IPA layer. A length-1 chunk would OOB at idx≥1 during + # modelopt's _exclude_matmuls_by_inference ORT probe. Force per-chunk + # length to num_ip_layers so every Gather sees the expected vector. + per_step_d0 = num_ip_layers if name == "ipadapter_scale" and num_ip_layers > 0 else resolved[0] + # Total leading dim = n_itr × per-chunk dim 0 so every split chunk + # has exactly the ONNX-declared leading dim (fixed Q+K=2 for kvo_cache, + # fixed 2 for control inputs, num_ip_layers for ipadapter_scale). + arr_shape = [_n_itr * per_step_d0] + list(resolved[1:]) + arr = ( + np.ones(arr_shape, dtype=dtype) if name == "ipadapter_scale" else np.zeros(arr_shape, dtype=dtype) + ) + calib_data[name] = arr + logger.info( + f"[FP8] Synthesized '{name}': shape={arr.shape}, dtype={arr.dtype} " + f"(n_itr={_n_itr}, per-step-dim0={per_step_d0})" + ) + except Exception as e: + logger.warning( + f"[FP8] Synthetic input generation failed: {e}. Missing inputs will be caught during quantization." + ) + + if save_path: + # Uncompressed: zlib barely compresses random-ish FP16 activations and is + # single-threaded — savings are <5 % on multi-GB calibration sets. + # Atomic write: if the build crashes mid-save, no partial file is left + # behind for a future run to load and corrupt calibration. + # Save via a file handle so np.savez does not auto-append ".npz" to tmp_path. + tmp_path = save_path + ".tmp" + with open(tmp_path, "wb") as _f: + np.savez(_f, **calib_data) + os.replace(tmp_path, save_path) + logger.info(f"[FP8] Saved calibration data: {save_path} ({os.path.getsize(save_path) / 1e6:.1f} MB)") + return calib_data + + +def load_calibration_data(npz_path: str) -> Optional[Dict[str, np.ndarray]]: + """ + Load previously-saved calibration data from a .npz file. + Returns None (and deletes the file) if loading fails. + """ + if not os.path.exists(npz_path): + return None try: - dynamic_axes = model_data.get_dynamic_axes() + data = dict(np.load(npz_path)) + logger.info(f"[FP8] Loaded calibration data from {npz_path} ({len(data)} tensors)") + return data except Exception as e: - logger.warning(f"[FP8] Could not get dynamic_axes from model_data: {e}. Skipping restore.") - return - - if not dynamic_axes: - logger.warning("[FP8] dynamic_axes is empty — skipping dynamic dim restore.") - return - - model = onnx.load(onnx_fp8_path, load_external_data=False) - - restored_count = 0 - for graph_input in model.graph.input: - name = graph_input.name - if name not in dynamic_axes: - continue - axes = dynamic_axes[name] - dims = graph_input.type.tensor_type.shape.dim - for dim_idx, symbolic_name in axes.items(): - if dim_idx < len(dims): - dim = dims[dim_idx] - dim.ClearField("dim_value") - dim.dim_param = symbolic_name - restored_count += 1 - - for graph_output in model.graph.output: - name = graph_output.name - if name not in dynamic_axes: - continue - axes = dynamic_axes[name] - dims = graph_output.type.tensor_type.shape.dim - for dim_idx, symbolic_name in axes.items(): - if dim_idx < len(dims): - dim = dims[dim_idx] - dim.ClearField("dim_value") - dim.dim_param = symbolic_name - restored_count += 1 - - if restored_count == 0: - logger.warning("[FP8] No dynamic dimensions restored — graph inputs may already be dynamic.") - return - - # Save only the protobuf (weight data stays in existing external file). - # load_external_data=False keeps tensor data_location=EXTERNAL references intact, - # so onnx.save() writes a small protobuf that still points to the existing _data file. - onnx.save(model, onnx_fp8_path) - logger.info( - f"[FP8] Restored {restored_count} dynamic dimensions in {os.path.basename(onnx_fp8_path)}" - ) + logger.warning(f"[FP8] Cannot load calibration data from {npz_path}: {e}. Will recapture.") + try: + os.remove(npz_path) + except OSError: + pass + return None -def generate_unet_calibration_data( - model_data, - opt_batch_size: int, - opt_image_height: int, - opt_image_width: int, - num_batches: int = 8, -) -> List[Dict[str, np.ndarray]]: - """ - Generate calibration data for SDXL-Turbo UNet FP8 quantization. +# modelopt's expand_node_names_from_patterns feeds these straight into re.match, +# so they're regex (not glob) — leading `*` would raise "nothing to repeat". +# `.*time_emb.*` already covers `time_embedding` since `time_emb` is a substring. +_DEFAULT_EXCLUDE_PATTERNS = [r".*time_emb.*", r".*add_emb.*"] - Returns a list of input dicts matching the ONNX model's input names, - with values as numpy arrays shaped to the TRT optimization profile's opt shapes. +# Feature-specific Q/DQ exclusions applied only when the corresponding feature flag +# is active — keeps plain-UNet Q/DQ counts unaffected. +_FEATURE_EXCLUDE_PATTERNS = { + "cached_attn": [r".*kvo_cache.*"], + "controlnet": [r".*down_block_additional_residuals.*", r".*mid_block_additional_residual.*"], + "ipadapter": [r".*to_k_ip.*", r".*to_v_ip.*", r".*to_out_ip.*"], +} - Args: - model_data: UNet BaseModel instance (provides input names, kvo_cache_shapes, - text_maxlen, embedding_dim, cache_maxframes). - opt_batch_size: Optimal batch size from TRT profile (typically 1 for - frame_buffer_size=1). The UNet input dim is 2*opt_batch_size - because cond + uncond are batched together. - opt_image_height: Optimal image height in pixels (e.g. 512). - opt_image_width: Optimal image width in pixels (e.g. 512). - num_batches: Number of calibration batches. Capped at 8 for SDXL-scale - models: each batch contains 70 KVO cache tensors (~2.2 GB), - so 128 batches would require ~281 GB RAM. FP8 is less - sensitive to calibration size than INT8 (wider dynamic range). - Returns: - List of dicts: [{input_name: np.ndarray}, ...] — one dict per batch. - """ - latent_h = opt_image_height // 8 - latent_w = opt_image_width // 8 - # UNet always receives 2× the batch (cond + uncond paired) - effective_batch = 2 * opt_batch_size - - input_names = model_data.get_input_names() - - # Fixed seed for reproducible calibration - rng = np.random.default_rng(seed=42) - - # Pre-read model_data properties once to avoid repeated attribute access - text_maxlen = getattr(model_data, "text_maxlen", 77) - embedding_dim = getattr(model_data, "embedding_dim", 2048) - cache_maxframes = getattr(model_data, "cache_maxframes", 4) - kvo_cache_shapes = getattr(model_data, "kvo_cache_shapes", []) - num_ip_layers = getattr(model_data, "num_ip_layers", 1) - control_inputs = getattr(model_data, "control_inputs", {}) - - calibration_dataset = [] - - for i in range(num_batches): - batch_data = {} - - for name in input_names: - if name == "sample": - # Noisy latents in float32 (UNet ingests fp32 sample before internal autocast) - # VAE latent scale: 0.18215 for SDXL - data = (rng.standard_normal((effective_batch, 4, latent_h, latent_w)) * 0.18215) - batch_data[name] = data.astype(np.float32) - - elif name == "timestep": - # Timesteps: float32, shape (effective_batch,) - # Sample broadly across [0, 999] to cover full activation range. - t = rng.integers(0, 1000, size=(effective_batch,)) - batch_data[name] = t.astype(np.float32) - - elif name == "encoder_hidden_states": - # CLIP/OpenCLIP text embeddings: float16 for fp16 SDXL models - # Scale 0.01 approximates typical normalized text embedding magnitude. - data = (rng.standard_normal((effective_batch, text_maxlen, embedding_dim)) * 0.01) - batch_data[name] = data.astype(np.float16) - - elif name == "ipadapter_scale": - # IP-Adapter per-layer scale: float32, shape (num_ip_layers,) - batch_data[name] = np.ones((num_ip_layers,), dtype=np.float32) - - elif name.startswith("input_control_"): - # ControlNet residual tensors: float16 - if name in control_inputs: - spec = control_inputs[name] - data = rng.standard_normal( - (effective_batch, spec["channels"], spec["height"], spec["width"]) - ) - batch_data[name] = data.astype(np.float16) - - elif name.startswith("kvo_cache_in_"): - # KVO cached attention inputs: float16 - # shape = (2, cache_maxframes, kvo_calib_batch, seq_len, hidden_dim) - # dim[0]=2: K/V pair (must match ONNX trace, which always uses 2). - # dim[2]: Must equal sample's batch dimension (effective_batch = 2 * opt_batch_size) - # because both share the ONNX dynamic axis "2B". Using a different value - # causes Concat dimension mismatches in attention layers during calibration. - # Zeros = cold cache. Conservative but avoids over-fitting calibration - # ranges to cached-attention activation patterns. - idx = int(name.rsplit("_", 1)[-1]) - if idx < len(kvo_cache_shapes): - seq_len, hidden_dim = kvo_cache_shapes[idx] - kvo_calib_batch = effective_batch # Must match sample batch (ONNX axis "2B") - batch_data[name] = np.zeros( - (2, cache_maxframes, kvo_calib_batch, seq_len, hidden_dim), - dtype=np.float16, - ) - - calibration_dataset.append(batch_data) +def _read_onnx_input_specs(onnx_path: str) -> Dict[str, tuple]: + """Return {name: (np_dtype, shape)} from ONNX graph inputs. Shape dims are int or None.""" + import onnx as _onnx + from onnx.helper import tensor_dtype_to_np_dtype as _onnx_to_np - logger.info( - f"[FP8] Generated {num_batches} calibration batches " - f"(effective_batch={effective_batch}, latent={latent_h}x{latent_w}, " - f"inputs={len(input_names)}, kvo_count={len(kvo_cache_shapes)})" - ) - return calibration_dataset + m = _onnx.load(onnx_path, load_external_data=False) + result: Dict[str, tuple] = {} + for inp in m.graph.input: + tt = inp.type.tensor_type + dtype = _onnx_to_np(tt.elem_type) + dims = [] + if tt.HasField("shape"): + for d in tt.shape.dim: + dims.append(d.dim_value if d.HasField("dim_value") and d.dim_value > 0 else None) + result[inp.name] = (dtype, dims) + return result def quantize_onnx_fp8( - onnx_opt_path: str, - onnx_fp8_path: str, - calibration_data: Optional[List[Dict[str, np.ndarray]]] = None, - quantize_mha: bool = False, - percentile: float = 1.0, - alpha: float = 0.8, - model_data=None, - opt_batch_size: int = 1, - opt_image_height: int = 512, - opt_image_width: int = 512, + onnx_path: str, + output_path: str, + calibration_data: Dict[str, np.ndarray], + nodes_to_exclude: Optional[List[str]] = None, + disable_mha_qdq: bool = True, + use_cached_attn: bool = False, + use_controlnet: bool = False, + num_ip_layers: int = 0, ) -> None: """ - Insert FP8 Q/DQ nodes into an optimized ONNX model via nvidia-modelopt. + Inject native FLOAT8E4M3FN Q/DQ nodes into a FP16 ONNX model via ORT. - Takes the FP16-optimized ONNX (*.opt.onnx), runs calibration to collect - activation ranges, and writes a new ONNX with QuantizeLinear/DequantizeLinear - nodes annotated for FP8 E4M3 precision. TRT compiles this with - STRONGLY_TYPED + FP8 builder flags. + The output ONNX feeds directly into Engine._build_fp8 (STRONGLY_TYPED path). Args: - onnx_opt_path: Input FP16 optimized ONNX path (*.opt.onnx). - onnx_fp8_path: Output FP8 quantized ONNX path (*.fp8.onnx). - calibration_data: Unused. Kept for backward compatibility. - quantize_mha: Enable FP8 quantization of multi-head attention ops. - Kept False — MHA analysis via ORT inference adds ~3 hours to build. - Non-MHA ops (Conv, Gemm, MatMul outside MHA) are still FP8. - percentile: Unused. Kept for backward compatibility (entropy calibration - does not use percentile clipping). - alpha: SmoothQuant alpha — balances quantization difficulty between - activations (alpha→0) and weights (alpha→1). 0.8 is optimal - for transformer attention layers. - model_data: UNet BaseModel instance for building calibration_shapes. - If None, RandomDataProvider defaults all dynamic dims to 1. - opt_batch_size: Optimal batch size from TRT profile. - opt_image_height: Optimal image height in pixels. - opt_image_width: Optimal image width in pixels. + onnx_path: Input FP16 ONNX (may use external data format). + output_path: Output path for the FP8-quantized ONNX. + calibration_data: Dict[str, np.ndarray] from capture_calibration_data(). + nodes_to_exclude: ONNX node name patterns to skip quantization on. + Defaults to time/add embedding layers. + disable_mha_qdq: Skip MHA-specific Q/DQ injection (default True for Ada). + General FP8 calibration still inserts Q/DQ on all attention + MatMuls; TRT Myelin fuses them into _gemm_mha_v2 FP8 kernels. """ try: from modelopt.onnx.quantization import quantize as modelopt_quantize except ImportError as e: raise ImportError( - "nvidia-modelopt is required for FP8 quantization. " - "Install with: pip install 'nvidia-modelopt[onnx]'" + "nvidia-modelopt[onnx] is required for ONNX-level FP8 quantization.\n" + "Install with: pip install 'nvidia-modelopt[onnx]>=0.19.0'\n" + "Also ensure onnxruntime-gpu >= 1.17 is installed." ) from e - # Enable verbose ORT logging so Memcpy node details are visible before the - # summary warning. Severity 1 = INFO (shows per-node placement decisions). + # ORT CUDA EP requires cuDNN DLLs — PyTorch ships cuDNN under torch/lib on Windows. + # Best-effort: failing here just lets ORT surface its own loader error downstream. try: - import onnxruntime as _ort - _ort.set_default_logger_severity(1) - logger.info("[FP8] ORT log_severity_level set to 1 (INFO) for Memcpy diagnostics") - except Exception: - pass - - input_size_mb = os.path.getsize(onnx_opt_path) / (1024 * 1024) - logger.info(f"[FP8] Starting ONNX FP8 quantization") - logger.info(f"[FP8] Input: {onnx_opt_path} ({input_size_mb:.0f} MB)") - logger.info(f"[FP8] Output: {onnx_fp8_path}") - logger.info(f"[FP8] Config: quantize_mha={quantize_mha}, calibration=entropy, alpha={alpha}") - logger.info(f"[FP8] Calibration: RandomDataProvider with calibration_shapes (model_data={'provided' if model_data is not None else 'none'})") - - # Patch ByteSize() for >2GB ONNX models: modelopt calls onnx_model.ByteSize() - # to auto-detect external data format, but protobuf cannot serialize >2GB protos. - # Return a large value on failure so modelopt correctly uses external data format. - import onnx as _onnx - from google.protobuf.message import EncodeError as _EncodeError + import torch as _torch - _orig_byte_size = _onnx.ModelProto.ByteSize + _torch_lib = os.path.join(os.path.dirname(_torch.__file__), "lib") + if os.path.isdir(_torch_lib) and _torch_lib not in os.environ.get("PATH", ""): + os.environ["PATH"] = _torch_lib + os.pathsep + os.environ.get("PATH", "") + except Exception as e: + logger.debug(f"[FP8] cuDNN PATH setup skipped: {e}") - def _safe_byte_size(self): - try: - return _orig_byte_size(self) - except _EncodeError: - return 3 * (1024**3) # >2GB → triggers external data format - - _onnx.ModelProto.ByteSize = _safe_byte_size - - # Ensure NVIDIA DLLs (cuDNN, cuBLAS, CUDA runtime) are on PATH so modelopt's - # ORT sessions can use CUDA/TensorRT EPs instead of CPU EP (which is stricter - # about mixed-precision Cast nodes and fails on FP16 models). - _nvidia_pkg_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname( - os.path.dirname(os.path.abspath(__file__))))), os.pardir, "venv", "Lib", - "site-packages", "nvidia") - _nvidia_pkg_dir = os.path.normpath(_nvidia_pkg_dir) - if not os.path.isdir(_nvidia_pkg_dir): - # Fallback: find via importlib - try: - import nvidia.cudnn - _nvidia_pkg_dir = os.path.dirname(os.path.dirname(nvidia.cudnn.__file__)) - except ImportError: - _nvidia_pkg_dir = None - - if _nvidia_pkg_dir and os.path.isdir(_nvidia_pkg_dir): - _bin_dirs = [] - for _subpkg in ("cudnn", "cublas", "cuda_runtime", "cufft", "curand"): - _bdir = os.path.join(_nvidia_pkg_dir, _subpkg, "bin") - if os.path.isdir(_bdir) and _bdir not in os.environ.get("PATH", ""): - _bin_dirs.append(_bdir) - if _bin_dirs: - os.environ["PATH"] = os.pathsep.join(_bin_dirs) + os.pathsep + os.environ.get("PATH", "") - logger.info(f"[FP8] Added {len(_bin_dirs)} NVIDIA DLL dirs to PATH") - - # Build calibration_shapes string for modelopt's RandomDataProvider. - # RandomDataProvider calls _get_tensor_shape() which sets ALL dynamic dims to 1. - # For a 512x512 UNet, sample becomes (1,4,1,1) instead of (2,4,64,64), causing - # spatial dimension mismatches at UNet skip-connection Concat nodes (up_blocks). - # calibration_shapes overrides _get_tensor_shape() per input — only specified - # inputs bypass the default-to-1 fallback. - # - # Format: "input0:d0xd1x...,input1:d0xd1x..." (modelopt parse_shapes_spec format) - calibration_shapes_str: Optional[str] = None - if model_data is not None: - latent_h = opt_image_height // 8 - latent_w = opt_image_width // 8 - effective_batch = 2 * opt_batch_size - text_maxlen = getattr(model_data, "text_maxlen", 77) - embedding_dim = getattr(model_data, "embedding_dim", 2048) - # Use cache_maxframes=1 for calibration. The attention processor does: - # kvo_cache[0] → (cache_maxframes, batch, S, H) - # .transpose(0,1).flatten(1,2) → (batch, cache_maxframes*S, H) - # With cache_maxframes=4, ONNX shape-computation nodes create Concat ops - # that mix dim=4 (cache_maxframes) with dim=2 (batch), causing Concat axis - # mismatch errors in ORT. cache_maxframes=1 is valid (within TRT profile - # min range) and avoids the conflict. FP8 only needs valid activation ranges. - calib_cache_maxframes = 1 - kvo_cache_shapes = getattr(model_data, "kvo_cache_shapes", []) - num_ip_layers = getattr(model_data, "num_ip_layers", 1) - control_inputs = getattr(model_data, "control_inputs", {}) - kvo_calib_batch = effective_batch # Must match sample batch (ONNX axis "2B") - - shape_parts = [] - try: - input_names = model_data.get_input_names() - except Exception: - input_names = [] - - for name in input_names: - if name == "sample": - shape_parts.append(f"{name}:{effective_batch}x4x{latent_h}x{latent_w}") - elif name == "timestep": - shape_parts.append(f"{name}:{effective_batch}") - elif name == "encoder_hidden_states": - shape_parts.append(f"{name}:{effective_batch}x{text_maxlen}x{embedding_dim}") - elif name == "ipadapter_scale": - shape_parts.append(f"{name}:{num_ip_layers}") - elif name.startswith("input_control_") and name in control_inputs: - spec = control_inputs[name] - shape_parts.append( - f"{name}:{effective_batch}x{spec['channels']}x{spec['height']}x{spec['width']}" - ) - elif name.startswith("kvo_cache_in_"): - idx = int(name.rsplit("_", 1)[-1]) - if idx < len(kvo_cache_shapes): - seq_len, hidden_dim = kvo_cache_shapes[idx] - shape_parts.append( - f"{name}:2x{calib_cache_maxframes}x{kvo_calib_batch}x{seq_len}x{hidden_dim}" - ) - - if shape_parts: - calibration_shapes_str = ",".join(shape_parts) - logger.info( - f"[FP8] calibration_shapes: {len(shape_parts)} inputs " - f"(sample={effective_batch}x4x{latent_h}x{latent_w}, " - f"kvo={len([p for p in shape_parts if 'kvo_cache_in' in p])} caches " - f"calib_frames={calib_cache_maxframes})" - ) - else: - logger.warning( - "[FP8] model_data not provided — RandomDataProvider will default all " - "dynamic dims to 1. UNet Concat nodes may fail for non-trivial models." - ) + # Flush pending GPU work before ORT CUDA EP claims VRAM. A failure here usually + # signals a wedged CUDA context — surface at debug so it's not invisible. + try: + import torch as _t - quantize_kwargs = { + if _t.cuda.is_available(): + _t.cuda.synchronize() + _t.cuda.empty_cache() + import gc as _gc + + _gc.collect() + except Exception as e: + logger.debug(f"[FP8] pre-quantize CUDA flush skipped: {e}") + + if nodes_to_exclude is None: + nodes_to_exclude = list(_DEFAULT_EXCLUDE_PATTERNS) + if use_cached_attn: + nodes_to_exclude.extend(_FEATURE_EXCLUDE_PATTERNS["cached_attn"]) + if use_controlnet: + nodes_to_exclude.extend(_FEATURE_EXCLUDE_PATTERNS["controlnet"]) + if num_ip_layers > 0: + nodes_to_exclude.extend(_FEATURE_EXCLUDE_PATTERNS["ipadapter"]) + + # The optimized ONNX may expose fewer inputs than capture_calibration_data + # records (e.g. SDXL UnifiedExportWrapper hides text_embeds/time_ids inside + # the graph) and may declare different dtypes than the captured tensors — + # e.g. SDXL exports `sample` as FP32 even though the unet runs FP16. + # modelopt's CalibrationDataProvider asserts strict count match and ORT's + # inference probe rejects dtype mismatches, so filter+cast accordingly. + _onnx_inputs = {k: v[0] for k, v in _read_onnx_input_specs(onnx_path).items()} + _dropped = set(calibration_data.keys()) - set(_onnx_inputs) + if _dropped: + logger.info(f"[FP8] Dropping calibration keys not exposed by ONNX: {sorted(_dropped)}") + calibration_data = {k: v for k, v in calibration_data.items() if k in _onnx_inputs} + _missing = set(_onnx_inputs) - set(calibration_data.keys()) + if _missing: + raise RuntimeError(f"[FP8] Calibration data missing required ONNX inputs: {sorted(_missing)}") + for _k, _expected in _onnx_inputs.items(): + if calibration_data[_k].dtype != _expected: + logger.info(f"[FP8] Casting calibration '{_k}': {calibration_data[_k].dtype} → {_expected}") + calibration_data[_k] = calibration_data[_k].astype(_expected) + + import inspect as _inspect + + _params = set(_inspect.signature(modelopt_quantize).parameters.keys()) + + kwargs = { + "onnx_path": onnx_path, "quantize_mode": "fp8", - "output_path": onnx_fp8_path, - # entropy: minimizes KL divergence to find optimal clipping point for each tensor. - # Better than percentile=1.0 (no clipping) which allows outliers to stretch the - # quantization range, reducing precision for the bulk of activations. - "calibration_method": "entropy", - "alpha": alpha, + "output_path": output_path, + "calibration_method": "max", + "calibration_eps": ["cuda:0"], + "calibration_data": calibration_data, + "high_precision_dtype": "fp16", "use_external_data_format": True, - # override_shapes replaces dynamic dims in the ONNX model itself with static - # values BEFORE any ORT sessions (MHA analysis or calibration) are created. - # Without this, ORT's internal shape inference with dynamic dims causes - # Concat failures (e.g. KVO cache dims vs sample batch dims). - # calibration_shapes additionally tells RandomDataProvider what shapes to - # generate for the calibration data. - "override_shapes": calibration_shapes_str, - "calibration_shapes": calibration_shapes_str, - # Use default EPs ["cpu","cuda:0","trt"] — CPU-only would fail on this FP16 SDXL - # model because ORT's mandatory CastFloat16Transformer inserts Cast nodes that - # conflict with existing Cast nodes in the upsampler conv. - # disable_mha_qdq=True: skip MHA pattern analysis (avoids 3-hour ORT inference - # pass over the full model graph). Non-MHA ops (Conv, Gemm, MatMul outside MHA) - # still get FP8 Q/DQ nodes via the normal KGEN/CASK path. - "disable_mha_qdq": not quantize_mha, - # calibrate_per_node: calibrate one node at a time to reduce peak VRAM during - # calibration. Essential for large UNets (83 inputs, 7993 nodes) to avoid OOM. - "calibrate_per_node": True, + "calibrate_per_node": False, + "disable_mha_qdq": disable_mha_qdq, + "nodes_to_exclude": nodes_to_exclude, } + # enable_gemv_detection_for_trt was removed in modelopt >= 0.42 + if "enable_gemv_detection_for_trt" in _params: + kwargs["enable_gemv_detection_for_trt"] = False - try: - modelopt_quantize(onnx_opt_path, **quantize_kwargs) - except TypeError as e: - # Older nvidia-modelopt versions may not support newer kwargs. - # Strip down to base parameters and retry. - logger.warning(f"[FP8] Retrying with reduced kwargs (TypeError: {e})") - for _k in ("alpha", "disable_mha_qdq", "calibrate_per_node"): - quantize_kwargs.pop(_k, None) - modelopt_quantize(onnx_opt_path, **quantize_kwargs) - except Exception as e: - # MHA analysis (disable_mha_qdq=False) requires an ORT inference run that - # fails with KVO cached attention models. Retry with disable_mha_qdq=True - # to skip the ORT session entirely — MHA layers use FP16, rest uses FP8. - if not quantize_kwargs.get("disable_mha_qdq", True): - # Delete intermediate files written during the failed attempt to free - # disk space before the retry (each set is ~23GB for SDXL-scale models). - _base = os.path.splitext(onnx_opt_path)[0] # strip .onnx - for _suffix in ( - "_static.onnx", "_static.onnx_data", # from override_shapes - "_named.onnx", "_named.onnx_data", - "_named_extended.onnx", "_named_extended.onnx_data", - "_ir10.onnx", "_ir10.onnx_data", - "_static_named.onnx", "_static_named.onnx_data", - "_static_ir10.onnx", "_static_ir10.onnx_data", - ): - _f = _base + _suffix - if os.path.exists(_f): - os.remove(_f) - logger.info(f"[FP8] Cleaned up intermediate: {os.path.basename(_f)}") - logger.warning( - f"[FP8] MHA analysis failed ({type(e).__name__}: {e}). " - "Retrying with disable_mha_qdq=True (MHA layers will use FP16 precision)." - ) - quantize_kwargs["disable_mha_qdq"] = True - modelopt_quantize(onnx_opt_path, **quantize_kwargs) - else: - raise - finally: - _onnx.ModelProto.ByteSize = _orig_byte_size # Restore original method - try: - import onnxruntime as _ort - _ort.set_default_logger_severity(2) # Restore to WARNING - except Exception: - pass + logger.info( + f"[FP8] ONNX-level FP8 quantization: {os.path.basename(onnx_path)}" + f" → {os.path.basename(output_path)}" + f" ({next(iter(calibration_data.values())).shape[0]} calibration samples," + f" disable_mha_qdq={disable_mha_qdq})" + ) + modelopt_quantize(**kwargs) - if not os.path.exists(onnx_fp8_path): - raise RuntimeError( - f"[FP8] Quantization completed but output file not found: {onnx_fp8_path}" - ) + if not os.path.exists(output_path): + raise RuntimeError(f"[FP8] modelopt_quantize completed but output not found: {output_path}") - # --- Restore dynamic axes --- - # ModelOpt's override_shapes baked static dim_value into graph inputs for calibration. - # TRT needs dynamic dim_param on inputs/outputs to accept optimization profiles. - if model_data is not None: - try: - _restore_dynamic_axes(onnx_fp8_path, model_data) - except Exception as restore_err: - logger.warning( - f"[FP8] Failed to restore dynamic axes: {restore_err}. " - "TRT engine build may fail with static shape profile mismatch." - ) + size_mb = os.path.getsize(output_path) / (1024**2) + logger.info(f"[FP8] FP8 ONNX written: {output_path} ({size_mb:.1f} MB)") + if size_mb > 5000: + logger.warning( + f"[FP8] FP8 ONNX is unexpectedly large ({size_mb:.0f} MB > 5000 MB). " + "FP32 Cast bloat may be active — check high_precision_dtype='fp16' is honored." + ) - output_size_mb = os.path.getsize(onnx_fp8_path) / (1024 * 1024) - ratio = output_size_mb / input_size_mb if input_size_mb > 0 else 0 - logger.info( - f"[FP8] Quantization complete: {input_size_mb:.0f} MB → {output_size_mb:.0f} MB " - f"(ratio: {ratio:.2f}x)" - ) + # Sentinel marker — only written after modelopt_quantize returns. The builder's + # cache check looks for this file, so a crash mid-write leaves no false-positive. + with open(output_path + ".ok", "w") as _f: + _f.write("ok") diff --git a/src/streamdiffusion/acceleration/tensorrt/models/__init__.py b/src/streamdiffusion/acceleration/tensorrt/models/__init__.py index f0f2c4b9..b6a1bd62 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/__init__.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/__init__.py @@ -1,13 +1,14 @@ -from .models import Optimizer, BaseModel, CLIP, UNet, VAE, VAEEncoder -from .controlnet_models import ControlNetTRT, ControlNetSDXLTRT +from .controlnet_models import ControlNetSDXLTRT, ControlNetTRT +from .models import CLIP, VAE, BaseModel, Optimizer, UNet, VAEEncoder + __all__ = [ "Optimizer", - "BaseModel", + "BaseModel", "CLIP", "UNet", "VAE", "VAEEncoder", "ControlNetTRT", "ControlNetSDXLTRT", -] \ No newline at end of file +] diff --git a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py index 6179ffc9..be41d502 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py @@ -2,10 +2,10 @@ import torch import torch.nn.functional as F - from diffusers.models.attention_processor import Attention from diffusers.utils import USE_PEFT_BACKEND + class CachedSTAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -28,9 +28,9 @@ def __init__(self): # clone/contiguous path. Set to True by wrapper.py after engine build. self._curr_key_buf: Optional[torch.Tensor] = None self._curr_value_buf: Optional[torch.Tensor] = None - self._cached_key_tr_buf: Optional[torch.Tensor] = None # transposed cache key + self._cached_key_tr_buf: Optional[torch.Tensor] = None # transposed cache key self._cached_value_tr_buf: Optional[torch.Tensor] = None # transposed cache value - self._kvo_out_buf: Optional[torch.Tensor] = None # (2, 1, B, S, H) + self._kvo_out_buf: Optional[torch.Tensor] = None # (2, 1, B, S, H) self._use_prealloc: bool = False def _ensure_buffers( diff --git a/src/streamdiffusion/acceleration/tensorrt/models/models.py b/src/streamdiffusion/acceleration/tensorrt/models/models.py index 99f306dc..a85b2415 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/models.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/models.py @@ -17,10 +17,15 @@ # limitations under the License. # +import logging + import onnx_graphsurgeon as gs import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from onnx import shape_inference + + +logger = logging.getLogger(__name__) from polygraphy.backend.onnx.loader import fold_constants @@ -68,6 +73,71 @@ def infer_shapes(self, return_onnx=False): if return_onnx: return onnx_graph + def fix_layernorm_dtypes(self, return_onnx=False): + """ + Fix LN dtype mismatch in FP8-quantized UNet without breaking Q/DQ adjacency. + + nvidia-modelopt DequantizeLinear outputs FP32; LN scale/bias stay FP16 + from the original weights. STRONGLY_TYPED rejects the mismatch + (TRT 10.x: "INormalizationLayer 'input' and 'scale' must have identical types"). + + Fix: promote scale/bias FP16→FP32 to match the FP32 input, and promote + the LN output dtype to FP32 so consumers see consistent types. + + Earlier versions also inserted a Cast(FP32→FP16) on each LN output. + That Cast pollutes Q/DQ adjacency — TRT's quantization fusion expects + the LN→Q edge to be direct. The Cast made the engine *build* (Q/DQ + count looked healthy, ~3082) but the DQ scale was applied to the + post-Cast tensor instead of the actually-quantized tensor → numerically + broken → pure noise at inference. Per SDXL UNet structure every LN + output feeds only QuantizeLinear (qkv/ff projections) which accepts + FP32 directly, so the Cast was unnecessary. + """ + import numpy as np + + promoted = 0 + out_promoted = 0 + non_q_consumers_seen = 0 + + for node in self.graph.nodes: + if node.op != "LayerNormalization": + continue + if not node.inputs: + continue + + for param in node.inputs[1:]: # scale, then optional bias + if param is None or not hasattr(param, "values") or param.values is None: + continue + if param.values.dtype == np.float16: + param.values = param.values.astype(np.float32) + promoted += 1 + + out_var = node.outputs[0] + if hasattr(out_var, "dtype") and out_var.dtype == np.float16: + out_var.dtype = np.float32 + out_promoted += 1 + + # Sanity: warn if any LN feeds something other than QuantizeLinear, + # since FP32 promotion of the output edge could then introduce a + # downstream type mismatch the original Cast was masking. + for consumer in self.graph.nodes: + if out_var in consumer.inputs and consumer.op != "QuantizeLinear": + non_q_consumers_seen += 1 + + if promoted or out_promoted: + logger.info( + f"[Optimizer] fix_layernorm_dtypes: promoted {promoted} initializer(s) " + f"and {out_promoted} LN output dtype(s) FP16→FP32 (no Cast insertion)" + ) + if non_q_consumers_seen: + logger.warning( + f"[Optimizer] fix_layernorm_dtypes: {non_q_consumers_seen} non-QuantizeLinear " + f"LN consumer(s) detected — FP32 LN output may need a downstream Cast for " + f"STRONGLY_TYPED build to succeed. Standard SDXL UNet should report 0." + ) + if return_onnx: + return gs.export_onnx(self.graph) + class BaseModel: def __init__( @@ -125,6 +195,9 @@ def optimize(self, onnx_graph): opt.info(self.name + ": fold constants") opt.infer_shapes() opt.info(self.name + ": shape inference") + if any(n.op in ("QuantizeLinear", "DequantizeLinear") for n in opt.graph.nodes): + opt.fix_layernorm_dtypes() + opt.info(self.name + ": fp8 LN dtype fix") onnx_opt_graph = opt.cleanup(return_onnx=True) opt.info(self.name + ": finished") return onnx_opt_graph @@ -142,7 +215,9 @@ def check_dims(self, batch_size, image_height, image_width): assert batch_size >= effective_min_batch and batch_size <= effective_max_batch, ( f"Batch size {batch_size} not in range [{effective_min_batch}, {effective_max_batch}]" ) - assert image_height % 8 == 0 or image_width % 8 == 0 + assert image_height % 8 == 0 and image_width % 8 == 0, ( + f"image_height ({image_height}) and image_width ({image_width}) must both be divisible by 8" + ) latent_height = image_height // 8 latent_width = image_width // 8 assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape @@ -733,14 +808,20 @@ def get_sample_input(self, batch_size, image_height, image_width): export_batch_size = min(batch_size, 1) # Use batch size 1 for ONNX export to save memory base_inputs = [ + # sample dtype matches self.fp16 so the ONNX `sample` input is FP16 when + # the unet runs FP16 — eliminates an FP32→FP16 Cast at conv_in, and + # avoids a dtype mismatch when modelopt's ORT inference probe (used in + # FP8 calibration) feeds FP16 captures into the graph. torch.randn( 2 * export_batch_size, self.unet_dim, latent_height, latent_width, - dtype=torch.float32, + dtype=dtype, device=self.device, ), + # timestep stays FP32 — diffusers' sinusoidal time_proj needs FP32 for + # numerical stability; this is also what the FP16 unet expects upstream. torch.ones((2 * export_batch_size,), dtype=torch.float32, device=self.device), torch.randn(2 * export_batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), ] diff --git a/src/streamdiffusion/acceleration/tensorrt/models/utils.py b/src/streamdiffusion/acceleration/tensorrt/models/utils.py index eac854aa..7e308e82 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/utils.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/utils.py @@ -1,16 +1,19 @@ +from typing import Dict, List, Tuple + import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + def get_kvo_cache_info(unet: UNet2DConditionModel, height=512, width=512): latent_height = height // 8 latent_width = width // 8 - + kvo_cache_shapes = [] kvo_cache_structure = [] current_h, current_w = latent_height, latent_width - + for _, block in enumerate(unet.down_blocks): - if hasattr(block, 'attentions') and block.attentions is not None: + if hasattr(block, "attentions") and block.attentions is not None: block_structure = [] for attn_block in block.attentions: attn_count = 0 @@ -22,12 +25,12 @@ def get_kvo_cache_info(unet: UNet2DConditionModel, height=512, width=512): attn_count += 1 block_structure.append(attn_count) kvo_cache_structure.append(block_structure) - - if hasattr(block, 'downsamplers') and block.downsamplers is not None: + + if hasattr(block, "downsamplers") and block.downsamplers is not None: current_h //= 2 current_w //= 2 - - if hasattr(unet.mid_block, 'attentions') and unet.mid_block.attentions is not None: + + if hasattr(unet.mid_block, "attentions") and unet.mid_block.attentions is not None: block_structure = [] for attn_block in unet.mid_block.attentions: attn_count = 0 @@ -39,9 +42,9 @@ def get_kvo_cache_info(unet: UNet2DConditionModel, height=512, width=512): attn_count += 1 block_structure.append(attn_count) kvo_cache_structure.append(block_structure) - + for _, block in enumerate(unet.up_blocks): - if hasattr(block, 'attentions') and block.attentions is not None: + if hasattr(block, "attentions") and block.attentions is not None: block_structure = [] for attn_block in block.attentions: attn_count = 0 @@ -53,13 +56,13 @@ def get_kvo_cache_info(unet: UNet2DConditionModel, height=512, width=512): attn_count += 1 block_structure.append(attn_count) kvo_cache_structure.append(block_structure) - - if hasattr(block, 'upsamplers') and block.upsamplers is not None: + + if hasattr(block, "upsamplers") and block.upsamplers is not None: current_h *= 2 current_w *= 2 kvo_cache_count = sum(sum(block) for block in kvo_cache_structure) - + return kvo_cache_shapes, kvo_cache_structure, kvo_cache_count @@ -89,16 +92,33 @@ def convert_structure_to_list(structured_list): return flat_list -def create_kvo_cache(unet: UNet2DConditionModel, batch_size, cache_maxframes, height=512, width=512, - device='cuda', dtype=torch.float16): +def create_kvo_cache( + unet: UNet2DConditionModel, batch_size, cache_maxframes, height=512, width=512, device="cuda", dtype=torch.float16 +): kvo_cache_shapes, kvo_cache_structure, _ = get_kvo_cache_info(unet, height, width) - - kvo_cache = [] - for seq_length, hidden_dim in kvo_cache_shapes: - cache_tensor = torch.zeros( - 2, cache_maxframes, batch_size, seq_length, hidden_dim, - dtype=dtype, device=device - ) - kvo_cache.append(cache_tensor) - - return kvo_cache, kvo_cache_structure \ No newline at end of file + + bucket_keys: List[Tuple[int, int]] = [] + key_to_idx: Dict[Tuple[int, int], int] = {} + layer_to_bucket: List[Tuple[int, int]] = [] + outputs_by_bucket: List[List[int]] = [] + for layer_idx, (s, h) in enumerate(kvo_cache_shapes): + b = key_to_idx.get((s, h)) + if b is None: + b = len(bucket_keys) + key_to_idx[(s, h)] = b + bucket_keys.append((s, h)) + outputs_by_bucket.append([]) + slot = len(outputs_by_bucket[b]) + layer_to_bucket.append((b, slot)) + outputs_by_bucket[b].append(layer_idx) + + # layers_in_bucket is the OUTERMOST dim so bucket[layer_slot] is stride-identical + # to a standalone (2, maxframes, B, S, H) tensor — TRT's contiguous-input + # requirement is satisfied without an extra .contiguous() call. + buckets = [ + torch.zeros(len(outputs_by_bucket[b]), 2, cache_maxframes, batch_size, s, h, dtype=dtype, device=device) + for b, (s, h) in enumerate(bucket_keys) + ] + per_layer_views = [buckets[b][slot] for (b, slot) in layer_to_bucket] + + return per_layer_views, kvo_cache_structure, buckets, outputs_by_bucket diff --git a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py index 165c261e..7fa98f85 100644 --- a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py +++ b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py @@ -1,12 +1,13 @@ """Runtime TensorRT engine wrappers.""" -from .unet_engine import UNet2DConditionModelEngine, AutoencoderKLEngine -from .controlnet_engine import ControlNetModelEngine from ..engine_manager import EngineManager +from .controlnet_engine import ControlNetModelEngine +from .unet_engine import AutoencoderKLEngine, UNet2DConditionModelEngine + __all__ = [ "UNet2DConditionModelEngine", - "AutoencoderKLEngine", + "AutoencoderKLEngine", "ControlNetModelEngine", "EngineManager", -] \ No newline at end of file +] diff --git a/src/streamdiffusion/acceleration/tensorrt/utilities.py b/src/streamdiffusion/acceleration/tensorrt/utilities.py index 17768fd0..5da5686d 100644 --- a/src/streamdiffusion/acceleration/tensorrt/utilities.py +++ b/src/streamdiffusion/acceleration/tensorrt/utilities.py @@ -19,11 +19,9 @@ # import gc - -# Set up logger for this module import logging import os -from collections import OrderedDict +from collections import OrderedDict, deque from dataclasses import dataclass from typing import Optional, Union @@ -33,6 +31,8 @@ import tensorrt as trt import torch +from streamdiffusion.tools.gpu_profiler import profiler as _gpu_profiler + # cuda-python 13.x renamed 'cudart' to 'cuda.bindings.runtime' try: @@ -135,8 +135,8 @@ def detect_gpu_profile(device: int = 0) -> GPUBuildProfile: # opt_level=4 for all tiers: always compiles dynamic kernels (better kernel # selection than level-3 heuristics, even for static shapes). Level 5 avoided — # causes OOM during tactic profiling (160 GiB requests observed). - # NOTE: tactic 0x3e9 "Assertion g.nodes.size() == 0" errors in TRT 10.12 are - # a known TRT bug — benign, the tactic is skipped and build succeeds. + # NOTE: tactic 0x3e9 "Assertion g.nodes.size() == 0" errors observed in TRT 10.12–10.16 — + # benign (TRT skips the tactic and picks another, build completes normally). if cc >= (12, 0): tier = "blackwell" opt_level = 4 @@ -170,7 +170,7 @@ def detect_gpu_profile(device: int = 0) -> GPUBuildProfile: tiling_optimization_level=tiling, l2_limit_for_tiling=l2, # use full L2 as tiling budget (static builds only) max_aux_streams=0, # 0 = let TRT decide (avoids "[MS] disabled" spam) - sparse_weights=True, # always examine; no downside if not sparse + sparse_weights=False, # dense SD/SDXL weights; inspection adds build overhead, no runtime benefit enable_runtime_activation_resize=True, max_workspace_cap_bytes=max_ws_cap, ) @@ -273,8 +273,10 @@ def _apply_gpu_profile_to_config( # any model where TRT can't assign that many streams (e.g. VAE decoder which is # too sequential). TRT's heuristic silently chooses the right value per model. - # SPARSE_WEIGHTS: let TRT examine weight tensors for structured 2:4 sparsity - # and use Sparse Tensor Core kernels if suitable. Zero downside for dense weights. + # SPARSE_WEIGHTS: included for future 2:4-sparse pruned UNet variants. Stock + # SD/SDXL weights are dense, so TRT's sparsity inspection runs during build but + # finds no sparse kernels to select — small build-time cost, no runtime benefit. + # Controlled via gpu_profile.sparse_weights so it can be disabled per deployment. if gpu_profile.sparse_weights: try: config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) @@ -294,16 +296,47 @@ def _apply_gpu_profile_to_config( logger.debug("[TRT Config] RUNTIME_ACTIVATION_RESIZE_10_10 not supported — skipping") # avg_timing_iterations: number of timing runs averaged per tactic candidate. - # Default 1 produces noisy measurements — occasional slow GPU clocks or cache - # miss can unfairly disqualify the best kernel. Value of 4 gives stable rankings - # with minimal extra build time (4× timing overhead, which is tiny vs. compilation). - # TRT 10.12 confirmed to support this property. + # Default 1 produces noisy measurements. Blackwell (SM_120+) requires 8 passes — + # WDDM kernel-launch latency jitter is higher and needs more averaging to stably + # rank tactics. Ada/Ampere use 4 (sufficient; lower variance). try: - config.avg_timing_iterations = 4 - logger.info("[TRT Config] avg_timing_iterations=4") + timing_iters = 8 if gpu_profile.compute_capability >= (12, 0) else 4 + config.avg_timing_iterations = timing_iters + logger.info(f"[TRT Config] avg_timing_iterations={timing_iters}") except AttributeError: logger.debug("[TRT Config] avg_timing_iterations not supported — skipping") + # Tactic sources — SM_120+ (Blackwell) only: + # cuDNN conv/norm tactics don't exist in the consumer-Blackwell codegen path. + # Leaving CUDNN in the default set wastes profiling time and can steer Myelin + # toward a suboptimal fallback. Scope to CUBLAS + CUBLAS_LT + JIT_CONVOLUTIONS + # + EDGE_MASK_CONVOLUTIONS — the sources that produce valid SM_120 kernels. + # TRT 10.16 exposes TacticSource as an int enum (not IntFlag), so the bitmask + # is built via (1 << int(source)). No-op on Ada/Ampere. + if gpu_profile.compute_capability >= (12, 0): + try: + sources = ( + (1 << int(trt.TacticSource.CUBLAS)) + | (1 << int(trt.TacticSource.CUBLAS_LT)) + | (1 << int(trt.TacticSource.JIT_CONVOLUTIONS)) + | (1 << int(trt.TacticSource.EDGE_MASK_CONVOLUTIONS)) + ) + config.set_tactic_sources(sources) + logger.info( + "[TRT Config] tactic sources = CUBLAS|CUBLAS_LT|JIT_CONV|EDGE_MASK (CUDNN excluded for SM_120+)" + ) + except (AttributeError, TypeError) as e: + logger.debug(f"[TRT Config] set_tactic_sources not available: {e}") + + # max_num_tactics: cap profiling candidates per layer to reduce build time. + # Available since TRT 10.x; -1 (default) lets TRT decide heuristically. 64 is a + # reasonable cap that matches FLUX's config. Gracefully ignored on older TRT. + try: + config.max_num_tactics = 64 + logger.info("[TRT Config] max_num_tactics=64") + except AttributeError: + logger.debug("[TRT Config] max_num_tactics not supported — skipping") + # Map of numpy dtype -> torch dtype numpy_to_torch_dtype_dict = { @@ -368,7 +401,7 @@ class TRTProfiler(trt.IProfiler): def __init__(self, name: str = ""): super().__init__() self.name = name - self._runs: list = [] # list of lists: [[( layer_name, ms ), ...], ...] + self._runs: deque = deque(maxlen=500) # rolling window; prevents unbounded growth at 30 fps self._current: list = [] # accumulator for the in-progress inference def report_layer_time(self, layer_name: str, ms: float) -> None: # noqa: N802 @@ -431,6 +464,9 @@ def __init__( self._last_device = None # Cached set of input tensor names — immutable after engine build self._allowed_inputs = None + # Cached ExternalStream wrapping the engine's polygraphy stream; allocated on + # first infer() call so we avoid constructing a new Python wrapper every frame. + self._engine_ext_stream = None def __del__(self): # Check if AttributeError: 'Engine' object has no attribute 'buffers' @@ -802,7 +838,8 @@ def activate(self, reuse_device_memory=None): # NOTE: profiler presence disables CUDA graph replay in infer() — IProfiler # cannot report per-layer times through a captured graph. self.profiler: Optional[TRTProfiler] = None - if os.environ.get("STREAMDIFFUSION_PROFILE_TRT"): + _profile_trt = os.environ.get("STREAMDIFFUSION_PROFILE_TRT", "").strip().lower() + if _profile_trt in ("1", "true", "yes", "on"): self.profiler = TRTProfiler(name=os.path.basename(self.engine_path)) self.context.profiler = self.profiler logger.info(f"[TRTProfiler] Attached to {os.path.basename(self.engine_path)} (CUDA graphs disabled)") @@ -940,35 +977,59 @@ def infer(self, feed_dict, stream, use_cuda_graph=False): if self.profiler is not None: self.profiler.start_run() - for name, buf in feed_dict.items(): - self.tensors[name].copy_(buf) + # Run input copies on the engine stream so they share ordering with the + # graph launch — copy_() on PyTorch's default stream would race the engine. + if self._engine_ext_stream is None: + self._engine_ext_stream = torch.cuda.ExternalStream(stream.ptr) + pt_stream = torch.cuda.current_stream().cuda_stream + if pt_stream != stream.ptr: + logger.debug( + "[TRT] PyTorch default stream (0x%x) differs from engine stream (0x%x) " + "— copy_() executes on engine stream to guarantee ordering.", + pt_stream, + stream.ptr, + ) + with torch.cuda.stream(self._engine_ext_stream): + for name, buf in feed_dict.items(): + self.tensors[name].copy_(buf) for name, tensor in self.tensors.items(): if not self.context.set_tensor_address(name, tensor.data_ptr()): raise RuntimeError(f"TensorRT: set_tensor_address failed for '{name}'") - if use_cuda_graph: - if self.cuda_graph_instance is not None: - CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr)) - # No cudaStreamSynchronize — graph replay is async; stream ordering ensures - # downstream GPU ops (copy_, attention) wait for graph completion. - # CPU sync happens only via end.synchronize() in pipeline.__call__. + with _gpu_profiler.region("trt_infer"): + if use_cuda_graph: + if self.cuda_graph_instance is not None: + CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream.ptr)) + # No cudaStreamSynchronize — graph replay is async; stream ordering ensures + # downstream GPU ops (copy_, attention) wait for graph completion. + # CPU sync happens only via end.synchronize() in pipeline.__call__. + else: + # Warmup passes before graph capture: TRT lazily JIT-compiles tactic + # variants on the first few forward calls. Three passes ensure all + # kernel variants are compiled before capture so the captured graph + # contains no JIT-init overhead. + for _ in range(3): + noerror = self.context.execute_async_v3(stream.ptr) + if not noerror: + raise ValueError("ERROR: inference failed.") + stream.synchronize() + # ThreadLocal mode: only captures ops on this thread's stream. + # Global mode would also capture any GPU work submitted from other + # threads (e.g. the TouchDesigner render thread), producing a + # corrupted graph with unintended nodes. + CUASSERT( + cudart.cudaStreamBeginCapture( + stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + ) + self.context.execute_async_v3(stream.ptr) + self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr)) + self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0)) else: - # do inference before CUDA graph capture noerror = self.context.execute_async_v3(stream.ptr) if not noerror: raise ValueError("ERROR: inference failed.") - # capture cuda graph - CUASSERT( - cudart.cudaStreamBeginCapture(stream.ptr, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) - ) - self.context.execute_async_v3(stream.ptr) - self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr)) - self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0)) - else: - noerror = self.context.execute_async_v3(stream.ptr) - if not noerror: - raise ValueError("ERROR: inference failed.") if self.profiler is not None: # Synchronize to ensure all IProfiler.report_layer_time() callbacks have fired @@ -1073,9 +1134,13 @@ def build_engine( build_all_tactics: bool = False, build_enable_refit: bool = False, fp8: bool = False, + builder_optimization_level: Optional[int] = None, ): # --- Step 0: Detect GPU and select hardware-optimal build parameters --- gpu_profile = detect_gpu_profile(device=torch.cuda.current_device()) + if builder_optimization_level is not None: + gpu_profile.builder_optimization_level = builder_optimization_level + logger.info(f"[TRT Build] builder_optimization_level overridden to {builder_optimization_level} (from config)") # --- Workspace sizing: leave 2 GiB for activations, cap per GPU tier --- _, free_mem, _ = cudart.cudaMemGetInfo() @@ -1107,12 +1172,14 @@ def build_engine( static_batch=build_static_batch, static_shape=not build_dynamic_shape, ) + # Note: build_all_tactics is accepted by build_engine() for API compat but + # Engine.build() does not forward it — tactic selection is now driven by + # set_tactic_sources (SM_120+) and max_tactics_per_layer in _apply_gpu_profile_to_config. engine.build( onnx_opt_path, fp16=True, input_profile=input_profile, enable_refit=build_enable_refit, - enable_all_tactics=build_all_tactics, timing_cache=timing_cache_path, workspace_size=max_workspace_size, fp8=fp8, @@ -1193,9 +1260,10 @@ def export_onnx( # Determine if we need external data format for large models (like SDXL) is_large_model = is_sdxl or (hasattr(model, "config") and getattr(model.config, "sample_size", 32) >= 64) - # Export ONNX normally first + export_model = wrapped_model + torch.onnx.export( - wrapped_model, + export_model, inputs, onnx_path, export_params=True, @@ -1286,7 +1354,9 @@ def optimize_onnx( else: # Standard optimization for smaller models - onnx_opt_graph = model_data.optimize(onnx.load(onnx_path)) + onnx_model = onnx.load(onnx_path) + onnx_opt_graph = model_data.optimize(onnx_model) + onnx.save(onnx_opt_graph, onnx_opt_path) del onnx_opt_graph diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 878d0914..33abd794 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -1,91 +1,96 @@ -import os -import sys -import yaml import json -from typing import Dict, List, Optional, Union, Any, Tuple from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import yaml + def load_config(config_path: Union[str, Path]) -> Dict[str, Any]: """Load StreamDiffusion configuration from YAML or JSON file""" config_path = Path(config_path) - + if not config_path.exists(): raise FileNotFoundError(f"load_config: Configuration file not found: {config_path}") - with open(config_path, 'r', encoding='utf-8') as f: - if config_path.suffix.lower() in ['.yaml', '.yml']: + with open(config_path, "r", encoding="utf-8") as f: + if config_path.suffix.lower() in [".yaml", ".yml"]: config_data = yaml.safe_load(f) - elif config_path.suffix.lower() == '.json': + elif config_path.suffix.lower() == ".json": config_data = json.load(f) else: raise ValueError(f"load_config: Unsupported configuration file format: {config_path.suffix}") - + _validate_config(config_data) - + return config_data def save_config(config: Dict[str, Any], config_path: Union[str, Path]) -> None: """Save StreamDiffusion configuration to YAML or JSON file""" config_path = Path(config_path) - + _validate_config(config) config_path.parent.mkdir(parents=True, exist_ok=True) - with open(config_path, 'w', encoding='utf-8') as f: - if config_path.suffix.lower() in ['.yaml', '.yml']: + with open(config_path, "w", encoding="utf-8") as f: + if config_path.suffix.lower() in [".yaml", ".yml"]: yaml.dump(config, f, default_flow_style=False, indent=2) - elif config_path.suffix.lower() == '.json': + elif config_path.suffix.lower() == ".json": json.dump(config, f, indent=2) else: raise ValueError(f"save_config: Unsupported configuration file format: {config_path.suffix}") + def create_wrapper_from_config(config: Dict[str, Any], **overrides) -> Any: """Create StreamDiffusionWrapper from configuration dictionary - + Prompt Interface: - Legacy: Use 'prompt' field for single prompt - New: Use 'prompt_blending' with 'prompt_list' for multiple weighted prompts - If both are provided, 'prompt_blending' takes precedence and 'prompt' is ignored - negative_prompt: Currently a single string (not list) for all prompt types """ + from streamdiffusion import StreamDiffusionWrapper - import torch final_config = {**config, **overrides} wrapper_params = _extract_wrapper_params(final_config) wrapper = StreamDiffusionWrapper(**wrapper_params) - + prepare_params = _extract_prepare_params(final_config) # Handle prompt configuration with clear precedence - if 'prompt_blending' in final_config: + if "prompt_blending" in final_config: # Use prompt blending (new interface) - ignore legacy 'prompt' field - blend_config = final_config['prompt_blending'] - + blend_config = final_config["prompt_blending"] + # Prepare with prompt blending directly using unified interface - prepare_params_with_blending = {k: v for k, v in prepare_params.items() - if k not in ['prompt_blending', 'seed_blending']} - prepare_params_with_blending['prompt'] = blend_config.get('prompt_list', []) - prepare_params_with_blending['prompt_interpolation_method'] = blend_config.get('interpolation_method', 'slerp') - + prepare_params_with_blending = { + k: v for k, v in prepare_params.items() if k not in ["prompt_blending", "seed_blending"] + } + prepare_params_with_blending["prompt"] = blend_config.get("prompt_list", []) + prepare_params_with_blending["prompt_interpolation_method"] = blend_config.get("interpolation_method", "slerp") + # Add seed blending if configured - if 'seed_blending' in final_config: - seed_blend_config = final_config['seed_blending'] - prepare_params_with_blending['seed_list'] = seed_blend_config.get('seed_list', []) - prepare_params_with_blending['seed_interpolation_method'] = seed_blend_config.get('interpolation_method', 'linear') - + if "seed_blending" in final_config: + seed_blend_config = final_config["seed_blending"] + prepare_params_with_blending["seed_list"] = seed_blend_config.get("seed_list", []) + prepare_params_with_blending["seed_interpolation_method"] = seed_blend_config.get( + "interpolation_method", "linear" + ) + wrapper.prepare(**prepare_params_with_blending) - elif prepare_params.get('prompt'): + elif prepare_params.get("prompt"): # Use legacy single prompt interface - clean_prepare_params = {k: v for k, v in prepare_params.items() - if k not in ['prompt_blending', 'seed_blending']} + clean_prepare_params = { + k: v for k, v in prepare_params.items() if k not in ["prompt_blending", "seed_blending"] + } wrapper.prepare(**clean_prepare_params) # Apply seed blending if configured and not already handled in prepare - if 'seed_blending' in final_config and 'prompt_blending' not in final_config: - seed_blend_config = final_config['seed_blending'] + if "seed_blending" in final_config and "prompt_blending" not in final_config: + seed_blend_config = final_config["seed_blending"] wrapper.update_stream_params( - seed_list=seed_blend_config.get('seed_list', []), - interpolation_method=seed_blend_config.get('interpolation_method', 'linear') + seed_list=seed_blend_config.get("seed_list", []), + interpolation_method=seed_blend_config.get("interpolation_method", "linear"), ) return wrapper @@ -93,251 +98,264 @@ def create_wrapper_from_config(config: Dict[str, Any], **overrides) -> Any: def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: """Extract parameters for StreamDiffusionWrapper.__init__() from config""" - import torch param_map = { - 'model_id_or_path': config.get('model_id', 'stabilityai/sd-turbo'), - 't_index_list': config.get('t_index_list', [0, 16, 32, 45]), - 'lora_dict': config.get('lora_dict'), - 'mode': config.get('mode', 'img2img'), - 'output_type': config.get('output_type', 'pil'), - 'vae_id': config.get('vae_id'), - 'device': config.get('device', 'cuda'), - 'dtype': _parse_dtype(config.get('dtype', 'float16')), - 'frame_buffer_size': config.get('frame_buffer_size', 1), - 'width': config.get('width', 512), - 'height': config.get('height', 512), - 'warmup': config.get('warmup', 10), - 'acceleration': config.get('acceleration', 'tensorrt'), - 'do_add_noise': config.get('do_add_noise', True), - 'device_ids': config.get('device_ids'), - 'use_lcm_lora': config.get('use_lcm_lora'), # Backwards compatibility - 'use_tiny_vae': config.get('use_tiny_vae', True), - 'enable_similar_image_filter': config.get('enable_similar_image_filter', False), - 'similar_image_filter_threshold': config.get('similar_image_filter_threshold', 0.98), - 'similar_image_filter_max_skip_frame': config.get('similar_image_filter_max_skip_frame', 10), - 'similar_filter_sleep_fraction': config.get('similar_filter_sleep_fraction', 0.025), - 'use_denoising_batch': config.get('use_denoising_batch', True), - 'cfg_type': config.get('cfg_type', 'self'), - 'seed': config.get('seed', 2), - 'use_safety_checker': config.get('use_safety_checker', False), - 'skip_diffusion': config.get('skip_diffusion', False), - 'engine_dir': config.get('engine_dir', 'engines'), - 'normalize_prompt_weights': config.get('normalize_prompt_weights', True), - 'normalize_seed_weights': config.get('normalize_seed_weights', True), - 'scheduler': config.get('scheduler', 'lcm'), - 'sampler': config.get('sampler', 'normal'), - 'compile_engines_only': config.get('compile_engines_only', False), + "model_id_or_path": config.get("model_id", "stabilityai/sd-turbo"), + "t_index_list": config.get("t_index_list", [0, 16, 32, 45]), + "lora_dict": config.get("lora_dict"), + "mode": config.get("mode", "img2img"), + "output_type": config.get("output_type", "pil"), + "vae_id": config.get("vae_id"), + "device": config.get("device", "cuda"), + "dtype": _parse_dtype(config.get("dtype", "float16")), + "frame_buffer_size": config.get("frame_buffer_size", 1), + "width": config.get("width", 512), + "height": config.get("height", 512), + "warmup": config.get("warmup", 10), + "acceleration": config.get("acceleration", "tensorrt"), + "do_add_noise": config.get("do_add_noise", True), + "device_ids": config.get("device_ids"), + "use_lcm_lora": config.get("use_lcm_lora"), # Backwards compatibility + "use_tiny_vae": config.get("use_tiny_vae", True), + "enable_similar_image_filter": config.get("enable_similar_image_filter", False), + "similar_image_filter_threshold": config.get("similar_image_filter_threshold", 0.98), + "similar_image_filter_max_skip_frame": config.get("similar_image_filter_max_skip_frame", 10), + "similar_filter_sleep_fraction": config.get("similar_filter_sleep_fraction", 0.025), + "use_denoising_batch": config.get("use_denoising_batch", True), + "cfg_type": config.get("cfg_type", "self"), + "seed": config.get("seed", 2), + "use_safety_checker": config.get("use_safety_checker", False), + "skip_diffusion": config.get("skip_diffusion", False), + "engine_dir": config.get("engine_dir", "engines"), + "normalize_prompt_weights": config.get("normalize_prompt_weights", True), + "normalize_seed_weights": config.get("normalize_seed_weights", True), + "scheduler": config.get("scheduler", "lcm"), + "sampler": config.get("sampler", "normal"), + "compile_engines_only": config.get("compile_engines_only", False), + "build_engines_if_missing": config.get("build_engines_if_missing", True), + "fp8": config.get("fp8", False), + "static_shapes": config.get("static_shapes", False), + "fp8_allow_fp16_fallback": config.get("fp8_allow_fp16_fallback", False), + "builder_optimization_level": config.get("builder_optimization_level"), } - if 'controlnets' in config and config['controlnets']: - param_map['use_controlnet'] = True - param_map['controlnet_config'] = _prepare_controlnet_configs(config) + if "controlnets" in config and config["controlnets"]: + param_map["use_controlnet"] = True + param_map["controlnet_config"] = _prepare_controlnet_configs(config) else: - param_map['use_controlnet'] = config.get('use_controlnet', False) - param_map['controlnet_config'] = config.get('controlnet_config') - + param_map["use_controlnet"] = config.get("use_controlnet", False) + param_map["controlnet_config"] = config.get("controlnet_config") + # Set IPAdapter usage if IPAdapters are configured - if 'ipadapters' in config and config['ipadapters']: - param_map['use_ipadapter'] = True - param_map['ipadapter_config'] = _prepare_ipadapter_configs(config) + if "ipadapters" in config and config["ipadapters"]: + param_map["use_ipadapter"] = True + param_map["ipadapter_config"] = _prepare_ipadapter_configs(config) else: - param_map['use_ipadapter'] = config.get('use_ipadapter', False) - param_map['ipadapter_config'] = config.get('ipadapter_config') - - param_map['use_cached_attn'] = config.get('use_cached_attn', False) - - param_map['cache_maxframes'] = config.get('cache_maxframes', 1) - param_map['cache_interval'] = config.get('cache_interval', 1) - + param_map["use_ipadapter"] = config.get("use_ipadapter", False) + param_map["ipadapter_config"] = config.get("ipadapter_config") + + param_map["use_cached_attn"] = config.get("use_cached_attn", False) + + param_map["cache_maxframes"] = config.get("cache_maxframes", 1) + param_map["cache_interval"] = config.get("cache_interval", 1) + # Pipeline hook configurations (Phase 4: Configuration Integration) hook_configs = _prepare_pipeline_hook_configs(config) param_map.update(hook_configs) - + return {k: v for k, v in param_map.items() if v is not None} def _extract_prepare_params(config: Dict[str, Any]) -> Dict[str, Any]: """Extract parameters for wrapper.prepare() from config""" prepare_params = { - 'prompt': config.get('prompt', ''), - 'negative_prompt': config.get('negative_prompt', ''), - 'num_inference_steps': config.get('num_inference_steps', 50), - 'guidance_scale': config.get('guidance_scale', 1.2), - 'delta': config.get('delta', 1.0), + "prompt": config.get("prompt", ""), + "negative_prompt": config.get("negative_prompt", ""), + "num_inference_steps": config.get("num_inference_steps", 50), + "guidance_scale": config.get("guidance_scale", 1.2), + "delta": config.get("delta", 1.0), } - + # Handle prompt blending configuration - if 'prompt_blending' in config: - blend_config = config['prompt_blending'] - prepare_params['prompt_blending'] = { - 'prompt_list': blend_config.get('prompt_list', []), - 'interpolation_method': blend_config.get('interpolation_method', 'slerp'), - 'enable_caching': blend_config.get('enable_caching', True) + if "prompt_blending" in config: + blend_config = config["prompt_blending"] + prepare_params["prompt_blending"] = { + "prompt_list": blend_config.get("prompt_list", []), + "interpolation_method": blend_config.get("interpolation_method", "slerp"), + "enable_caching": blend_config.get("enable_caching", True), } - + # Handle seed blending configuration - if 'seed_blending' in config: - seed_blend_config = config['seed_blending'] - prepare_params['seed_blending'] = { - 'seed_list': seed_blend_config.get('seed_list', []), - 'interpolation_method': seed_blend_config.get('interpolation_method', 'linear'), - 'enable_caching': seed_blend_config.get('enable_caching', True) + if "seed_blending" in config: + seed_blend_config = config["seed_blending"] + prepare_params["seed_blending"] = { + "seed_list": seed_blend_config.get("seed_list", []), + "interpolation_method": seed_blend_config.get("interpolation_method", "linear"), + "enable_caching": seed_blend_config.get("enable_caching", True), } - + return prepare_params + def _prepare_controlnet_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]: """Prepare ControlNet configurations for wrapper""" controlnet_configs = [] - pipeline_type = config.get('pipeline_type', 'sd1.5') - for cn_config in config['controlnets']: + pipeline_type = config.get("pipeline_type", "sd1.5") + for cn_config in config["controlnets"]: controlnet_config = { - 'model_id': cn_config['model_id'], - 'preprocessor': cn_config.get('preprocessor', 'passthrough'), - 'conditioning_scale': cn_config.get('conditioning_scale', 1.0), - 'enabled': cn_config.get('enabled', True), - 'preprocessor_params': cn_config.get('preprocessor_params'), - 'conditioning_channels': cn_config.get('conditioning_channels'), - 'pipeline_type': pipeline_type, - 'control_guidance_start': cn_config.get('control_guidance_start', 0.0), - 'control_guidance_end': cn_config.get('control_guidance_end', 1.0), + "model_id": cn_config["model_id"], + "preprocessor": cn_config.get("preprocessor", "passthrough"), + "conditioning_scale": cn_config.get("conditioning_scale", 1.0), + "enabled": cn_config.get("enabled", True), + "preprocessor_params": cn_config.get("preprocessor_params"), + "conditioning_channels": cn_config.get("conditioning_channels"), + "pipeline_type": pipeline_type, + "control_guidance_start": cn_config.get("control_guidance_start", 0.0), + "control_guidance_end": cn_config.get("control_guidance_end", 1.0), } controlnet_configs.append(controlnet_config) - + return controlnet_configs def _prepare_ipadapter_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]: """Prepare IPAdapter configurations for wrapper""" ipadapter_configs = [] - - for ip_config in config['ipadapters']: + + for ip_config in config["ipadapters"]: ipadapter_config = { - 'ipadapter_model_path': ip_config['ipadapter_model_path'], - 'image_encoder_path': ip_config['image_encoder_path'], - 'style_image': ip_config.get('style_image'), - 'scale': ip_config.get('scale', 1.0), - 'enabled': ip_config.get('enabled', True), + "ipadapter_model_path": ip_config["ipadapter_model_path"], + "image_encoder_path": ip_config["image_encoder_path"], + "style_image": ip_config.get("style_image"), + "scale": ip_config.get("scale", 1.0), + "enabled": ip_config.get("enabled", True), # Preserve FaceID options from config for downstream wrapper/module handling - 'type': ip_config.get('type', 'regular'), - 'insightface_model_name': ip_config.get('insightface_model_name'), + "type": ip_config.get("type", "regular"), + "insightface_model_name": ip_config.get("insightface_model_name"), } ipadapter_configs.append(ipadapter_config) - + return ipadapter_configs def _prepare_pipeline_hook_configs(config: Dict[str, Any]) -> Dict[str, Any]: """Prepare pipeline hook configurations for wrapper following ControlNet/IPAdapter pattern""" hook_configs = {} - + # Image preprocessing hooks - if 'image_preprocessing' in config and config['image_preprocessing']: - if config['image_preprocessing'].get('enabled', True): - hook_configs['image_preprocessing_config'] = _prepare_single_hook_config( - config['image_preprocessing'], 'image_preprocessing' + if "image_preprocessing" in config and config["image_preprocessing"]: + if config["image_preprocessing"].get("enabled", True): + hook_configs["image_preprocessing_config"] = _prepare_single_hook_config( + config["image_preprocessing"], "image_preprocessing" ) - - # Image postprocessing hooks - if 'image_postprocessing' in config and config['image_postprocessing']: - if config['image_postprocessing'].get('enabled', True): - hook_configs['image_postprocessing_config'] = _prepare_single_hook_config( - config['image_postprocessing'], 'image_postprocessing' + + # Image postprocessing hooks + if "image_postprocessing" in config and config["image_postprocessing"]: + if config["image_postprocessing"].get("enabled", True): + hook_configs["image_postprocessing_config"] = _prepare_single_hook_config( + config["image_postprocessing"], "image_postprocessing" ) - + # Latent preprocessing hooks - if 'latent_preprocessing' in config and config['latent_preprocessing']: - if config['latent_preprocessing'].get('enabled', True): - hook_configs['latent_preprocessing_config'] = _prepare_single_hook_config( - config['latent_preprocessing'], 'latent_preprocessing' + if "latent_preprocessing" in config and config["latent_preprocessing"]: + if config["latent_preprocessing"].get("enabled", True): + hook_configs["latent_preprocessing_config"] = _prepare_single_hook_config( + config["latent_preprocessing"], "latent_preprocessing" ) - + # Latent postprocessing hooks - if 'latent_postprocessing' in config and config['latent_postprocessing']: - if config['latent_postprocessing'].get('enabled', True): - hook_configs['latent_postprocessing_config'] = _prepare_single_hook_config( - config['latent_postprocessing'], 'latent_postprocessing' + if "latent_postprocessing" in config and config["latent_postprocessing"]: + if config["latent_postprocessing"].get("enabled", True): + hook_configs["latent_postprocessing_config"] = _prepare_single_hook_config( + config["latent_postprocessing"], "latent_postprocessing" ) - + return hook_configs def _prepare_single_hook_config(hook_config: Dict[str, Any], hook_type: str) -> Dict[str, Any]: """Prepare configuration for a single hook type""" return { - 'enabled': hook_config.get('enabled', True), - 'processors': hook_config.get('processors', []), - 'hook_type': hook_type, + "enabled": hook_config.get("enabled", True), + "processors": hook_config.get("processors", []), + "hook_type": hook_type, } def _validate_pipeline_hook_configs(config: Dict[str, Any]) -> None: """Validate pipeline hook configurations following ControlNet/IPAdapter validation pattern""" - hook_types = ['image_preprocessing', 'image_postprocessing', 'latent_preprocessing', 'latent_postprocessing'] - + hook_types = ["image_preprocessing", "image_postprocessing", "latent_preprocessing", "latent_postprocessing"] + for hook_type in hook_types: if hook_type in config: hook_config = config[hook_type] if not isinstance(hook_config, dict): raise ValueError(f"_validate_config: '{hook_type}' must be a dictionary") - + # Validate enabled field - if 'enabled' in hook_config: - enabled = hook_config['enabled'] + if "enabled" in hook_config: + enabled = hook_config["enabled"] if not isinstance(enabled, bool): raise ValueError(f"_validate_config: '{hook_type}.enabled' must be a boolean") - + # Validate processors field - if 'processors' in hook_config: - processors = hook_config['processors'] + if "processors" in hook_config: + processors = hook_config["processors"] if not isinstance(processors, list): raise ValueError(f"_validate_config: '{hook_type}.processors' must be a list") - + for i, processor in enumerate(processors): if not isinstance(processor, dict): raise ValueError(f"_validate_config: '{hook_type}.processors[{i}]' must be a dictionary") - + # Validate processor type (required) - if 'type' not in processor: - raise ValueError(f"_validate_config: '{hook_type}.processors[{i}]' missing required 'type' field") - - if not isinstance(processor['type'], str): + if "type" not in processor: + raise ValueError( + f"_validate_config: '{hook_type}.processors[{i}]' missing required 'type' field" + ) + + if not isinstance(processor["type"], str): raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].type' must be a string") - + # Validate enabled field (optional, defaults to True) - if 'enabled' in processor: - enabled = processor['enabled'] + if "enabled" in processor: + enabled = processor["enabled"] if not isinstance(enabled, bool): - raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].enabled' must be a boolean") - + raise ValueError( + f"_validate_config: '{hook_type}.processors[{i}].enabled' must be a boolean" + ) + # Validate order field (optional) - if 'order' in processor: - order = processor['order'] + if "order" in processor: + order = processor["order"] if not isinstance(order, int): - raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].order' must be an integer") - + raise ValueError( + f"_validate_config: '{hook_type}.processors[{i}].order' must be an integer" + ) + # Validate params field (optional, coerce None to empty dict) - if 'params' in processor: - if processor['params'] is None: - processor['params'] = {} - elif not isinstance(processor['params'], dict): - raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].params' must be a dictionary") + if "params" in processor: + if processor["params"] is None: + processor["params"] = {} + elif not isinstance(processor["params"], dict): + raise ValueError( + f"_validate_config: '{hook_type}.processors[{i}].params' must be a dictionary" + ) def create_prompt_blending_config( base_config: Dict[str, Any], prompt_list: List[Tuple[str, float]], prompt_interpolation_method: str = "slerp", - enable_caching: bool = True + enable_caching: bool = True, ) -> Dict[str, Any]: """Create a configuration with prompt blending settings""" config = base_config.copy() - - config['prompt_blending'] = { - 'prompt_list': prompt_list, - 'interpolation_method': prompt_interpolation_method, - 'enable_caching': enable_caching + + config["prompt_blending"] = { + "prompt_list": prompt_list, + "interpolation_method": prompt_interpolation_method, + "enable_caching": enable_caching, } - + return config @@ -345,150 +363,152 @@ def create_seed_blending_config( base_config: Dict[str, Any], seed_list: List[Tuple[int, float]], interpolation_method: str = "linear", - enable_caching: bool = True + enable_caching: bool = True, ) -> Dict[str, Any]: """Create a configuration with seed blending settings""" config = base_config.copy() - - config['seed_blending'] = { - 'seed_list': seed_list, - 'interpolation_method': interpolation_method, - 'enable_caching': enable_caching + + config["seed_blending"] = { + "seed_list": seed_list, + "interpolation_method": interpolation_method, + "enable_caching": enable_caching, } - + return config def set_normalize_weights_config( - base_config: Dict[str, Any], - normalize_prompt_weights: bool = True, - normalize_seed_weights: bool = True + base_config: Dict[str, Any], normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True ) -> Dict[str, Any]: """Create a configuration with separate normalize weight settings""" config = base_config.copy() - - config['normalize_prompt_weights'] = normalize_prompt_weights - config['normalize_seed_weights'] = normalize_seed_weights - + + config["normalize_prompt_weights"] = normalize_prompt_weights + config["normalize_seed_weights"] = normalize_seed_weights + return config + def _parse_dtype(dtype_str: str) -> Any: """Parse dtype string to torch dtype""" import torch - + dtype_map = { - 'float16': torch.float16, - 'float32': torch.float32, - 'half': torch.float16, - 'float': torch.float32, + "float16": torch.float16, + "float32": torch.float32, + "half": torch.float16, + "float": torch.float32, } - + if isinstance(dtype_str, str): return dtype_map.get(dtype_str.lower(), torch.float16) return dtype_str # Assume it's already a torch dtype + + def _validate_config(config: Dict[str, Any]) -> None: """Basic validation of configuration dictionary""" if not isinstance(config, dict): raise ValueError("_validate_config: Configuration must be a dictionary") - - if 'model_id' not in config: + + if "model_id" not in config: raise ValueError("_validate_config: Missing required field: model_id") - - if 'controlnets' in config: - if not isinstance(config['controlnets'], list): + + if "controlnets" in config: + if not isinstance(config["controlnets"], list): raise ValueError("_validate_config: 'controlnets' must be a list") - - for i, controlnet in enumerate(config['controlnets']): + + for i, controlnet in enumerate(config["controlnets"]): if not isinstance(controlnet, dict): raise ValueError(f"_validate_config: ControlNet {i} must be a dictionary") - - if 'model_id' not in controlnet: + + if "model_id" not in controlnet: raise ValueError(f"_validate_config: ControlNet {i} missing required 'model_id'") - + # Validate conditioning_channels if present - if 'conditioning_channels' in controlnet: - channels = controlnet['conditioning_channels'] + if "conditioning_channels" in controlnet: + channels = controlnet["conditioning_channels"] if not isinstance(channels, int) or channels <= 0: - raise ValueError(f"_validate_config: ControlNet {i} 'conditioning_channels' must be a positive integer, got {channels}") - + raise ValueError( + f"_validate_config: ControlNet {i} 'conditioning_channels' must be a positive integer, got {channels}" + ) + # Validate ipadapters if present - if 'ipadapters' in config: - if not isinstance(config['ipadapters'], list): + if "ipadapters" in config: + if not isinstance(config["ipadapters"], list): raise ValueError("_validate_config: 'ipadapters' must be a list") - - for i, ipadapter in enumerate(config['ipadapters']): + + for i, ipadapter in enumerate(config["ipadapters"]): if not isinstance(ipadapter, dict): raise ValueError(f"_validate_config: IPAdapter {i} must be a dictionary") - - if 'ipadapter_model_path' not in ipadapter: + + if "ipadapter_model_path" not in ipadapter: raise ValueError(f"_validate_config: IPAdapter {i} missing required 'ipadapter_model_path'") - - if 'image_encoder_path' not in ipadapter: + + if "image_encoder_path" not in ipadapter: raise ValueError(f"_validate_config: IPAdapter {i} missing required 'image_encoder_path'") # Validate prompt blending configuration if present - if 'prompt_blending' in config: - blend_config = config['prompt_blending'] + if "prompt_blending" in config: + blend_config = config["prompt_blending"] if not isinstance(blend_config, dict): raise ValueError("_validate_config: 'prompt_blending' must be a dictionary") - - if 'prompt_list' in blend_config: - prompt_list = blend_config['prompt_list'] + + if "prompt_list" in blend_config: + prompt_list = blend_config["prompt_list"] if not isinstance(prompt_list, list): raise ValueError("_validate_config: 'prompt_list' must be a list") - + for i, prompt_item in enumerate(prompt_list): if not isinstance(prompt_item, (list, tuple)) or len(prompt_item) != 2: raise ValueError(f"_validate_config: Prompt item {i} must be [text, weight] pair") - + text, weight = prompt_item if not isinstance(text, str): raise ValueError(f"_validate_config: Prompt text {i} must be a string") - + if not isinstance(weight, (int, float)) or weight < 0: raise ValueError(f"_validate_config: Prompt weight {i} must be a non-negative number") - - interpolation_method = blend_config.get('interpolation_method', 'slerp') - if interpolation_method not in ['linear', 'slerp']: + + interpolation_method = blend_config.get("interpolation_method", "slerp") + if interpolation_method not in ["linear", "slerp"]: raise ValueError("_validate_config: interpolation_method must be 'linear' or 'slerp'") # Validate seed blending configuration if present - if 'seed_blending' in config: - seed_blend_config = config['seed_blending'] + if "seed_blending" in config: + seed_blend_config = config["seed_blending"] if not isinstance(seed_blend_config, dict): raise ValueError("_validate_config: 'seed_blending' must be a dictionary") - - if 'seed_list' in seed_blend_config: - seed_list = seed_blend_config['seed_list'] + + if "seed_list" in seed_blend_config: + seed_list = seed_blend_config["seed_list"] if not isinstance(seed_list, list): raise ValueError("_validate_config: 'seed_list' must be a list") - + for i, seed_item in enumerate(seed_list): if not isinstance(seed_item, (list, tuple)) or len(seed_item) != 2: raise ValueError(f"_validate_config: Seed item {i} must be [seed, weight] pair") - + seed_value, weight = seed_item if not isinstance(seed_value, int) or seed_value < 0: raise ValueError(f"_validate_config: Seed value {i} must be a non-negative integer") - + if not isinstance(weight, (int, float)) or weight < 0: raise ValueError(f"_validate_config: Seed weight {i} must be a non-negative number") - - interpolation_method = seed_blend_config.get('interpolation_method', 'linear') - if interpolation_method not in ['linear', 'slerp']: + + interpolation_method = seed_blend_config.get("interpolation_method", "linear") + if interpolation_method not in ["linear", "slerp"]: raise ValueError("_validate_config: seed blending interpolation_method must be 'linear' or 'slerp'") # Validate pipeline hook configurations if present (Phase 4: Configuration Integration) _validate_pipeline_hook_configs(config) # Validate separate normalize settings if present - if 'normalize_prompt_weights' in config: - normalize_prompt_weights = config['normalize_prompt_weights'] + if "normalize_prompt_weights" in config: + normalize_prompt_weights = config["normalize_prompt_weights"] if not isinstance(normalize_prompt_weights, bool): raise ValueError("_validate_config: 'normalize_prompt_weights' must be a boolean value") - - if 'normalize_seed_weights' in config: - normalize_seed_weights = config['normalize_seed_weights'] + + if "normalize_seed_weights" in config: + normalize_seed_weights = config["normalize_seed_weights"] if not isinstance(normalize_seed_weights, bool): raise ValueError("_validate_config: 'normalize_seed_weights' must be a boolean value") - diff --git a/src/streamdiffusion/hooks.py b/src/streamdiffusion/hooks.py index ec5db415..02f10270 100644 --- a/src/streamdiffusion/hooks.py +++ b/src/streamdiffusion/hooks.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional + import torch @@ -13,6 +14,7 @@ class EmbedsCtx: - prompt_embeds: [batch, seq_len, dim] - negative_prompt_embeds: optional [batch, seq_len, dim] """ + prompt_embeds: torch.Tensor negative_prompt_embeds: Optional[torch.Tensor] = None @@ -28,6 +30,7 @@ class StepCtx: - guidance_mode: one of {"none","full","self","initialize"} - sdxl_cond: optional dict with SDXL micro-cond tensors """ + x_t_latent: torch.Tensor t_list: torch.Tensor step_index: Optional[int] @@ -38,6 +41,7 @@ class StepCtx: @dataclass class UnetKwargsDelta: """Delta produced by UNet hooks to augment UNet call kwargs.""" + down_block_additional_residuals: Optional[List[torch.Tensor]] = None mid_block_additional_residual: Optional[torch.Tensor] = None added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None @@ -48,37 +52,37 @@ class UnetKwargsDelta: @dataclass class ImageCtx: """Context passed to image processing hooks. - + Fields: - image: [B, C, H, W] tensor in image space - width: image width - - height: image height + - height: image height - step_index: optional step index for multi-step processing """ + image: torch.Tensor width: int height: int step_index: Optional[int] = None -@dataclass +@dataclass class LatentCtx: """Context passed to latent processing hooks. - + Fields: - latent: [B, C, H/8, W/8] tensor in latent space - timestep: optional timestep tensor for diffusion context - step_index: optional step index for multi-step processing """ + latent: torch.Tensor timestep: Optional[torch.Tensor] = None step_index: Optional[int] = None - # Type aliases for clarity EmbeddingHook = Callable[[EmbedsCtx], EmbedsCtx] UnetHook = Callable[[StepCtx], UnetKwargsDelta] ImageHook = Callable[[ImageCtx], ImageCtx] LatentHook = Callable[[LatentCtx], LatentCtx] - diff --git a/src/streamdiffusion/image_filter.py b/src/streamdiffusion/image_filter.py index 5523c886..e975567a 100644 --- a/src/streamdiffusion/image_filter.py +++ b/src/streamdiffusion/image_filter.py @@ -1,5 +1,5 @@ -from typing import Optional import random +from typing import Optional import torch import torch.nn.functional as F diff --git a/src/streamdiffusion/image_utils.py b/src/streamdiffusion/image_utils.py index 200295b3..77d7275c 100644 --- a/src/streamdiffusion/image_utils.py +++ b/src/streamdiffusion/image_utils.py @@ -30,9 +30,7 @@ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images - pil_images = [ - PIL.Image.fromarray(image.squeeze(), mode="L") for image in images - ] + pil_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [PIL.Image.fromarray(image) for image in images] @@ -56,12 +54,7 @@ def postprocess_image( if do_denormalize is None: do_denormalize = [do_normalize_flg] * image.shape[0] - image = torch.stack( - [ - denormalize(image[i]) if do_denormalize[i] else image[i] - for i in range(image.shape[0]) - ] - ) + image = torch.stack([denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]) if output_type == "pt": return image @@ -91,8 +84,6 @@ def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor: img, _ = process_image(image_pil) imgs.append(img) imgs = torch.vstack(imgs) - images = torch.nn.functional.interpolate( - imgs, size=(height, width), mode="bilinear" - ) + images = torch.nn.functional.interpolate(imgs, size=(height, width), mode="bilinear") image_tensors = images.to(torch.float16) return image_tensors diff --git a/src/streamdiffusion/model_detection.py b/src/streamdiffusion/model_detection.py index e9eef252..fbd28933 100644 --- a/src/streamdiffusion/model_detection.py +++ b/src/streamdiffusion/model_detection.py @@ -1,13 +1,15 @@ """Comprehensive model detection for TensorRT and pipeline support""" -from typing import Dict, Tuple, Optional, Any, List +from typing import Any, Dict, Optional import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + # Gracefully import the SD3 model class; it might not exist in older diffusers versions. try: from diffusers.models.transformers.mm_dit import MMDiTTransformer2DModel + HAS_MMDIT = True except ImportError: # Create a dummy class if the import fails to prevent runtime errors. @@ -15,6 +17,8 @@ HAS_MMDIT = False import logging + + logger = logging.getLogger(__name__) @@ -23,7 +27,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str Comprehensive and robust model detection using definitive architectural features. This function replaces heuristic-based analysis with a deterministic, - rule-based approach by first inspecting the model's class and then its key + rule-based approach by first inspecting the model's class and then its key configuration parameters that define the architecture. Args: @@ -50,9 +54,9 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str confidence = 1.0 # Differentiating SD3 vs. SD3-Turbo from the MMDiT config alone is currently # speculative. A check on the pipeline's scheduler is a reasonable proxy. - if pipe and hasattr(pipe, 'scheduler'): - scheduler_name = getattr(pipe.scheduler.config, '_class_name', '').lower() - if 'lcm' in scheduler_name or 'turbo' in scheduler_name: + if pipe and hasattr(pipe, "scheduler"): + scheduler_name = getattr(pipe.scheduler.config, "_class_name", "").lower() + if "lcm" in scheduler_name or "turbo" in scheduler_name: is_turbo = True model_type = "SD3-Turbo" else: @@ -62,7 +66,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str # 2. UNet-based Model Detection (SDXL, SD2.1, SD1.5) elif isinstance(model, UNet2DConditionModel): config = model.config - + # 2a. SDXL vs. non-SDXL # The `addition_embed_type` is the clearest indicator for the SDXL architecture. if config.get("addition_embed_type") is not None: @@ -73,7 +77,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str # Base SDXL has `time_cond_proj_dim` (e.g., 256), while Turbo has it set to `None`. if config.get("time_cond_proj_dim") is None: is_turbo = True - + # 2b. SD2.1 vs. SD1.5 (if not SDXL) # Differentiate based on the text encoder's projection dimension. else: @@ -90,10 +94,10 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str confidence = 0.7 # 3. ControlNet Model Detection (detect underlying architecture) - elif hasattr(model, 'config') and hasattr(model.config, 'cross_attention_dim'): + elif hasattr(model, "config") and hasattr(model.config, "cross_attention_dim"): # ControlNet models have UNet-like configs, detect their base architecture config = model.config - + # Apply same detection logic as UNet models if config.get("addition_embed_type") is not None: model_type = "SDXL" @@ -107,12 +111,12 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str model_type = "SD2.1" confidence = 0.95 elif cross_attention_dim == 768: - model_type = "SD1.5" + model_type = "SD1.5" confidence = 0.95 else: model_type = "SD-finetune" confidence = 0.7 - + else: # The model is not a known UNet or MMDiT class. confidence = 0.0 @@ -120,44 +124,46 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str # Populate architecture and compatibility details (can be expanded as needed) architecture_details = { - 'model_class': model.__class__.__name__, - 'in_channels': getattr(model.config, 'in_channels', 'N/A'), - 'cross_attention_dim': getattr(model.config, 'cross_attention_dim', 'N/A'), - 'block_out_channels': getattr(model.config, 'block_out_channels', 'N/A'), + "model_class": model.__class__.__name__, + "in_channels": getattr(model.config, "in_channels", "N/A"), + "cross_attention_dim": getattr(model.config, "cross_attention_dim", "N/A"), + "block_out_channels": getattr(model.config, "block_out_channels", "N/A"), } - + # For UNet models, add detailed characteristics that SDXL code expects if isinstance(model, UNet2DConditionModel): unet_chars = detect_unet_characteristics(model) - architecture_details.update({ - 'has_time_conditioning': unet_chars['has_time_cond'], - 'has_addition_embeds': unet_chars['has_addition_embed'], - }) - + architecture_details.update( + { + "has_time_conditioning": unet_chars["has_time_cond"], + "has_addition_embeds": unet_chars["has_addition_embed"], + } + ) + # For ControlNet models, add similar characteristics - elif hasattr(model, 'config') and hasattr(model.config, 'cross_attention_dim'): + elif hasattr(model, "config") and hasattr(model.config, "cross_attention_dim"): # ControlNet models have similar config structure to UNet config = model.config has_addition_embed = config.get("addition_embed_type") is not None - has_time_cond = hasattr(config, 'time_cond_proj_dim') and config.time_cond_proj_dim is not None - - architecture_details.update({ - 'has_time_conditioning': has_time_cond, - 'has_addition_embeds': has_addition_embed, - }) - - compatibility_info = { - 'notes': f"Detected as {model_type} with {confidence:.2f} confidence based on architecture." - } + has_time_cond = hasattr(config, "time_cond_proj_dim") and config.time_cond_proj_dim is not None + + architecture_details.update( + { + "has_time_conditioning": has_time_cond, + "has_addition_embeds": has_addition_embed, + } + ) + + compatibility_info = {"notes": f"Detected as {model_type} with {confidence:.2f} confidence based on architecture."} result = { - 'model_type': model_type, - 'is_turbo': is_turbo, - 'is_sdxl': is_sdxl, - 'is_sd3': is_sd3, - 'confidence': confidence, - 'architecture_details': architecture_details, - 'compatibility_info': compatibility_info, + "model_type": model_type, + "is_turbo": is_turbo, + "is_sdxl": is_sdxl, + "is_sd3": is_sd3, + "confidence": confidence, + "architecture_details": architecture_details, + "compatibility_info": compatibility_info, } return result @@ -166,13 +172,13 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, any]: """Detect detailed UNet characteristics including SDXL-specific features""" config = unet.config - + # Get cross attention dimensions to detect model type - cross_attention_dim = getattr(config, 'cross_attention_dim', None) - + cross_attention_dim = getattr(config, "cross_attention_dim", None) + # Detect SDXL by multiple indicators is_sdxl = False - + # Check cross attention dimension if isinstance(cross_attention_dim, (list, tuple)): # SDXL typically has [1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280] @@ -180,73 +186,74 @@ def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, any]: elif isinstance(cross_attention_dim, int): # Single value - SDXL uses 2048 for concatenated embeddings, or 1280+ for individual encoders is_sdxl = cross_attention_dim >= 1280 - + # Check addition_embed_type for SDXL detection (strong indicator) - addition_embed_type = getattr(config, 'addition_embed_type', None) + addition_embed_type = getattr(config, "addition_embed_type", None) has_addition_embed = addition_embed_type is not None - - if addition_embed_type in ['text_time', 'text_time_guidance']: + + if addition_embed_type in ["text_time", "text_time_guidance"]: is_sdxl = True # This is a definitive SDXL indicator - + # Check if model has time conditioning projection (SDXL feature) - has_time_cond = hasattr(config, 'time_cond_proj_dim') and config.time_cond_proj_dim is not None - + has_time_cond = hasattr(config, "time_cond_proj_dim") and config.time_cond_proj_dim is not None + # Additional SDXL detection checks - if hasattr(config, 'num_class_embeds') and config.num_class_embeds is not None: + if hasattr(config, "num_class_embeds") and config.num_class_embeds is not None: is_sdxl = True # SDXL often has class embeddings - + # Check sample size (SDXL typically uses 128 vs 64 for SD1.5) - sample_size = getattr(config, 'sample_size', 64) + sample_size = getattr(config, "sample_size", 64) if sample_size >= 128: is_sdxl = True - + return { - 'is_sdxl': is_sdxl, - 'has_time_cond': has_time_cond, - 'has_addition_embed': has_addition_embed, - 'cross_attention_dim': cross_attention_dim, - 'addition_embed_type': addition_embed_type, - 'in_channels': getattr(config, 'in_channels', 4), - 'sample_size': getattr(config, 'sample_size', 64 if not is_sdxl else 128), - 'block_out_channels': tuple(getattr(config, 'block_out_channels', [])), - 'attention_head_dim': getattr(config, 'attention_head_dim', None) + "is_sdxl": is_sdxl, + "has_time_cond": has_time_cond, + "has_addition_embed": has_addition_embed, + "cross_attention_dim": cross_attention_dim, + "addition_embed_type": addition_embed_type, + "in_channels": getattr(config, "in_channels", 4), + "sample_size": getattr(config, "sample_size", 64 if not is_sdxl else 128), + "block_out_channels": tuple(getattr(config, "block_out_channels", [])), + "attention_head_dim": getattr(config, "attention_head_dim", None), } + # This is used for controlnet/ipadapter model detection - can be deprecated (along with detect_unet_characteristics) def detect_model_from_diffusers_unet(unet: UNet2DConditionModel) -> str: """Detect model type from diffusers UNet configuration""" characteristics = detect_unet_characteristics(unet) - - in_channels = characteristics['in_channels'] - block_out_channels = characteristics['block_out_channels'] - cross_attention_dim = characteristics['cross_attention_dim'] - is_sdxl = characteristics['is_sdxl'] - + + in_channels = characteristics["in_channels"] + block_out_channels = characteristics["block_out_channels"] + cross_attention_dim = characteristics["cross_attention_dim"] + is_sdxl = characteristics["is_sdxl"] + # Use enhanced SDXL detection if is_sdxl: return "SDXL" - + # Original detection logic for other models - if (cross_attention_dim == 768 and - block_out_channels == (320, 640, 1280, 1280) and - in_channels == 4): + if cross_attention_dim == 768 and block_out_channels == (320, 640, 1280, 1280) and in_channels == 4: return "SD15" - - elif (cross_attention_dim == 1024 and - block_out_channels == (320, 640, 1280, 1280) and - in_channels == 4): + + elif cross_attention_dim == 1024 and block_out_channels == (320, 640, 1280, 1280) and in_channels == 4: return "SD21" - + elif cross_attention_dim == 768 and in_channels == 4: return "SD15" elif cross_attention_dim == 1024 and in_channels == 4: return "SD21" - + if cross_attention_dim == 768: - print(f"detect_model_from_diffusers_unet: Unknown SD1.5-like model with channels {block_out_channels}, defaulting to SD15") + print( + f"detect_model_from_diffusers_unet: Unknown SD1.5-like model with channels {block_out_channels}, defaulting to SD15" + ) return "SD15" elif cross_attention_dim == 1024: - print(f"detect_model_from_diffusers_unet: Unknown SD2.1-like model with channels {block_out_channels}, defaulting to SD21") + print( + f"detect_model_from_diffusers_unet: Unknown SD2.1-like model with channels {block_out_channels}, defaulting to SD21" + ) return "SD21" else: raise ValueError( @@ -260,58 +267,58 @@ def detect_model_from_diffusers_unet(unet: UNet2DConditionModel) -> str: def extract_unet_architecture(unet: UNet2DConditionModel) -> Dict[str, Any]: """ Extract UNet architecture details needed for TensorRT engine building. - + This function provides the essential architecture information needed for TensorRT engine compilation in a clean, structured format. - + Args: unet: The UNet model to analyze - + Returns: Dict with architecture parameters for TensorRT engine building """ config = unet.config - + # Basic model parameters model_channels = config.block_out_channels[0] if config.block_out_channels else 320 block_out_channels = tuple(config.block_out_channels) channel_mult = tuple(ch // model_channels for ch in block_out_channels) - + # Resolution blocks - if hasattr(config, 'layers_per_block'): + if hasattr(config, "layers_per_block"): if isinstance(config.layers_per_block, (list, tuple)): num_res_blocks = tuple(config.layers_per_block) else: num_res_blocks = tuple([config.layers_per_block] * len(block_out_channels)) else: num_res_blocks = tuple([2] * len(block_out_channels)) - + # Attention and context dimensions context_dim = config.cross_attention_dim in_channels = config.in_channels - + # Attention head configuration - attention_head_dim = getattr(config, 'attention_head_dim', 8) + attention_head_dim = getattr(config, "attention_head_dim", 8) if isinstance(attention_head_dim, (list, tuple)): attention_head_dim = attention_head_dim[0] - + # Transformer depth - transformer_depth = getattr(config, 'transformer_layers_per_block', 1) + transformer_depth = getattr(config, "transformer_layers_per_block", 1) if isinstance(transformer_depth, (list, tuple)): transformer_depth = tuple(transformer_depth) else: transformer_depth = tuple([transformer_depth] * len(block_out_channels)) - + # Time embedding - time_embed_dim = getattr(config, 'time_embedding_dim', None) + time_embed_dim = getattr(config, "time_embedding_dim", None) if time_embed_dim is None: time_embed_dim = model_channels * 4 - + # Build architecture dictionary architecture_dict = { "model_channels": model_channels, "in_channels": in_channels, - "out_channels": getattr(config, 'out_channels', in_channels), + "out_channels": getattr(config, "out_channels", in_channels), "num_res_blocks": num_res_blocks, "channel_mult": channel_mult, "context_dim": context_dim, @@ -319,48 +326,50 @@ def extract_unet_architecture(unet: UNet2DConditionModel) -> Dict[str, Any]: "transformer_depth": transformer_depth, "time_embed_dim": time_embed_dim, "block_out_channels": block_out_channels, - # Additional configuration - "use_linear_in_transformer": getattr(config, 'use_linear_in_transformer', False), - "conv_in_kernel": getattr(config, 'conv_in_kernel', 3), - "conv_out_kernel": getattr(config, 'conv_out_kernel', 3), - "resnet_time_scale_shift": getattr(config, 'resnet_time_scale_shift', 'default'), - "class_embed_type": getattr(config, 'class_embed_type', None), - "num_class_embeds": getattr(config, 'num_class_embeds', None), - + "use_linear_in_transformer": getattr(config, "use_linear_in_transformer", False), + "conv_in_kernel": getattr(config, "conv_in_kernel", 3), + "conv_out_kernel": getattr(config, "conv_out_kernel", 3), + "resnet_time_scale_shift": getattr(config, "resnet_time_scale_shift", "default"), + "class_embed_type": getattr(config, "class_embed_type", None), + "num_class_embeds": getattr(config, "num_class_embeds", None), # Block types - "down_block_types": getattr(config, 'down_block_types', []), - "up_block_types": getattr(config, 'up_block_types', []), + "down_block_types": getattr(config, "down_block_types", []), + "up_block_types": getattr(config, "up_block_types", []), } - + return architecture_dict def validate_architecture(arch_dict: Dict[str, Any], model_type: str) -> Dict[str, Any]: """ Validate and fix architecture dictionary using model type presets. - + Ensures that all required architecture parameters are present and have reasonable values for the specified model type. - + Args: arch_dict: Architecture dictionary to validate model_type: Expected model type for validation - + Returns: Validated and corrected architecture dictionary """ - + # Check for required keys required_keys = [ - "model_channels", "channel_mult", "num_res_blocks", - "context_dim", "in_channels", "block_out_channels" + "model_channels", + "channel_mult", + "num_res_blocks", + "context_dim", + "in_channels", + "block_out_channels", ] - + for key in required_keys: if key not in arch_dict: raise ValueError(f"Missing required architecture parameter: {key}") - + # Ensure tuple format for sequence parameters for key in ["channel_mult", "num_res_blocks", "transformer_depth", "block_out_channels"]: if key in arch_dict and not isinstance(arch_dict[key], tuple): @@ -371,12 +380,11 @@ def validate_architecture(arch_dict: Dict[str, Any], model_type: str) -> Dict[st arch_dict[key] = tuple(arch_dict[key]) else: arch_dict[key] = preset[key] - + # Validate sequence lengths match expected_levels = len(arch_dict["channel_mult"]) for key in ["num_res_blocks", "transformer_depth"]: if key in arch_dict and len(arch_dict[key]) != expected_levels: arch_dict[key] = preset[key] - - return arch_dict + return arch_dict diff --git a/src/streamdiffusion/modules/__init__.py b/src/streamdiffusion/modules/__init__.py index 54954961..f3242ca5 100644 --- a/src/streamdiffusion/modules/__init__.py +++ b/src/streamdiffusion/modules/__init__.py @@ -1,22 +1,21 @@ # StreamDiffusion Modules Package from .controlnet_module import ControlNetModule +from .image_processing_module import ImagePostprocessingModule, ImagePreprocessingModule, ImageProcessingModule from .ipadapter_module import IPAdapterModule -from .image_processing_module import ImageProcessingModule, ImagePreprocessingModule, ImagePostprocessingModule -from .latent_processing_module import LatentProcessingModule, LatentPreprocessingModule, LatentPostprocessingModule +from .latent_processing_module import LatentPostprocessingModule, LatentPreprocessingModule, LatentProcessingModule + __all__ = [ # Existing modules - 'ControlNetModule', - 'IPAdapterModule', - + "ControlNetModule", + "IPAdapterModule", # Pipeline processing base classes - 'ImageProcessingModule', - 'LatentProcessingModule', - + "ImageProcessingModule", + "LatentProcessingModule", # Pipeline processing timing-specific modules - 'ImagePreprocessingModule', - 'ImagePostprocessingModule', - 'LatentPreprocessingModule', - 'LatentPostprocessingModule', + "ImagePreprocessingModule", + "ImagePostprocessingModule", + "LatentPreprocessingModule", + "LatentPostprocessingModule", ] diff --git a/src/streamdiffusion/modules/controlnet_module.py b/src/streamdiffusion/modules/controlnet_module.py index 9e7818b1..e0a57f3b 100644 --- a/src/streamdiffusion/modules/controlnet_module.py +++ b/src/streamdiffusion/modules/controlnet_module.py @@ -1,18 +1,18 @@ from __future__ import annotations +import logging import threading from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch from diffusers.models import ControlNetModel -import logging -from streamdiffusion.hooks import StepCtx, UnetKwargsDelta, UnetHook +from streamdiffusion.hooks import StepCtx, UnetHook, UnetKwargsDelta +from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser from streamdiffusion.preprocessing.preprocessing_orchestrator import ( PreprocessingOrchestrator, ) -from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser @dataclass @@ -55,17 +55,17 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) -> self._prepared_dtype: Optional[torch.dtype] = None self._prepared_batch: Optional[int] = None self._images_version: int = 0 - + # Cache expensive lookups to avoid repeated hasattr/getattr calls self._engines_by_id: Dict[str, Any] = {} self._engines_cache_valid: bool = False self._is_sdxl: Optional[bool] = None self._expected_text_len: int = 77 - + # SDXL-specific caching for performance optimization self._sdxl_conditioning_cache: Optional[Dict[str, torch.Tensor]] = None self._sdxl_conditioning_valid: bool = False - + # Cache engine type detection to avoid repeated hasattr calls self._engine_type_cache: Dict[str, bool] = {} @@ -78,9 +78,9 @@ def install(self, stream) -> None: # Register UNet hook stream.unet_hooks.append(self.build_unet_hook()) # Expose controlnet collections so existing updater can find them - setattr(stream, 'controlnets', self.controlnets) - setattr(stream, 'controlnet_scales', self.controlnet_scales) - setattr(stream, 'preprocessors', self.preprocessors) + setattr(stream, "controlnets", self.controlnets) + setattr(stream, "controlnet_scales", self.controlnet_scales) + setattr(stream, "preprocessors", self.preprocessors) # Reset prepared tensors on install self._prepared_tensors = [] self._prepared_device = None @@ -92,18 +92,26 @@ def install(self, stream) -> None: self._sdxl_conditioning_valid = False self._engine_type_cache.clear() - def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None) -> None: + def add_controlnet( + self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None + ) -> None: model = self._load_pytorch_controlnet_model(cfg.model_id, cfg.conditioning_channels) preproc = None if cfg.preprocessor: from streamdiffusion.preprocessing.processors import get_preprocessor - preproc = get_preprocessor(cfg.preprocessor, pipeline_ref=self._stream, normalization_context='controlnet', params=cfg.preprocessor_params) + + preproc = get_preprocessor( + cfg.preprocessor, + pipeline_ref=self._stream, + normalization_context="controlnet", + params=cfg.preprocessor_params, + ) # Apply provided parameters to the preprocessor instance if cfg.preprocessor_params: params = cfg.preprocessor_params or {} # If the preprocessor exposes a 'params' dict, update it - if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict): + if hasattr(preproc, "params") and isinstance(getattr(preproc, "params"), dict): preproc.params.update(params) # Also set attributes directly when they exist for name, value in params.items(): @@ -113,16 +121,15 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st except Exception: pass - # Align preprocessor target size with stream resolution once (avoid double-resize later) try: - if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict): - preproc.params['image_width'] = int(self._stream.width) - preproc.params['image_height'] = int(self._stream.height) - if hasattr(preproc, 'image_width'): - setattr(preproc, 'image_width', int(self._stream.width)) - if hasattr(preproc, 'image_height'): - setattr(preproc, 'image_height', int(self._stream.height)) + if hasattr(preproc, "params") and isinstance(getattr(preproc, "params"), dict): + preproc.params["image_width"] = int(self._stream.width) + preproc.params["image_height"] = int(self._stream.height) + if hasattr(preproc, "image_width"): + setattr(preproc, "image_width", int(self._stream.width)) + if hasattr(preproc, "image_height"): + setattr(preproc, "image_height", int(self._stream.height)) except Exception: pass @@ -142,7 +149,9 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st # Invalidate SDXL conditioning cache when ControlNet configuration changes self._sdxl_conditioning_valid = False - def update_control_image_efficient(self, control_image: Union[str, Any, torch.Tensor], index: Optional[int] = None) -> None: + def update_control_image_efficient( + self, control_image: Union[str, Any, torch.Tensor], index: Optional[int] = None + ) -> None: if self._preprocessing_orchestrator is None: return with self._collections_lock: @@ -150,23 +159,15 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te return total = len(self.controlnets) # Build active scales, respecting enabled_list if present - scales = [ - (self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0) - for i in range(total) - ] - if hasattr(self, 'enabled_list') and self.enabled_list and len(self.enabled_list) == total: + scales = [(self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0) for i in range(total)] + if hasattr(self, "enabled_list") and self.enabled_list and len(self.enabled_list) == total: scales = [sc if bool(self.enabled_list[i]) else 0.0 for i, sc in enumerate(scales)] preprocessors = [self.preprocessors[i] if i < len(self.preprocessors) else None for i in range(total)] # Single-index fast path if index is not None: results = self._preprocessing_orchestrator.process_sync( - control_image, - preprocessors, - scales, - self._stream.width, - self._stream.height, - index + control_image, preprocessors, scales, self._stream.width, self._stream.height, index ) processed = results[index] if results and len(results) > index else None with self._collections_lock: @@ -182,11 +183,7 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te # Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync) processed_images = self._preprocessing_orchestrator.process_pipelined( - control_image, - preprocessors, - scales, - self._stream.width, - self._stream.height + control_image, preprocessors, scales, self._stream.width, self._stream.height ) # If orchestrator returns empty list, it indicates no update needed for this frame @@ -243,7 +240,7 @@ def reorder_controlnets_by_model_ids(self, desired_model_ids: List[str]) -> None # Build current mapping from model_id to index current_ids: List[str] = [] for i, cn in enumerate(self.controlnets): - model_id = getattr(cn, 'model_id', f'controlnet_{i}') + model_id = getattr(cn, "model_id", f"controlnet_{i}") current_ids.append(model_id) # Compute new index order @@ -275,23 +272,29 @@ def get_current_config(self) -> List[Dict[str, Any]]: cfg: List[Dict[str, Any]] = [] with self._collections_lock: for i, cn in enumerate(self.controlnets): - model_id = getattr(cn, 'model_id', f'controlnet_{i}') + model_id = getattr(cn, "model_id", f"controlnet_{i}") scale = self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0 - preproc_params = getattr(self.preprocessors[i], 'params', {}) if i < len(self.preprocessors) and self.preprocessors[i] else {} - cfg.append({ - 'model_id': model_id, - 'conditioning_scale': scale, - 'preprocessor_params': preproc_params, - 'enabled': (self.enabled_list[i] if i < len(self.enabled_list) else True), - }) + preproc_params = ( + getattr(self.preprocessors[i], "params", {}) + if i < len(self.preprocessors) and self.preprocessors[i] + else {} + ) + cfg.append( + { + "model_id": model_id, + "conditioning_scale": scale, + "preprocessor_params": preproc_params, + "enabled": (self.enabled_list[i] if i < len(self.enabled_list) else True), + } + ) return cfg def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_size: int) -> None: """Prepare control image tensors for the current frame. - + This method is called once per frame to prepare all control images with the correct device, dtype, and batch size. This avoids redundant operations during each denoising step. - + Args: device: Target device for tensors dtype: Target dtype for tensors @@ -300,22 +303,22 @@ def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_ with self._collections_lock: # Check if we need to re-prepare tensors cache_valid = ( - self._prepared_device == device and - self._prepared_dtype == dtype and - self._prepared_batch == batch_size and - len(self._prepared_tensors) == len(self.controlnet_images) + self._prepared_device == device + and self._prepared_dtype == dtype + and self._prepared_batch == batch_size + and len(self._prepared_tensors) == len(self.controlnet_images) ) - + if cache_valid: return - + # Prepare tensors for current frame self._prepared_tensors = [] for img in self.controlnet_images: if img is None: self._prepared_tensors.append(None) continue - + # Prepare tensor with correct batch size prepared = img if prepared.dim() == 4 and prepared.shape[0] != batch_size: @@ -324,63 +327,62 @@ def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_ else: repeat_factor = max(1, batch_size // prepared.shape[0]) prepared = prepared.repeat(repeat_factor, 1, 1, 1)[:batch_size] - + # Move to correct device and dtype prepared = prepared.to(device=device, dtype=dtype) self._prepared_tensors.append(prepared) - + # Update cache state self._prepared_device = device self._prepared_dtype = dtype self._prepared_batch = batch_size - def _get_cached_sdxl_conditioning(self, ctx: 'StepCtx') -> Optional[Dict[str, torch.Tensor]]: + def _get_cached_sdxl_conditioning(self, ctx: "StepCtx") -> Optional[Dict[str, torch.Tensor]]: """Get cached SDXL conditioning to avoid repeated preparation""" if not self._is_sdxl or ctx.sdxl_cond is None: return None - + # Check if cache is valid if self._sdxl_conditioning_valid and self._sdxl_conditioning_cache is not None: cached = self._sdxl_conditioning_cache # Verify batch size matches current context - if ('text_embeds' in cached and - cached['text_embeds'].shape[0] == ctx.x_t_latent.shape[0]): + if "text_embeds" in cached and cached["text_embeds"].shape[0] == ctx.x_t_latent.shape[0]: return cached - + # Cache miss or invalid - prepare new conditioning try: conditioning = {} - if 'text_embeds' in ctx.sdxl_cond: - text_embeds = ctx.sdxl_cond['text_embeds'] + if "text_embeds" in ctx.sdxl_cond: + text_embeds = ctx.sdxl_cond["text_embeds"] batch_size = ctx.x_t_latent.shape[0] - + # Optimize batch expansion for SDXL text embeddings if text_embeds.shape[0] != batch_size: if text_embeds.shape[0] == 1: - conditioning['text_embeds'] = text_embeds.repeat(batch_size, 1) + conditioning["text_embeds"] = text_embeds.repeat(batch_size, 1) else: - conditioning['text_embeds'] = text_embeds[:batch_size] + conditioning["text_embeds"] = text_embeds[:batch_size] else: - conditioning['text_embeds'] = text_embeds - - if 'time_ids' in ctx.sdxl_cond: - time_ids = ctx.sdxl_cond['time_ids'] + conditioning["text_embeds"] = text_embeds + + if "time_ids" in ctx.sdxl_cond: + time_ids = ctx.sdxl_cond["time_ids"] batch_size = ctx.x_t_latent.shape[0] - + # Optimize batch expansion for SDXL time IDs if time_ids.shape[0] != batch_size: if time_ids.shape[0] == 1: - conditioning['time_ids'] = time_ids.repeat(batch_size, 1) + conditioning["time_ids"] = time_ids.repeat(batch_size, 1) else: - conditioning['time_ids'] = time_ids[:batch_size] + conditioning["time_ids"] = time_ids[:batch_size] else: - conditioning['time_ids'] = time_ids - + conditioning["time_ids"] = time_ids + # Cache the prepared conditioning self._sdxl_conditioning_cache = conditioning self._sdxl_conditioning_valid = True return conditioning - + except Exception: # Fallback to original conditioning on any error return ctx.sdxl_cond @@ -399,8 +401,10 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # Single pass to collect active ControlNet data active_data = [] enabled_flags = self.enabled_list if len(self.enabled_list) == len(self.controlnets) else None - - for i, (cn, img, scale) in enumerate(zip(self.controlnets, self.controlnet_images, self.controlnet_scales)): + + for i, (cn, img, scale) in enumerate( + zip(self.controlnets, self.controlnet_images, self.controlnet_scales) + ): if cn is not None and img is not None and scale > 0: enabled = enabled_flags[i] if enabled_flags else True if enabled: @@ -413,9 +417,11 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: if not self._engines_cache_valid: self._engines_by_id.clear() try: - if hasattr(self._stream, 'controlnet_engines') and isinstance(self._stream.controlnet_engines, list): + if hasattr(self._stream, "controlnet_engines") and isinstance( + self._stream.controlnet_engines, list + ): for eng in self._stream.controlnet_engines: - mid = getattr(eng, 'model_id', None) + mid = getattr(eng, "model_id", None) if mid: self._engines_by_id[mid] = eng self._engines_cache_valid = True @@ -425,17 +431,17 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # Cache SDXL detection to avoid repeated hasattr calls if self._is_sdxl is None: try: - self._is_sdxl = getattr(self._stream, 'is_sdxl', False) + self._is_sdxl = getattr(self._stream, "is_sdxl", False) except Exception: self._is_sdxl = False - encoder_hidden_states = self._stream.prompt_embeds[:, :self._expected_text_len, :] + encoder_hidden_states = self._stream.prompt_embeds[:, : self._expected_text_len, :] base_kwargs: Dict[str, Any] = { - 'sample': x_t, - 'timestep': t_list, - 'encoder_hidden_states': encoder_hidden_states, - 'return_dict': False, + "sample": x_t, + "timestep": t_list, + "encoder_hidden_states": encoder_hidden_states, + "return_dict": False, } down_samples_list: List[List[torch.Tensor]] = [] @@ -443,20 +449,22 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # Ensure tensors are prepared for this frame # This should have been called earlier, but we call it here as a safety net - if (self._prepared_device != x_t.device or - self._prepared_dtype != x_t.dtype or - self._prepared_batch != x_t.shape[0]): + if ( + self._prepared_device != x_t.device + or self._prepared_dtype != x_t.dtype + or self._prepared_batch != x_t.shape[0] + ): self.prepare_frame_tensors(x_t.device, x_t.dtype, x_t.shape[0]) - + # Use pre-prepared tensors prepared_images = self._prepared_tensors for cn, img, scale, idx_i in active_data: # Swap to TRT engine if available for this model_id (use cached lookup) - model_id = getattr(cn, 'model_id', None) + model_id = getattr(cn, "model_id", None) if model_id and model_id in self._engines_by_id: cn = self._engines_by_id[model_id] - + # Use pre-prepared tensor current_img = prepared_images[idx_i] if idx_i < len(prepared_images) else img if current_img is None: @@ -467,12 +475,12 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: if cache_key in self._engine_type_cache: is_trt_engine = self._engine_type_cache[cache_key] else: - is_trt_engine = hasattr(cn, 'engine') and hasattr(cn, 'stream') + is_trt_engine = hasattr(cn, "engine") and hasattr(cn, "stream") self._engine_type_cache[cache_key] = is_trt_engine - + # Get optimized SDXL conditioning (uses caching to avoid repeated tensor operations) added_cond_kwargs = self._get_cached_sdxl_conditioning(ctx) - + try: if is_trt_engine: # TensorRT engine path @@ -483,7 +491,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: encoder_hidden_states=encoder_hidden_states, controlnet_cond=current_img, conditioning_scale=float(scale), - **added_cond_kwargs + **added_cond_kwargs, ) else: down_samples, mid_sample = cn( @@ -491,7 +499,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: timestep=t_list, encoder_hidden_states=encoder_hidden_states, controlnet_cond=current_img, - conditioning_scale=float(scale) + conditioning_scale=float(scale), ) else: # PyTorch ControlNet path @@ -503,7 +511,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: controlnet_cond=current_img, conditioning_scale=float(scale), return_dict=False, - added_cond_kwargs=added_cond_kwargs + added_cond_kwargs=added_cond_kwargs, ) else: down_samples, mid_sample = cn( @@ -512,21 +520,30 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: encoder_hidden_states=encoder_hidden_states, controlnet_cond=current_img, conditioning_scale=float(scale), - return_dict=False + return_dict=False, ) except Exception as e: import traceback - __import__('logging').getLogger(__name__).error("ControlNetModule: controlnet forward failed: %s", e) + + __import__("logging").getLogger(__name__).error( + "ControlNetModule: controlnet forward failed: %s", e + ) try: - __import__('logging').getLogger(__name__).error("ControlNetModule: call_summary: cond_shape=%s, img_shape=%s, scale=%s, is_sdxl=%s, is_trt=%s", - (tuple(encoder_hidden_states.shape) if isinstance(encoder_hidden_states, torch.Tensor) else None), - (tuple(current_img.shape) if isinstance(current_img, torch.Tensor) else None), - scale, - self._is_sdxl, - is_trt_engine) + __import__("logging").getLogger(__name__).error( + "ControlNetModule: call_summary: cond_shape=%s, img_shape=%s, scale=%s, is_sdxl=%s, is_trt=%s", + ( + tuple(encoder_hidden_states.shape) + if isinstance(encoder_hidden_states, torch.Tensor) + else None + ), + (tuple(current_img.shape) if isinstance(current_img, torch.Tensor) else None), + scale, + self._is_sdxl, + is_trt_engine, + ) except Exception: pass - __import__('logging').getLogger(__name__).error(traceback.format_exc()) + __import__("logging").getLogger(__name__).error(traceback.format_exc()) continue down_samples_list.append(down_samples) mid_samples_list.append(mid_sample) @@ -555,48 +572,53 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: return _unet_hook - def _prepare_control_image(self, control_image: Union[str, Any, torch.Tensor], preprocessor: Optional[Any]) -> torch.Tensor: + def _prepare_control_image( + self, control_image: Union[str, Any, torch.Tensor], preprocessor: Optional[Any] + ) -> torch.Tensor: if self._preprocessing_orchestrator is None: raise RuntimeError("ControlNetModule: preprocessing orchestrator is not initialized") # Reuse orchestrator API used by BaseControlNetPipeline images = self._preprocessing_orchestrator.process_sync( - control_image, - [preprocessor], - [1.0], - self._stream.width, - self._stream.height, - 0 + control_image, [preprocessor], [1.0], self._stream.width, self._stream.height, 0 ) # API returns a list; pick first if present return images[0] if images else None - #FIXME: more robust model management is needed in general. - def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: Optional[int] = None) -> ControlNetModel: - from pathlib import Path - import logging + # FIXME: more robust model management is needed in general. + def _load_pytorch_controlnet_model( + self, model_id: str, conditioning_channels: Optional[int] = None + ) -> ControlNetModel: import os + from pathlib import Path + logger = logging.getLogger(__name__) - + try: # Prepare loading kwargs load_kwargs = {"torch_dtype": self.dtype} if conditioning_channels is not None: load_kwargs["conditioning_channels"] = conditioning_channels - + # Check if offline mode is enabled via environment variables - is_offline = os.environ.get("HF_HUB_OFFLINE", "0") == "1" or os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1" - + is_offline = ( + os.environ.get("HF_HUB_OFFLINE", "0") == "1" or os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1" + ) + if Path(model_id).exists(): model_path = Path(model_id) - + # Check if it's a direct file path to a safetensors/ckpt file - if model_path.is_file() and model_path.suffix in ['.safetensors', '.ckpt', '.bin']: - logger.info(f"ControlNetModule._load_pytorch_controlnet_model: Loading ControlNet from single file: {model_path} (channels={conditioning_channels})") + if model_path.is_file() and model_path.suffix in [".safetensors", ".ckpt", ".bin"]: + logger.info( + f"ControlNetModule._load_pytorch_controlnet_model: Loading ControlNet from single file: {model_path} (channels={conditioning_channels})" + ) # Try loading from single file (works for most ControlNet models) try: controlnet = ControlNetModel.from_single_file(str(model_path), **load_kwargs) except Exception as e: - logger.warning(f"ControlNetModule._load_pytorch_controlnet_model: Single file loading failed: {e}") + logger.warning( + f"ControlNetModule._load_pytorch_controlnet_model: Single file loading failed: {e}" + ) # Fallback: try pretrained loading in case it's in a proper directory structure load_kwargs["local_files_only"] = True controlnet = ControlNetModel.from_pretrained(str(model_path.parent), **load_kwargs) @@ -608,29 +630,27 @@ def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: O # Loading from HuggingFace Hub - respect offline mode if is_offline: load_kwargs["local_files_only"] = True - logger.info(f"ControlNetModule._load_pytorch_controlnet_model: Offline mode enabled, loading '{model_id}' from cache only") - + logger.info( + f"ControlNetModule._load_pytorch_controlnet_model: Offline mode enabled, loading '{model_id}' from cache only" + ) + if "/" in model_id and model_id.count("/") > 1: parts = model_id.split("/") repo_id = "/".join(parts[:2]) subfolder = "/".join(parts[2:]) - controlnet = ControlNetModel.from_pretrained( - repo_id, subfolder=subfolder, **load_kwargs - ) + controlnet = ControlNetModel.from_pretrained(repo_id, subfolder=subfolder, **load_kwargs) else: controlnet = ControlNetModel.from_pretrained(model_id, **load_kwargs) controlnet = controlnet.to(device=self.device, dtype=self.dtype) # Track model_id for updater diffing try: - setattr(controlnet, 'model_id', model_id) + setattr(controlnet, "model_id", model_id) except Exception: pass return controlnet except Exception as e: import traceback + logger.error(f"ControlNetModule: failed to load model '{model_id}': {e}") logger.error(traceback.format_exc()) raise - - - diff --git a/src/streamdiffusion/modules/image_processing_module.py b/src/streamdiffusion/modules/image_processing_module.py index ffea6e5f..b96f0c0b 100644 --- a/src/streamdiffusion/modules/image_processing_module.py +++ b/src/streamdiffusion/modules/image_processing_module.py @@ -1,55 +1,57 @@ -from typing import List, Optional, Any, Dict +from typing import Any, Dict, List + import torch -from ..preprocessing.orchestrator_user import OrchestratorUser -from ..preprocessing.pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator from ..hooks import ImageCtx, ImageHook +from ..preprocessing.orchestrator_user import OrchestratorUser class ImageProcessingModule(OrchestratorUser): """ Shared base class for image domain processing modules. - + Handles sequential chain execution for both preprocessing and postprocessing timing variants. Processing domain is always image tensors. """ - + def __init__(self): """Initialize image processing module.""" self.processors = [] - + def _process_image_chain(self, input_image: torch.Tensor) -> torch.Tensor: """Execute sequential chain of processors in image domain. - + Uses the shared orchestrator's sequential chain processing. """ if not self.processors: return input_image - + ordered_processors = self._get_ordered_processors() return self._preprocessing_orchestrator.execute_pipeline_chain( input_image, ordered_processors, processing_domain="image" ) - + def add_processor(self, proc_config: Dict[str, Any]) -> None: """Add a processor using the existing registry, following ControlNet pattern.""" from streamdiffusion.preprocessing.processors import get_preprocessor - - processor_type = proc_config.get('type') + + processor_type = proc_config.get("type") if not processor_type: raise ValueError("Processor config missing 'type' field") - + # Check if processor is enabled (default to True, same as ControlNet) - enabled = proc_config.get('enabled', True) - + enabled = proc_config.get("enabled", True) + # Create processor using existing registry (same as ControlNet) # ImageProcessingModule uses 'pipeline' normalization context - processor = get_preprocessor(processor_type, pipeline_ref=getattr(self, '_stream', None), normalization_context='pipeline') - + processor = get_preprocessor( + processor_type, pipeline_ref=getattr(self, "_stream", None), normalization_context="pipeline" + ) + # Apply parameters (same pattern as ControlNet) - processor_params = proc_config.get('params', {}) + processor_params = proc_config.get("params", {}) if processor_params: - if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict): + if hasattr(processor, "params") and isinstance(getattr(processor, "params"), dict): processor.params.update(processor_params) for name, value in processor_params.items(): try: @@ -57,109 +59,109 @@ def add_processor(self, proc_config: Dict[str, Any]) -> None: setattr(processor, name, value) except Exception: pass - + # Set order for sequential execution - order = proc_config.get('order', len(self.processors)) - setattr(processor, 'order', order) - + order = proc_config.get("order", len(self.processors)) + setattr(processor, "order", order) + # Set enabled state - setattr(processor, 'enabled', enabled) - + setattr(processor, "enabled", enabled) + # Align preprocessor target size with stream resolution (same as ControlNet) - if hasattr(self, '_stream'): + if hasattr(self, "_stream"): try: - if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict): - processor.params['image_width'] = int(self._stream.width) - processor.params['image_height'] = int(self._stream.height) - if hasattr(processor, 'image_width'): - setattr(processor, 'image_width', int(self._stream.width)) - if hasattr(processor, 'image_height'): - setattr(processor, 'image_height', int(self._stream.height)) + if hasattr(processor, "params") and isinstance(getattr(processor, "params"), dict): + processor.params["image_width"] = int(self._stream.width) + processor.params["image_height"] = int(self._stream.height) + if hasattr(processor, "image_width"): + setattr(processor, "image_width", int(self._stream.width)) + if hasattr(processor, "image_height"): + setattr(processor, "image_height", int(self._stream.height)) except Exception: pass - + self.processors.append(processor) - + def _get_ordered_processors(self) -> List[Any]: """Return enabled processors in execution order based on their order attribute.""" # Filter for enabled processors first, then sort by order - enabled_processors = [p for p in self.processors if getattr(p, 'enabled', True)] - return sorted(enabled_processors, key=lambda p: getattr(p, 'order', 0)) + enabled_processors = [p for p in self.processors if getattr(p, "enabled", True)] + return sorted(enabled_processors, key=lambda p: getattr(p, "order", 0)) class ImagePreprocessingModule(ImageProcessingModule): """ Image domain preprocessing module - executes before VAE encoding. - + Timing: After image_processor.preprocess(), before similar_image_filter Uses pipelined processing for performance optimization. """ - + def install(self, stream) -> None: """Install module by registering hook with stream and attaching orchestrators.""" self._stream = stream # Store stream reference for dimension access self.attach_orchestrator(stream) # For sequential chain processing (fallback) self.attach_pipeline_preprocessing_orchestrator(stream) # For pipelined processing stream.image_preprocessing_hooks.append(self.build_image_hook()) - + def build_image_hook(self) -> ImageHook: """Build hook function that processes image context with pipelined processing.""" + def hook(ctx: ImageCtx) -> ImageCtx: ctx.image = self._process_image_pipelined(ctx.image) return ctx + return hook - + def _process_image_pipelined(self, input_image: torch.Tensor) -> torch.Tensor: """Execute pipelined processing of preprocessors for performance. - + Uses PipelinePreprocessingOrchestrator for Frame N-1 results while starting Frame N processing. Falls back to synchronous processing when needed. """ if not self.processors: return input_image - + ordered_processors = self._get_ordered_processors() - + # Use pipelined pipeline preprocessing orchestrator for performance - return self._pipeline_preprocessing_orchestrator.process_pipelined( - input_image, ordered_processors - ) + return self._pipeline_preprocessing_orchestrator.process_pipelined(input_image, ordered_processors) class ImagePostprocessingModule(ImageProcessingModule): """ Image domain postprocessing module - executes after VAE decoding. - + Timing: After decode_image(), before returning final output Uses pipelined processing for performance optimization. """ - + def install(self, stream) -> None: """Install module by registering hook with stream and attaching orchestrators.""" self._stream = stream # Store stream reference for dimension access self.attach_preprocessing_orchestrator(stream) # For sequential chain processing (fallback) self.attach_postprocessing_orchestrator(stream) # For pipelined processing stream.image_postprocessing_hooks.append(self.build_image_hook()) - + def build_image_hook(self) -> ImageHook: """Build hook function that processes image context with pipelined processing.""" + def hook(ctx: ImageCtx) -> ImageCtx: ctx.image = self._process_image_pipelined(ctx.image) return ctx + return hook - + def _process_image_pipelined(self, input_image: torch.Tensor) -> torch.Tensor: """Execute pipelined processing of postprocessors for performance. - + Uses PostprocessingOrchestrator for Frame N-1 results while starting Frame N processing. Falls back to synchronous processing when needed. """ if not self.processors: return input_image - + ordered_processors = self._get_ordered_processors() - + # Use pipelined postprocessing orchestrator for performance - return self._postprocessing_orchestrator.process_pipelined( - input_image, ordered_processors - ) + return self._postprocessing_orchestrator.process_pipelined(input_image, ordered_processors) diff --git a/src/streamdiffusion/modules/ipadapter_module.py b/src/streamdiffusion/modules/ipadapter_module.py index b283799f..4e3b95ea 100644 --- a/src/streamdiffusion/modules/ipadapter_module.py +++ b/src/streamdiffusion/modules/ipadapter_module.py @@ -1,16 +1,18 @@ from __future__ import annotations +import logging +import os from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Any from enum import Enum +from typing import Any, Dict, Optional, Tuple + import torch -from streamdiffusion.hooks import EmbedsCtx, EmbeddingHook, StepCtx, UnetKwargsDelta, UnetHook -import os +from streamdiffusion.hooks import EmbeddingHook, EmbedsCtx, StepCtx, UnetHook, UnetKwargsDelta from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser -import logging from streamdiffusion.utils.reporting import report_error + logger = logging.getLogger(__name__) @@ -27,6 +29,7 @@ class IPAdapterConfig: This module focuses only on embedding composition (step 2 of migration). Runtime installation and wrapper wiring will come in later steps. """ + style_image_key: Optional[str] = None num_image_tokens: int = 4 # e.g., 4 for standard, 16 for plus ipadapter_model_path: Optional[str] = None @@ -59,7 +62,7 @@ class IPAdapterConfig: "image_encoder_path": "h94/IP-Adapter/models/image_encoder", }, ("SD2.1", IPAdapterType.REGULAR): None, # not available from h94 (ip-adapter_sd21.bin was never released) - ("SD2.1", IPAdapterType.PLUS): None, # not available from h94 + ("SD2.1", IPAdapterType.PLUS): None, # not available from h94 ("SD2.1", IPAdapterType.FACEID): None, # not available from h94 ("SDXL", IPAdapterType.REGULAR): { "model_path": "h94/IP-Adapter/sdxl_models/ip-adapter_sdxl.bin", @@ -78,15 +81,15 @@ class IPAdapterConfig: # Set of all known HF model paths — used to distinguish known vs custom paths. # Custom/local paths are never overridden. _KNOWN_IPADAPTER_PATHS: frozenset = frozenset( - entry["model_path"] - for entry in IPADAPTER_MODEL_MAP.values() - if entry is not None + entry["model_path"] for entry in IPADAPTER_MODEL_MAP.values() if entry is not None ) -_KNOWN_ENCODER_PATHS: frozenset = frozenset({ - "h94/IP-Adapter/models/image_encoder", - "h94/IP-Adapter/sdxl_models/image_encoder", -}) +_KNOWN_ENCODER_PATHS: frozenset = frozenset( + { + "h94/IP-Adapter/models/image_encoder", + "h94/IP-Adapter/sdxl_models/image_encoder", + } +) def _normalize_model_type(detected_model_type: str, is_sdxl: bool) -> Optional[str]: @@ -183,10 +186,7 @@ def resolve_ipadapter_paths( # Resolve encoder path (only if it's a known HF encoder — custom encoders untouched) if current_encoder_path in _KNOWN_ENCODER_PATHS and current_encoder_path != correct_encoder_path: - logger.info( - f"IP-Adapter: resolving image encoder " - f"'{current_encoder_path}' → '{correct_encoder_path}'." - ) + logger.info(f"IP-Adapter: resolving image encoder '{current_encoder_path}' → '{correct_encoder_path}'.") cfg["image_encoder_path"] = correct_encoder_path return cfg @@ -209,7 +209,9 @@ def build_embedding_hook(self, stream) -> EmbeddingHook: def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx: # Fetch cached image token embeddings (prompt, negative) - cached: Optional[Tuple[torch.Tensor, torch.Tensor]] = stream._param_updater.get_cached_embeddings(style_key) + cached: Optional[Tuple[torch.Tensor, torch.Tensor]] = stream._param_updater.get_cached_embeddings( + style_key + ) image_prompt_tokens: Optional[torch.Tensor] = None image_negative_tokens: Optional[torch.Tensor] = None if cached is not None: @@ -220,7 +222,9 @@ def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx: batch_size = ctx.prompt_embeds.shape[0] if image_prompt_tokens is None: image_prompt_tokens = torch.zeros( - (batch_size, num_tokens, hidden_dim), dtype=ctx.prompt_embeds.dtype, device=ctx.prompt_embeds.device + (batch_size, num_tokens, hidden_dim), + dtype=ctx.prompt_embeds.dtype, + device=ctx.prompt_embeds.device, ) else: if image_prompt_tokens.shape[1] != num_tokens: @@ -242,7 +246,9 @@ def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx: if neg_with_image is not None: if image_negative_tokens is None: image_negative_tokens = torch.zeros( - (neg_with_image.shape[0], num_tokens, hidden_dim), dtype=neg_with_image.dtype, device=neg_with_image.device + (neg_with_image.shape[0], num_tokens, hidden_dim), + dtype=neg_with_image.dtype, + device=neg_with_image.device, ) else: if image_negative_tokens.shape[0] != neg_with_image.shape[0]: @@ -291,14 +297,14 @@ def install(self, stream) -> None: # Create IP-Adapter and install processors into UNet (FaceID-aware) ip_kwargs = { - 'pipe': stream.pipe, - 'ipadapter_ckpt_path': resolved_ip_path, - 'image_encoder_path': resolved_encoder_path, - 'device': stream.device, - 'dtype': stream.dtype, + "pipe": stream.pipe, + "ipadapter_ckpt_path": resolved_ip_path, + "image_encoder_path": resolved_encoder_path, + "device": stream.device, + "dtype": stream.dtype, } if self.config.type == IPAdapterType.FACEID and self.config.insightface_model_name: - ip_kwargs['insightface_model_name'] = self.config.insightface_model_name + ip_kwargs["insightface_model_name"] = self.config.insightface_model_name print( f"IPAdapterModule.install: Initializing FaceID IP-Adapter with InsightFace model: {self.config.insightface_model_name}" ) @@ -311,6 +317,7 @@ def install(self, stream) -> None: # AttnProcessor2_0 which accepts kvo_cache and returns (hidden_states, kvo_cache). try: from diffusers.models.attention_processor import AttnProcessor2_0 as NativeAttnProcessor2_0 + attn_procs = stream.pipe.unet.attn_processors for name in attn_procs: if name.endswith("attn1.processor"): @@ -324,6 +331,7 @@ def install(self, stream) -> None: if self.config.type == IPAdapterType.FACEID: try: from streamdiffusion.preprocessing.processors.faceid_embedding import FaceIDEmbeddingPreprocessor + embedding_preprocessor = FaceIDEmbeddingPreprocessor( ipadapter=ipadapter, device=stream.device, @@ -357,11 +365,11 @@ def install(self, stream) -> None: # Expose IPAdapter instance as single source of truth try: - setattr(stream, 'ipadapter', ipadapter) + setattr(stream, "ipadapter", ipadapter) # Extend IPAdapter with our custom attributes since diffusers IPAdapter doesn't expose current state - setattr(ipadapter, 'weight_type', self.config.weight_type) # For build_layer_weights - setattr(ipadapter, 'scale', float(self.config.scale)) # Track current scale - setattr(ipadapter, 'enabled', bool(self.config.enabled)) # Track enabled state + setattr(ipadapter, "weight_type", self.config.weight_type) # For build_layer_weights + setattr(ipadapter, "scale", float(self.config.scale)) # Track current scale + setattr(ipadapter, "enabled", bool(self.config.enabled)) # Track enabled state except Exception: pass @@ -389,7 +397,10 @@ def _resolve_model_path(self, model_path: Optional[str]) -> str: from huggingface_hub import hf_hub_download, snapshot_download except Exception as e: import logging - logging.getLogger(__name__).error(f"IPAdapterModule: huggingface_hub required to resolve '{model_path}': {e}") + + logging.getLogger(__name__).error( + f"IPAdapterModule: huggingface_hub required to resolve '{model_path}': {e}" + ) raise parts = model_path.split("/") @@ -419,28 +430,28 @@ def build_unet_hook(self, stream) -> UnetHook: - For PyTorch UNet with installed IP processors, modulate per-layer processor scale by time factor """ _last_enabled_state = None # Track previous enabled state to avoid redundant updates - + def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # If no IP-Adapter installed, do nothing - if not hasattr(stream, 'ipadapter') or stream.ipadapter is None: + if not hasattr(stream, "ipadapter") or stream.ipadapter is None: return UnetKwargsDelta() # Check if IPAdapter is enabled - enabled = getattr(stream.ipadapter, 'enabled', True) + enabled = getattr(stream.ipadapter, "enabled", True) # Read base weight and weight type from IPAdapter instance try: - base_weight = float(getattr(stream.ipadapter, 'scale', 1.0)) if enabled else 0.0 + base_weight = float(getattr(stream.ipadapter, "scale", 1.0)) if enabled else 0.0 except Exception: base_weight = 0.0 if not enabled else 1.0 - weight_type = getattr(stream.ipadapter, 'weight_type', None) + weight_type = getattr(stream.ipadapter, "weight_type", None) # Determine total steps and current step index for time scheduling total_steps = None try: - if hasattr(stream, 'denoising_steps_num') and isinstance(stream.denoising_steps_num, int): + if hasattr(stream, "denoising_steps_num") and isinstance(stream.denoising_steps_num, int): total_steps = int(stream.denoising_steps_num) - elif hasattr(stream, 't_list') and stream.t_list is not None: + elif hasattr(stream, "t_list") and stream.t_list is not None: total_steps = len(stream.t_list) except Exception: total_steps = None @@ -449,6 +460,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: if total_steps is not None and ctx.step_index is not None: try: from diffusers_ipadapter.ip_adapter.attention_processor import build_time_weight_factor + time_factor = float(build_time_weight_factor(weight_type, int(ctx.step_index), int(total_steps))) except Exception: # Do not add fallback mechanisms @@ -456,18 +468,20 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # TensorRT engine path: supply ipadapter_scale vector via extra kwargs try: - is_trt_unet = hasattr(stream, 'unet') and hasattr(stream.unet, 'engine') and hasattr(stream.unet, 'stream') + is_trt_unet = ( + hasattr(stream, "unet") and hasattr(stream.unet, "engine") and hasattr(stream.unet, "stream") + ) except Exception: is_trt_unet = False - if is_trt_unet and getattr(stream.unet, 'use_ipadapter', False): + if is_trt_unet and getattr(stream.unet, "use_ipadapter", False): try: from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights except Exception: # If helper unavailable, do not construct weights here build_layer_weights = None # type: ignore - num_ip_layers = getattr(stream.unet, 'num_ip_layers', None) + num_ip_layers = getattr(stream.unet, "num_ip_layers", None) if isinstance(num_ip_layers, int) and num_ip_layers > 0: weights_tensor = None try: @@ -476,24 +490,26 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: except Exception: weights_tensor = None if weights_tensor is None: - weights_tensor = torch.full((num_ip_layers,), float(base_weight), dtype=torch.float32, device=stream.device) + weights_tensor = torch.full( + (num_ip_layers,), float(base_weight), dtype=torch.float32, device=stream.device + ) # Apply per-step time factor try: weights_tensor = weights_tensor * float(time_factor) except Exception: pass - return UnetKwargsDelta(extra_unet_kwargs={'ipadapter_scale': weights_tensor}) + return UnetKwargsDelta(extra_unet_kwargs={"ipadapter_scale": weights_tensor}) # PyTorch UNet path: modulate installed processor scales by time factor and enabled state try: nonlocal _last_enabled_state # Only process if we need to make changes (time scaling or state transition) - needs_update = (time_factor != 1.0 or enabled != _last_enabled_state) - if needs_update and hasattr(stream.pipe, 'unet') and hasattr(stream.pipe.unet, 'attn_processors'): + needs_update = time_factor != 1.0 or enabled != _last_enabled_state + if needs_update and hasattr(stream.pipe, "unet") and hasattr(stream.pipe.unet, "attn_processors"): _last_enabled_state = enabled for proc in stream.pipe.unet.attn_processors.values(): - if hasattr(proc, 'scale') and hasattr(proc, '_ip_layer_index'): - base_val = getattr(proc, '_base_scale', proc.scale) + if hasattr(proc, "scale") and hasattr(proc, "_ip_layer_index"): + base_val = getattr(proc, "_base_scale", proc.scale) # Apply both enabled state and time factor final_scale = float(base_val) * float(time_factor) if enabled else 0.0 proc.scale = final_scale @@ -503,4 +519,3 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: return UnetKwargsDelta() return _unet_hook - diff --git a/src/streamdiffusion/modules/latent_processing_module.py b/src/streamdiffusion/modules/latent_processing_module.py index 256c66f0..78edf1b3 100644 --- a/src/streamdiffusion/modules/latent_processing_module.py +++ b/src/streamdiffusion/modules/latent_processing_module.py @@ -1,54 +1,55 @@ -from typing import List, Optional, Any, Dict +from typing import Any, Dict, List + import torch -from ..preprocessing.orchestrator_user import OrchestratorUser from ..hooks import LatentCtx, LatentHook +from ..preprocessing.orchestrator_user import OrchestratorUser class LatentProcessingModule(OrchestratorUser): """ Shared base class for latent domain processing modules. - + Handles sequential chain execution for both preprocessing and postprocessing timing variants. Processing domain is always latent tensors. """ - + def __init__(self): """Initialize latent processing module.""" self.processors = [] - + def _process_latent_chain(self, input_latent: torch.Tensor) -> torch.Tensor: """Execute sequential chain of processors in latent domain. - + Uses the shared orchestrator's sequential chain processing. """ if not self.processors: return input_latent - + ordered_processors = self._get_ordered_processors() return self._preprocessing_orchestrator.execute_pipeline_chain( input_latent, ordered_processors, processing_domain="latent" ) - + def add_processor(self, proc_config: Dict[str, Any]) -> None: """Add a processor using the existing registry, following ControlNet pattern.""" from streamdiffusion.preprocessing.processors import get_preprocessor - - processor_type = proc_config.get('type') + + processor_type = proc_config.get("type") if not processor_type: raise ValueError("Processor config missing 'type' field") - + # Check if processor is enabled (default to True, same as ControlNet) - enabled = proc_config.get('enabled', True) - + enabled = proc_config.get("enabled", True) + # Create processor using existing registry (same as ControlNet) # LatentProcessingModule uses 'latent' normalization context (works in latent space) - processor = get_preprocessor(processor_type, pipeline_ref=self._stream, normalization_context='latent') - + processor = get_preprocessor(processor_type, pipeline_ref=self._stream, normalization_context="latent") + # Apply parameters (same pattern as ControlNet) - processor_params = proc_config.get('params', {}) + processor_params = proc_config.get("params", {}) if processor_params: - if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict): + if hasattr(processor, "params") and isinstance(getattr(processor, "params"), dict): processor.params.update(processor_params) for name, value in processor_params.items(): try: @@ -56,62 +57,66 @@ def add_processor(self, proc_config: Dict[str, Any]) -> None: setattr(processor, name, value) except Exception: pass - + # Set order for sequential execution - order = proc_config.get('order', len(self.processors)) - setattr(processor, 'order', order) - + order = proc_config.get("order", len(self.processors)) + setattr(processor, "order", order) + # Set enabled state - setattr(processor, 'enabled', enabled) - + setattr(processor, "enabled", enabled) + # Pipeline reference is now automatically handled by the factory function - + self.processors.append(processor) - + def _get_ordered_processors(self) -> List[Any]: """Return enabled processors in execution order based on their order attribute.""" # Filter for enabled processors first, then sort by order - enabled_processors = [p for p in self.processors if getattr(p, 'enabled', True)] - return sorted(enabled_processors, key=lambda p: getattr(p, 'order', 0)) + enabled_processors = [p for p in self.processors if getattr(p, "enabled", True)] + return sorted(enabled_processors, key=lambda p: getattr(p, "order", 0)) class LatentPreprocessingModule(LatentProcessingModule): """ Latent domain preprocessing module - executes after VAE encoding, before diffusion. - + Timing: After encode_image(), before predict_x0_batch() """ - + def install(self, stream) -> None: """Install module by registering hook with stream and attaching orchestrator.""" self.attach_orchestrator(stream) self._stream = stream # Store stream reference like ControlNet module does stream.latent_preprocessing_hooks.append(self.build_latent_hook()) - + def build_latent_hook(self) -> LatentHook: """Build hook function that processes latent context.""" + def hook(ctx: LatentCtx) -> LatentCtx: ctx.latent = self._process_latent_chain(ctx.latent) return ctx + return hook class LatentPostprocessingModule(LatentProcessingModule): """ Latent domain postprocessing module - executes after diffusion, before VAE decoding. - + Timing: After predict_x0_batch(), before decode_image() """ - + def install(self, stream) -> None: """Install module by registering hook with stream and attaching orchestrator.""" self.attach_orchestrator(stream) self._stream = stream # Store stream reference like ControlNet module does stream.latent_postprocessing_hooks.append(self.build_latent_hook()) - + def build_latent_hook(self) -> LatentHook: """Build hook function that processes latent context.""" + def hook(ctx: LatentCtx) -> LatentCtx: ctx.latent = self._process_latent_chain(ctx.latent) return ctx + return hook diff --git a/src/streamdiffusion/pip_utils.py b/src/streamdiffusion/pip_utils.py index 6ae3f11c..4a28c0a0 100644 --- a/src/streamdiffusion/pip_utils.py +++ b/src/streamdiffusion/pip_utils.py @@ -27,13 +27,16 @@ def _check_torch_installed(): raise RuntimeError(msg) if not torch.version.cuda: - raise RuntimeError("Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package.") + raise RuntimeError( + "Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package." + ) def get_cuda_version() -> str | None: _check_torch_installed() import torch + return torch.version.cuda @@ -66,7 +69,7 @@ def is_installed(package: str) -> bool: def run_python(command: str, env: Dict[str, str] | None = None) -> str: run_kwargs = { - "args": f"\"{python}\" {command}", + "args": f'"{python}" {command}', "shell": True, "env": os.environ if env is None else env, "encoding": "utf8", diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index b8adb71f..55922e69 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -25,6 +25,7 @@ from streamdiffusion.image_filter import SimilarImageFilter from streamdiffusion.model_detection import detect_model from streamdiffusion.stream_parameter_updater import StreamParameterUpdater +from streamdiffusion.tools.gpu_profiler import profiler logging.basicConfig(level=logging.INFO) @@ -56,6 +57,7 @@ def __init__( self.device = torch.device(device) self.dtype = torch_dtype self.generator = None + self._input_staging: Optional[torch.Tensor] = None self.height = height self.width = width @@ -112,11 +114,11 @@ def __init__( self._prev_image_buf = None # pre-allocated buffer for skip-frame image cache self._combined_latent_buf = None # pre-allocated: avoids torch.cat in predict_x0_batch self._alpha_next = None # pre-computed: cat([alpha_prod_t_sqrt[1:], ones[0:1]]) - self._beta_next = None # pre-computed: cat([beta_prod_t_sqrt[1:], ones[0:1]]) + self._beta_next = None # pre-computed: cat([beta_prod_t_sqrt[1:], ones[0:1]]) self._init_noise_rotated = None # pre-computed: cat([init_noise[1:], init_noise[0:1]]) self._unet_kwargs: dict = {"return_dict": False} # pre-allocated: avoids per-frame dict creation self._cfg_latent_buf = None # pre-allocated: avoids torch.concat for CFG latent doubling - self._cfg_t_buf = None # pre-allocated: avoids torch.concat for CFG timestep doubling + self._cfg_t_buf = None # pre-allocated: avoids torch.concat for CFG timestep doubling self.pipe = pipe self.image_processor = VaeImageProcessor(pipe.vae_scale_factor) @@ -161,10 +163,16 @@ def __init__( self._cached_guidance_scale: Optional[float] = None self.kvo_cache = kvo_cache + self._kvo_buckets = None + self._kvo_outputs_by_bucket = None self.cache_interval = cache_interval self.cache_maxframes = cache_maxframes self.frame_idx = 0 + # Pre-allocated CUDA timing events — reused every frame via .record() + self._timing_start = torch.cuda.Event(enable_timing=True) + self._timing_end = torch.cuda.Event(enable_timing=True) + def _initialize_scheduler(self, scheduler_type: str, sampler_type: str, config): """Initialize scheduler based on type and sampler configuration.""" @@ -346,6 +354,13 @@ def prepare( self.generator = generator self.generator.manual_seed(seed) self.current_seed = seed + + # Pinned CPU staging buffer for async H2D DMA via non_blocking=True. + # Pageable memory silently falls back to synchronous double-buffered staging (~½ bandwidth). + # dtype=self.dtype: avoids a transient GPU fp32 buffer that PyTorch would allocate + # when casting pageable float32 input to fp16 on the device. + self._input_staging = torch.zeros(1, 3, self.height, self.width, dtype=self.dtype).pin_memory() + # initialize x_t_latent (it can be any random tensor) if self.denoising_steps_num > 1: self.x_t_latent_buffer = torch.zeros( @@ -675,11 +690,11 @@ def unet_step( self._cfg_t_buf[1:].copy_(t_list) t_list = self._cfg_t_buf elif self.guidance_scale > 1.0 and (self.cfg_type == "full"): - self._cfg_latent_buf[:len(x_t_latent)].copy_(x_t_latent) - self._cfg_latent_buf[len(x_t_latent):].copy_(x_t_latent) + self._cfg_latent_buf[: len(x_t_latent)].copy_(x_t_latent) + self._cfg_latent_buf[len(x_t_latent) :].copy_(x_t_latent) x_t_latent_plus_uc = self._cfg_latent_buf - self._cfg_t_buf[:len(t_list)].copy_(t_list) - self._cfg_t_buf[len(t_list):].copy_(t_list) + self._cfg_t_buf[: len(t_list)].copy_(t_list) + self._cfg_t_buf[len(t_list) :].copy_(t_list) t_list = self._cfg_t_buf else: x_t_latent_plus_uc = x_t_latent @@ -717,39 +732,40 @@ def unet_step( unet_kwargs["added_cond_kwargs"] = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # Allow modules to contribute additional UNet kwargs via hooks - try: - step_ctx = StepCtx( - x_t_latent=x_t_latent_plus_uc, - t_list=t_list, - step_index=idx if isinstance(idx, int) else (int(idx) if idx is not None else None), - guidance_mode=self.cfg_type if self.guidance_scale > 1.0 else "none", - sdxl_cond=unet_kwargs.get("added_cond_kwargs", None), - ) - extra_from_hooks = {} - for hook in self.unet_hooks: - delta: UnetKwargsDelta = hook(step_ctx) - if delta is None: - continue - if delta.down_block_additional_residuals is not None: - unet_kwargs["down_block_additional_residuals"] = delta.down_block_additional_residuals - if delta.mid_block_additional_residual is not None: - unet_kwargs["mid_block_additional_residual"] = delta.mid_block_additional_residual - if delta.added_cond_kwargs is not None: - # Merge SDXL cond if both exist - base_added = unet_kwargs.get("added_cond_kwargs", {}) - base_added.update(delta.added_cond_kwargs) - unet_kwargs["added_cond_kwargs"] = base_added - if getattr(delta, "extra_unet_kwargs", None): - # Merge extra kwargs from hooks (e.g., ipadapter_scale) - try: - extra_from_hooks.update(delta.extra_unet_kwargs) - except Exception: - pass - if extra_from_hooks: - unet_kwargs["extra_unet_kwargs"] = extra_from_hooks - except Exception as e: - logger.error(f"unet_step: unet hook failed: {e}") - raise + if self.unet_hooks: + try: + step_ctx = StepCtx( + x_t_latent=x_t_latent_plus_uc, + t_list=t_list, + step_index=idx if isinstance(idx, int) else (int(idx) if idx is not None else None), + guidance_mode=self.cfg_type if self.guidance_scale > 1.0 else "none", + sdxl_cond=unet_kwargs.get("added_cond_kwargs", None), + ) + extra_from_hooks = {} + for hook in self.unet_hooks: + delta: UnetKwargsDelta = hook(step_ctx) + if delta is None: + continue + if delta.down_block_additional_residuals is not None: + unet_kwargs["down_block_additional_residuals"] = delta.down_block_additional_residuals + if delta.mid_block_additional_residual is not None: + unet_kwargs["mid_block_additional_residual"] = delta.mid_block_additional_residual + if delta.added_cond_kwargs is not None: + # Merge SDXL cond if both exist + base_added = unet_kwargs.get("added_cond_kwargs", {}) + base_added.update(delta.added_cond_kwargs) + unet_kwargs["added_cond_kwargs"] = base_added + if getattr(delta, "extra_unet_kwargs", None): + # Merge extra kwargs from hooks (e.g., ipadapter_scale) + try: + extra_from_hooks.update(delta.extra_unet_kwargs) + except Exception: + pass + if extra_from_hooks: + unet_kwargs["extra_unet_kwargs"] = extra_from_hooks + except Exception as e: + logger.error(f"unet_step: unet hook failed: {e}") + raise # Extract potential ControlNet residual kwargs and generic extra kwargs (e.g., ipadapter_scale) hook_down_res = unet_kwargs.get("down_block_additional_residuals", None) @@ -881,116 +897,146 @@ def update_kvo_cache(self, kvo_cache_out: List[torch.Tensor]) -> None: # The attention processor reads all slots as an unordered K/V bag, so slot order is irrelevant. # Use self.cache_maxframes (not tensor shape) so that when the buffer is allocated at # max_cache_maxframes but the logical window is smaller, writes stay within the active range. - for i, new_kv in enumerate(kvo_cache_out): - cache_size = self.cache_maxframes - write_slot = (self.frame_idx // self.cache_interval - 1) % cache_size - self.kvo_cache[i][:, write_slot].copy_(new_kv.squeeze(1)) + write_slot = (self.frame_idx // self.cache_interval - 1) % self.cache_maxframes + + if self._kvo_buckets is not None: + # Bucketed path: N stacks + N copies (N≈3) instead of ~70 individual copies. + # Each kvo_cache_out[i] arrives shape (2, 1, B, S, H); squeeze(1) → (2, B, S, H). + for bucket_idx, output_indices in enumerate(self._kvo_outputs_by_bucket): + stacked = torch.stack([kvo_cache_out[i].squeeze(1) for i in output_indices], dim=0) + self._kvo_buckets[bucket_idx][:, :, write_slot].copy_(stacked) + else: + # Fallback path — used when buckets were dropped after a kvo_cache resize + # (stream_parameter_updater rebinds kvo_cache[i] to standalone tensors). + for i, new_kv in enumerate(kvo_cache_out): + self.kvo_cache[i][:, write_slot].copy_(new_kv.squeeze(1)) def encode_image(self, image_tensors: torch.Tensor) -> torch.Tensor: - image_tensors = image_tensors.to( - device=self.device, - dtype=self.vae.dtype, - ) - with torch.autocast("cuda", dtype=self.dtype): - img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator) - - img_latent = img_latent * self.vae.config.scaling_factor + with profiler.region("encode_image"): + image_tensors = image_tensors.to( + device=self.device, + dtype=self.vae.dtype, + non_blocking=True, + ) + with torch.autocast("cuda", dtype=self.dtype): + img_latent = retrieve_latents(self.vae.encode(image_tensors), self.generator) - x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0) + img_latent = img_latent * self.vae.config.scaling_factor - return x_t_latent + x_t_latent = self.add_noise(img_latent, self.init_noise[0], 0) + return x_t_latent def decode_image(self, x_0_pred_out: torch.Tensor) -> torch.Tensor: - scaled_latent = x_0_pred_out / self.vae.config.scaling_factor - with torch.autocast("cuda", dtype=self.dtype): - output_latent = self.vae.decode(scaled_latent, return_dict=False)[0] - return output_latent + with profiler.region("decode_image"): + scaled_latent = x_0_pred_out / self.vae.config.scaling_factor + with torch.autocast("cuda", dtype=self.dtype): + output_latent = self.vae.decode(scaled_latent, return_dict=False)[0] + return output_latent def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: - prev_latent_batch = self.x_t_latent_buffer - - # LCM supports our denoising-batch trick. TCD must use standard scheduler.step() sequentially - # but now properly processes ControlNet hooks through unet_step() - if self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): - t_list = self.sub_timesteps_tensor - if self.denoising_steps_num > 1: - # Copy into pre-allocated buffer: eliminates 1 malloc + copy kernel vs torch.cat - self._combined_latent_buf[:self.frame_bff_size].copy_(x_t_latent) - self._combined_latent_buf[self.frame_bff_size:].copy_(prev_latent_batch) - x_t_latent = self._combined_latent_buf - self.stock_noise = torch.cat((self.init_noise[0:1], self.stock_noise[:-1]), dim=0) - x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) - - if self.denoising_steps_num > 1: - x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) - if self.do_add_noise: - self.x_t_latent_buffer = ( - self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] - + self.beta_prod_t_sqrt[1:] * self.init_noise[1:] - ) + with profiler.region("predict_x0_batch"): + prev_latent_batch = self.x_t_latent_buffer + + # LCM supports our denoising-batch trick. TCD must use standard scheduler.step() sequentially + # but now properly processes ControlNet hooks through unet_step() + if self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler): + t_list = self.sub_timesteps_tensor + if self.denoising_steps_num > 1: + # Copy into pre-allocated buffer: eliminates 1 malloc + copy kernel vs torch.cat + self._combined_latent_buf[: self.frame_bff_size].copy_(x_t_latent) + self._combined_latent_buf[self.frame_bff_size :].copy_(prev_latent_batch) + x_t_latent = self._combined_latent_buf + self.stock_noise = torch.cat((self.init_noise[0:1], self.stock_noise[:-1]), dim=0) + with profiler.region("unet_step"): + x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) + + if self.denoising_steps_num > 1: + x_0_pred_out = x_0_pred_batch[-1].unsqueeze(0) + if self.do_add_noise: + self.x_t_latent_buffer = ( + self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] + + self.beta_prod_t_sqrt[1:] * self.init_noise[1:] + ) + else: + self.x_t_latent_buffer = self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] else: - self.x_t_latent_buffer = self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1] + x_0_pred_out = x_0_pred_batch + self.x_t_latent_buffer = None else: - x_0_pred_out = x_0_pred_batch - self.x_t_latent_buffer = None - else: - # Standard scheduler loop for TCD and non-batched LCM - sample = x_t_latent - for idx, timestep in enumerate(self.sub_timesteps_tensor): - # Ensure timestep tensor on device with correct dtype - if not isinstance(timestep, torch.Tensor): - t = torch.tensor(timestep, device=self.device, dtype=torch.long) - else: - t = timestep.to(self.device) - - # For TCD, use the same UNet calling logic as LCM to ensure ControlNet hooks are processed - if isinstance(self.scheduler, TCDScheduler): - # Use unet_step to process ControlNet hooks and get proper noise prediction - t_expanded = t.view( - 1, - ).repeat( - self.frame_bff_size, - ) - x_0_pred, model_pred = self.unet_step(sample, t_expanded, idx) - - # Apply TCD scheduler step to the guided noise prediction - step_out = self.scheduler.step(model_pred, t, sample) - sample = getattr( - step_out, "prev_sample", step_out[0] if isinstance(step_out, (tuple, list)) else step_out - ) - else: - # Original LCM logic for non-batched mode - t = t.view( - 1, - ).repeat( - self.frame_bff_size, - ) - x_0_pred, model_pred = self.unet_step(sample, t, idx) - if idx < len(self.sub_timesteps_tensor) - 1: - if self.do_add_noise: - if self._noise_buf is None: - self._noise_buf = torch.empty_like(x_0_pred) - self._noise_buf.normal_() - sample = ( - self.alpha_prod_t_sqrt[idx + 1] * x_0_pred - + self.beta_prod_t_sqrt[idx + 1] * self._noise_buf - ) - else: - sample = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + # Standard scheduler loop for TCD and non-batched LCM + sample = x_t_latent + for idx, timestep in enumerate(self.sub_timesteps_tensor): + # Ensure timestep tensor on device with correct dtype + if not isinstance(timestep, torch.Tensor): + t = torch.tensor(timestep, device=self.device, dtype=torch.long) + else: + t = timestep.to(self.device) + + # For TCD, use the same UNet calling logic as LCM to ensure ControlNet hooks are processed + if isinstance(self.scheduler, TCDScheduler): + # Use unet_step to process ControlNet hooks and get proper noise prediction + t_expanded = t.view( + 1, + ).repeat( + self.frame_bff_size, + ) + with profiler.region("unet_step"): + x_0_pred, model_pred = self.unet_step(sample, t_expanded, idx) + + # Apply TCD scheduler step to the guided noise prediction + with profiler.region("scheduler_step"): + step_out = self.scheduler.step(model_pred, t, sample) + sample = getattr( + step_out, "prev_sample", step_out[0] if isinstance(step_out, (tuple, list)) else step_out + ) else: - sample = x_0_pred + # Original LCM logic for non-batched mode + t = t.view( + 1, + ).repeat( + self.frame_bff_size, + ) + with profiler.region("unet_step"): + x_0_pred, model_pred = self.unet_step(sample, t, idx) + if idx < len(self.sub_timesteps_tensor) - 1: + if self.do_add_noise: + if self._noise_buf is None: + self._noise_buf = torch.empty_like(x_0_pred) + self._noise_buf.normal_() + sample = ( + self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + + self.beta_prod_t_sqrt[idx + 1] * self._noise_buf + ) + else: + sample = self.alpha_prod_t_sqrt[idx + 1] * x_0_pred + else: + sample = x_0_pred - x_0_pred_out = sample - return x_0_pred_out + x_0_pred_out = sample + return x_0_pred_out @torch.inference_mode() def __call__(self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None) -> torch.Tensor: - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + start = self._timing_start + end = self._timing_end start.record() if x is not None: - x = self.image_processor.preprocess(x, self.height, self.width).to(device=self.device, dtype=self.dtype) + # Fast path: already a normalized GPU tensor with the right shape/dtype. + # Skips image_processor.preprocess (and its per-frame image.min() GPU scan). + # PIL/numpy inputs still take the full preprocessing path below. + if not ( + isinstance(x, torch.Tensor) + and x.is_cuda + and x.dtype == self.dtype + and x.shape[-2:] == (self.height, self.width) + ): + _raw = self.image_processor.preprocess(x, self.height, self.width) + if not _raw.is_cuda and self._input_staging is not None and _raw.shape == self._input_staging.shape: + self._input_staging.copy_(_raw) # CPU fp32→dtype cast into pinned memory + x = self._input_staging.to(device=self.device, non_blocking=True) + else: + x = _raw.to(device=self.device, dtype=self.dtype, non_blocking=True) # IMAGE PREPROCESSING HOOKS: After built-in preprocessing, before filtering x = self._apply_image_preprocessing_hooks(x) @@ -1182,7 +1228,9 @@ def txt2img_sd_turbo(self, batch_size: int = 1) -> torch.Tensor: return_dict=False, ) - x_0_pred_out = (x_t_latent - self.beta_prod_t_sqrt * model_pred) / self.alpha_prod_t_sqrt + x_0_pred_out = ((x_t_latent - self.beta_prod_t_sqrt * model_pred).float() / self.alpha_prod_t_sqrt.float()).to( + x_t_latent.dtype + ) # LATENT POSTPROCESSING HOOKS: After diffusion, before VAE decoding x_0_pred_out = self._apply_latent_postprocessing_hooks(x_0_pred_out) diff --git a/src/streamdiffusion/preprocessing/__init__.py b/src/streamdiffusion/preprocessing/__init__.py index 4228ee69..c52a8a2e 100644 --- a/src/streamdiffusion/preprocessing/__init__.py +++ b/src/streamdiffusion/preprocessing/__init__.py @@ -1,13 +1,14 @@ -from .preprocessing_orchestrator import PreprocessingOrchestrator -from .postprocessing_orchestrator import PostprocessingOrchestrator -from .pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator from .base_orchestrator import BaseOrchestrator from .orchestrator_user import OrchestratorUser +from .pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator +from .postprocessing_orchestrator import PostprocessingOrchestrator +from .preprocessing_orchestrator import PreprocessingOrchestrator + __all__ = [ "PreprocessingOrchestrator", "PostprocessingOrchestrator", "PipelinePreprocessingOrchestrator", "BaseOrchestrator", - "OrchestratorUser" + "OrchestratorUser", ] diff --git a/src/streamdiffusion/preprocessing/base_orchestrator.py b/src/streamdiffusion/preprocessing/base_orchestrator.py index d6d86bf2..e5f6c1b2 100644 --- a/src/streamdiffusion/preprocessing/base_orchestrator.py +++ b/src/streamdiffusion/preprocessing/base_orchestrator.py @@ -1,144 +1,148 @@ -import torch -from typing import List, Optional, Union, Dict, Any, Tuple, Callable, TypeVar, Generic -from abc import ABC, abstractmethod -import numpy as np import concurrent.futures import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, TypeVar + +import torch + logger = logging.getLogger(__name__) # Type variables for generic orchestrator -T = TypeVar('T') # Input type (e.g., ControlImage for preprocessing) -R = TypeVar('R') # Result type (e.g., List[torch.Tensor] for preprocessing) +T = TypeVar("T") # Input type (e.g., ControlImage for preprocessing) +R = TypeVar("R") # Result type (e.g., List[torch.Tensor] for preprocessing) class BaseOrchestrator(Generic[T, R], ABC): """ Generic base orchestrator for parallelized and pipelined processing. - + Handles thread pool management, pipeline state, and inter-frame pipelining while leaving domain-specific processing logic to subclasses. - + Type Parameters: T: Input type for processing operations R: Result type returned from processing operations """ - - def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max_workers: int = 4, timeout_ms: float = 10.0, pipeline_ref: Optional[Any] = None): + + def __init__( + self, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + max_workers: int = 4, + timeout_ms: float = 10.0, + pipeline_ref: Optional[Any] = None, + ): self.device = device self.dtype = dtype self.timeout_ms = timeout_ms self.pipeline_ref = pipeline_ref self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) - + # Pipeline state for pipelined processing self._next_frame_future = None self._next_frame_result = None - + # CUDA stream for background processing to avoid GPU contention self._background_stream = None device_str = str(device) if device_str.startswith("cuda") and torch.cuda.is_available(): self._background_stream = torch.cuda.Stream() - - def cleanup(self) -> None: """Cleanup thread pool and CUDA stream resources""" - if hasattr(self, '_executor'): + if hasattr(self, "_executor"): self._executor.shutdown(wait=True) - + # Cleanup CUDA stream if it exists - if hasattr(self, '_background_stream') and self._background_stream is not None: + if hasattr(self, "_background_stream") and self._background_stream is not None: # Synchronize the stream before cleanup torch.cuda.synchronize() self._background_stream = None - + def __del__(self): """Cleanup on destruction""" try: self.cleanup() except: pass - + @abstractmethod def _should_use_sync_processing(self, *args, **kwargs) -> bool: """ Determine if synchronous processing should be used instead of pipelined. - + Subclasses implement domain-specific logic (e.g., feedback preprocessor detection). - + Returns: True if sync processing should be used, False for pipelined processing """ pass - + @abstractmethod def _process_frame_background(self, *args, **kwargs) -> Dict[str, Any]: """ Process a frame in the background thread. - + Subclasses implement their specific processing logic here. - + Returns: Dictionary containing processing results and status """ pass - + def process_pipelined(self, input_data: T, *args, **kwargs) -> R: """ Process input with intelligent pipelining. - + Automatically falls back to sync processing when required by domain logic, otherwise uses pipelined processing for performance. - + Args: input_data: Input data to process *args, **kwargs: Additional arguments passed to processing methods - + Returns: Processing results """ # Check if sync processing is required (domain-specific logic) if self._should_use_sync_processing(*args, **kwargs): return self.process_sync(input_data, *args, **kwargs) - + # Use pipelined processing # Wait for previous frame processing; non-blocking with short timeout self._wait_for_previous_processing() - + # Start next frame processing in background self._start_next_frame_processing(input_data, *args, **kwargs) - + # Apply current frame processing results if available; otherwise signal no update return self._apply_current_frame_processing(*args, **kwargs) - + @abstractmethod def process_sync(self, input_data: T, *args, **kwargs) -> R: """ Process input synchronously. - + Subclasses implement their specific synchronous processing logic. - + Args: input_data: Input data to process *args, **kwargs: Additional arguments passed to processing methods - + Returns: Processing results """ pass - + def _start_next_frame_processing(self, input_data: T, *args, **kwargs) -> None: """Start processing for next frame in background thread""" # Submit background processing - self._next_frame_future = self._executor.submit( - self._process_frame_background, input_data, *args, **kwargs - ) - + self._next_frame_future = self._executor.submit(self._process_frame_background, input_data, *args, **kwargs) + def _wait_for_previous_processing(self) -> None: """Wait for previous frame processing with configurable timeout""" - if hasattr(self, '_next_frame_future') and self._next_frame_future is not None: + if hasattr(self, "_next_frame_future") and self._next_frame_future is not None: try: # Use configurable timeout based on orchestrator type self._next_frame_result = self._next_frame_future.result(timeout=self.timeout_ms / 1000.0) @@ -150,52 +154,52 @@ def _wait_for_previous_processing(self) -> None: self._next_frame_result = None else: self._next_frame_result = None - + def _apply_current_frame_processing(self, processors=None, *args, **kwargs) -> R: """ Apply processing results from previous iteration. - + Default implementation provides common fallback logic for tensor-to-tensor orchestrators. Subclasses can override this method for specialized behavior. - + Args: processors: List of processors/postprocessors to apply (parameter name varies by subclass) *args, **kwargs: Additional arguments - + Returns: Processing results, or processed current input if no results available """ - if not hasattr(self, '_next_frame_result') or self._next_frame_result is None: + if not hasattr(self, "_next_frame_result") or self._next_frame_result is None: # First frame or no background results - process current input synchronously - if hasattr(self, '_current_input_tensor') and self._current_input_tensor is not None: + if hasattr(self, "_current_input_tensor") and self._current_input_tensor is not None: if processors: return self.process_sync(self._current_input_tensor, processors) else: return self._current_input_tensor - + # If we don't have current input stored, we have an issue class_name = self.__class__.__name__ logger.error(f"{class_name}: No background results and no current input tensor available") raise RuntimeError(f"{class_name}: No processing results available") - + result = self._next_frame_result - if result['status'] != 'success': + if result["status"] != "success": class_name = self.__class__.__name__ logger.warning(f"{class_name}: Background processing failed: {result.get('error', 'Unknown error')}") # Process current input synchronously on error - if hasattr(self, '_current_input_tensor') and self._current_input_tensor is not None: + if hasattr(self, "_current_input_tensor") and self._current_input_tensor is not None: if processors: return self.process_sync(self._current_input_tensor, processors) else: return self._current_input_tensor raise RuntimeError(f"{class_name}: Background processing failed and no fallback available") - - return result['result'] - + + return result["result"] + def _set_background_stream_context(self): """ Set CUDA stream context for background processing. - + Returns: The original stream to restore later, or None if no background stream """ @@ -204,11 +208,11 @@ def _set_background_stream_context(self): torch.cuda.set_stream(self._background_stream) return original_stream return None - + def _restore_stream_context(self, original_stream): """ Restore the original CUDA stream context. - + Args: original_stream: The stream to restore, or None to do nothing """ diff --git a/src/streamdiffusion/preprocessing/orchestrator_user.py b/src/streamdiffusion/preprocessing/orchestrator_user.py index 2503c14e..e731540b 100644 --- a/src/streamdiffusion/preprocessing/orchestrator_user.py +++ b/src/streamdiffusion/preprocessing/orchestrator_user.py @@ -2,9 +2,9 @@ from typing import Optional -from .preprocessing_orchestrator import PreprocessingOrchestrator -from .postprocessing_orchestrator import PostprocessingOrchestrator from .pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator +from .postprocessing_orchestrator import PostprocessingOrchestrator +from .preprocessing_orchestrator import PreprocessingOrchestrator class OrchestratorUser: @@ -20,32 +20,36 @@ class OrchestratorUser: def attach_orchestrator(self, stream) -> None: """Attach preprocessing orchestrator (backward compatibility).""" self.attach_preprocessing_orchestrator(stream) - + def attach_preprocessing_orchestrator(self, stream) -> None: """Attach shared preprocessing orchestrator from stream.""" - orchestrator = getattr(stream, 'preprocessing_orchestrator', None) + orchestrator = getattr(stream, "preprocessing_orchestrator", None) if orchestrator is None: # Lazy-create on stream once, on first user that needs it - orchestrator = PreprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream) - setattr(stream, 'preprocessing_orchestrator', orchestrator) + orchestrator = PreprocessingOrchestrator( + device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream + ) + setattr(stream, "preprocessing_orchestrator", orchestrator) self._preprocessing_orchestrator = orchestrator - + def attach_postprocessing_orchestrator(self, stream) -> None: """Attach shared postprocessing orchestrator from stream.""" - orchestrator = getattr(stream, 'postprocessing_orchestrator', None) + orchestrator = getattr(stream, "postprocessing_orchestrator", None) if orchestrator is None: # Lazy-create on stream once, on first user that needs it - orchestrator = PostprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream) - setattr(stream, 'postprocessing_orchestrator', orchestrator) + orchestrator = PostprocessingOrchestrator( + device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream + ) + setattr(stream, "postprocessing_orchestrator", orchestrator) self._postprocessing_orchestrator = orchestrator - + def attach_pipeline_preprocessing_orchestrator(self, stream) -> None: """Attach shared pipeline preprocessing orchestrator from stream.""" - orchestrator = getattr(stream, 'pipeline_preprocessing_orchestrator', None) + orchestrator = getattr(stream, "pipeline_preprocessing_orchestrator", None) if orchestrator is None: # Lazy-create on stream once, on first user that needs it - orchestrator = PipelinePreprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream) - setattr(stream, 'pipeline_preprocessing_orchestrator', orchestrator) + orchestrator = PipelinePreprocessingOrchestrator( + device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream + ) + setattr(stream, "pipeline_preprocessing_orchestrator", orchestrator) self._pipeline_preprocessing_orchestrator = orchestrator - - diff --git a/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py index 8cf4e717..382e874f 100644 --- a/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py +++ b/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py @@ -1,32 +1,42 @@ -import torch -from typing import List, Dict, Any, Optional import logging +from typing import Any, Dict, List, Optional + +import torch + from .base_orchestrator import BaseOrchestrator + logger = logging.getLogger(__name__) + class PipelinePreprocessingOrchestrator(BaseOrchestrator[torch.Tensor, torch.Tensor]): """ Orchestrates pipeline input preprocessing with parallelization and pipelining. - + Handles preprocessing of input tensors before they enter the diffusion pipeline. - + Tensor ranges: - Input: Receives [-1, 1] tensors from image_processor.preprocess() - Processors: Work in [-1, 1] space when normalization_context='pipeline' - Output: Returns [-1, 1] tensors for pipeline processing - + Note: Processors created with normalization_context='pipeline' expect and preserve [-1, 1] range. No automatic conversion happens in this orchestrator. """ - - def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max_workers: int = 4, pipeline_ref: Optional[Any] = None): + + def __init__( + self, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + max_workers: int = 4, + pipeline_ref: Optional[Any] = None, + ): # Pipeline preprocessing: 10ms timeout for responsive processing super().__init__(device, dtype, max_workers, timeout_ms=10.0, pipeline_ref=pipeline_ref) - + # Pipeline preprocessing specific state self._current_input_tensor = None # For BaseOrchestrator fallback logic - + def _should_use_sync_processing(self, *args, **kwargs) -> bool: """ Determine if synchronous processing should be used instead of pipelined. @@ -44,123 +54,102 @@ def _should_use_sync_processing(self, *args, **kwargs) -> bool: if not processors: return False for proc in processors: - if proc is not None and getattr(proc, 'requires_sync_processing', False): + if proc is not None and getattr(proc, "requires_sync_processing", False): return True return False - - def process_pipelined(self, - input_tensor: torch.Tensor, - processors: List[Any], - *args, **kwargs) -> torch.Tensor: + + def process_pipelined(self, input_tensor: torch.Tensor, processors: List[Any], *args, **kwargs) -> torch.Tensor: """ Process input with intelligent pipelining. - + Overrides base method to store current input tensor for fallback logic. """ # Store current input for fallback logic self._current_input_tensor = input_tensor - + # RACE CONDITION FIX: Check if there are actually enabled processors # Filter to only enabled processors (same logic as _get_ordered_processors) - enabled_processors = [p for p in processors if getattr(p, 'enabled', True)] if processors else [] - + enabled_processors = [p for p in processors if getattr(p, "enabled", True)] if processors else [] + if not enabled_processors: return input_tensor - + # Call parent implementation return super().process_pipelined(input_tensor, processors, *args, **kwargs) - - def process_sync(self, - input_tensor: torch.Tensor, - processors: List[Any]) -> torch.Tensor: + + def process_sync(self, input_tensor: torch.Tensor, processors: List[Any]) -> torch.Tensor: """ Process pipeline input tensor synchronously through preprocessors. - + Implementation of BaseOrchestrator.process_sync for pipeline preprocessing. - + Args: input_tensor: Input tensor to preprocess (already normalized) processors: List of preprocessor instances - + Returns: Preprocessed tensor ready for pipeline processing """ if not processors: return input_tensor - + # Sequential application of processors current_tensor = input_tensor for processor in processors: if processor is not None: current_tensor = self._apply_single_processor(current_tensor, processor) - + return current_tensor - - def _process_frame_background(self, - input_tensor: torch.Tensor, - processors: List[Any]) -> Dict[str, Any]: + + def _process_frame_background(self, input_tensor: torch.Tensor, processors: List[Any]) -> Dict[str, Any]: """ Process a frame in the background thread. - + Implementation of BaseOrchestrator._process_frame_background for pipeline preprocessing. - + Returns: Dictionary containing processing results and status """ try: # Set CUDA stream for background processing original_stream = self._set_background_stream_context() - + if not processors: - return { - 'result': input_tensor, - 'status': 'success' - } - + return {"result": input_tensor, "status": "success"} + # Process processors sequentially (most pipeline preprocessing is dependent) current_tensor = input_tensor for processor in processors: if processor is not None: current_tensor = self._apply_single_processor(current_tensor, processor) - - return { - 'result': current_tensor, - 'status': 'success' - } - + + return {"result": current_tensor, "status": "success"} + except Exception as e: logger.error(f"PipelinePreprocessingOrchestrator: Background processing failed: {e}") # Return original input tensor on error - return { - 'result': input_tensor, - 'error': str(e), - 'status': 'error' - } + return {"result": input_tensor, "error": str(e), "status": "error"} finally: # Restore original CUDA stream self._restore_stream_context(original_stream) - - - - def _apply_single_processor(self, - input_tensor: torch.Tensor, - processor: Any) -> torch.Tensor: + + def _apply_single_processor(self, input_tensor: torch.Tensor, processor: Any) -> torch.Tensor: """ Apply a single processor to the input tensor. - + Args: input_tensor: Input tensor to process processor: Processor instance - + Returns: Processed tensor """ try: # Apply processor - if hasattr(processor, 'process_tensor'): + if hasattr(processor, "process_tensor"): # Prefer tensor processing method result = processor.process_tensor(input_tensor) - elif hasattr(processor, 'process'): + elif hasattr(processor, "process"): # Use general process method result = processor.process(input_tensor) elif callable(processor): @@ -169,18 +158,18 @@ def _apply_single_processor(self, else: logger.warning(f"PipelinePreprocessingOrchestrator: Unknown processor type: {type(processor)}") return input_tensor - + # Ensure result is a tensor if isinstance(result, torch.Tensor): return result else: logger.warning(f"PipelinePreprocessingOrchestrator: Processor returned non-tensor: {type(result)}") return input_tensor - + except Exception as e: logger.error(f"PipelinePreprocessingOrchestrator: Processor failed: {e}") return input_tensor # Return original on error - + def clear_cache(self) -> None: """Clear preprocessing cache""" pass diff --git a/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py index 742a80e6..ef5dceb3 100644 --- a/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py +++ b/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py @@ -73,7 +73,7 @@ def _should_use_sync_processing(self, *args, **kwargs) -> bool: if not processors: return False for proc in processors: - if proc is not None and getattr(proc, 'requires_sync_processing', False): + if proc is not None and getattr(proc, "requires_sync_processing", False): return True return False diff --git a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py index 96c247c4..fd554ac6 100644 --- a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py +++ b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py @@ -1,13 +1,15 @@ -import torch -from typing import List, Optional, Union, Dict, Any, Tuple, Callable -from PIL import Image -import numpy as np -import concurrent.futures import logging -from diffusers.utils import load_image +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch import torchvision.transforms as transforms +from diffusers.utils import load_image +from PIL import Image + from .base_orchestrator import BaseOrchestrator + logger = logging.getLogger(__name__) # Type alias for control image input @@ -17,40 +19,45 @@ class PreprocessingOrchestrator(BaseOrchestrator[ControlImage, List[Optional[torch.Tensor]]]): """ Orchestrates module preprocessing with typical orchestrator pipelining, but with additional intraframe parallelization, caching, and optimization. - Modules (IPAdapter, Controlnet) share intraframe parallelism. + Modules (IPAdapter, Controlnet) share intraframe parallelism. Handles image format conversion (while most are GPU native,some preprocessors are CPU only), preprocessor execution, and result caching. """ - - def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max_workers: int = 4, pipeline_ref: Optional[Any] = None): + + def __init__( + self, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + max_workers: int = 4, + pipeline_ref: Optional[Any] = None, + ): # Preprocessing: 10ms timeout for fast frame-skipping behavior super().__init__(device, dtype, max_workers, timeout_ms=10.0, pipeline_ref=pipeline_ref) - + # Caching self._preprocessed_cache: Dict[str, torch.Tensor] = {} self._last_input_frame = None - + # Optimized transforms self._cached_transform = transforms.ToTensor() - + # Cache pipelining decision to avoid hot path checks self._preprocessors_cache_key = None self._has_feedback_cache = False - - - - - #Abstract method implementations - def process_sync(self, - control_image: ControlImage, - preprocessors: List[Optional[Any]], - scales: List[float] = None, - stream_width: int = None, - stream_height: int = None, - index: Optional[int] = None, - processing_type: str = "controlnet") -> Union[List[Optional[torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor]]]: + + # Abstract method implementations + def process_sync( + self, + control_image: ControlImage, + preprocessors: List[Optional[Any]], + scales: List[float] = None, + stream_width: int = None, + stream_height: int = None, + index: Optional[int] = None, + processing_type: str = "controlnet", + ) -> Union[List[Optional[torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor]]]: """ Process images synchronously for ControlNet or IPAdapter preprocessing. - + Args: control_image: Input image to process preprocessors: List of preprocessor instances @@ -59,7 +66,7 @@ def process_sync(self, stream_height: Target height for processing index: If specified, only process this single ControlNet index (ControlNet only) processing_type: "controlnet" or "ipadapter" to specify processing mode - + Returns: ControlNet: List of processed tensors for each ControlNet IPAdapter: List of (positive_embeds, negative_embeds) tuples @@ -80,410 +87,390 @@ def process_sync(self, ) else: raise ValueError(f"Invalid processing_type: {processing_type}. Must be 'controlnet' or 'ipadapter'") - + def _should_use_sync_processing(self, *args, **kwargs) -> bool: """ Check for pipeline-aware preprocessors that require sync processing. - - Pipeline-aware preprocessors (feedback, temporal, etc.) need synchronous processing + + Pipeline-aware preprocessors (feedback, temporal, etc.) need synchronous processing to avoid temporal artifacts and ensure access to previous pipeline outputs. - + Args: *args: Arguments from process_pipelined call (preprocessors, scales, stream_width, stream_height) **kwargs: Keyword arguments - + Returns: True if pipeline-aware preprocessors detected, False otherwise """ # Extract preprocessors from args - they're the first argument after control_image if len(args) < 1: return False - + preprocessors = args[0] # preprocessors is first arg after control_image return self._check_pipeline_aware_cached(preprocessors) - def _process_frame_background(self, - control_image: ControlImage, - *args, **kwargs) -> Dict[str, Any]: + def _process_frame_background(self, control_image: ControlImage, *args, **kwargs) -> Dict[str, Any]: """ Process a frame in the background thread. - + Implementation of BaseOrchestrator._process_frame_background for ControlNet preprocessing. Automatically detects processing mode based on current state. - + Returns: Dictionary containing processing results and status """ try: # Set CUDA stream for background processing original_stream = self._set_background_stream_context() - + # Check if last argument is "ipadapter" processing type if args and len(args) >= 5 and args[4] == "ipadapter": # Handle embedding preprocessing embedding_preprocessors = args[0] - stream_width = args[2] + stream_width = args[2] stream_height = args[3] - + # Prepare processing data control_variants = self._prepare_input_variants(control_image, thread_safe=True) - + # Process using existing IPAdapter logic try: results = self._process_ipadapter_preprocessors_parallel( embedding_preprocessors, control_variants, stream_width, stream_height ) - return { - 'results': results, - 'status': 'success' - } + return {"results": results, "status": "success"} except Exception as e: import traceback + traceback.print_exc() - return { - 'error': str(e), - 'status': 'error' - } - elif hasattr(self, '_current_processing_mode') and self._current_processing_mode == "embedding": + return {"error": str(e), "status": "error"} + elif hasattr(self, "_current_processing_mode") and self._current_processing_mode == "embedding": # Handle embedding preprocessing (legacy path) embedding_preprocessors = args[0] - stream_width = args[2] + stream_width = args[2] stream_height = args[3] - + # Prepare processing data control_variants = self._prepare_input_variants(control_image, thread_safe=True) - + # Process using existing IPAdapter logic try: results = self._process_ipadapter_preprocessors_parallel( embedding_preprocessors, control_variants, stream_width, stream_height ) - return { - 'results': results, - 'status': 'success' - } + return {"results": results, "status": "success"} except Exception as e: import traceback + traceback.print_exc() - return { - 'error': str(e), - 'status': 'error' - } + return {"error": str(e), "status": "error"} else: # Handle ControlNet preprocessing (default mode) preprocessors = args[0] scales = args[1] stream_width = args[2] stream_height = args[3] - + # Check if any processing is needed if not any(scale > 0 for scale in scales): - return {'status': 'success', 'results': [None] * len(preprocessors)} - #TODO: can we reuse similarity filter here? - if (self._last_input_frame is not None and - isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) and - control_image is self._last_input_frame): - return {'status': 'success', 'results': []} # Signal no update needed - + return {"status": "success", "results": [None] * len(preprocessors)} + # TODO: can we reuse similarity filter here? + if ( + self._last_input_frame is not None + and isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) + and control_image is self._last_input_frame + ): + return {"status": "success", "results": []} # Signal no update needed + self._last_input_frame = control_image - + # Prepare processing data preprocessor_groups = self._group_preprocessors(preprocessors, scales) active_indices = [i for i, scale in enumerate(scales) if scale > 0] - + if not active_indices: - return {'status': 'success', 'results': [None] * len(preprocessors)} - + return {"status": "success", "results": [None] * len(preprocessors)} + # Optimize input preparation control_variants = self._prepare_input_variants(control_image, thread_safe=True) - + # Process using unified parallel logic processed_images = self._process_controlnet_preprocessors_parallel( preprocessor_groups, control_variants, stream_width, stream_height, preprocessors ) - - return { - 'results': processed_images, - 'status': 'success' - } - + + return {"results": processed_images, "status": "success"} + except Exception as e: logger.error(f"PreprocessingOrchestrator: Background processing failed: {e}") - return { - 'error': str(e), - 'status': 'error' - } + return {"error": str(e), "status": "error"} finally: # Restore original CUDA stream self._restore_stream_context(original_stream) - - def _apply_current_frame_processing(self, - preprocessors: List[Optional[Any]] = None, - scales: List[float] = None, - *args, **kwargs) -> List[Optional[torch.Tensor]]: + + def _apply_current_frame_processing( + self, preprocessors: List[Optional[Any]] = None, scales: List[float] = None, *args, **kwargs + ) -> List[Optional[torch.Tensor]]: """ Apply processing results from previous iteration. - + Overrides BaseOrchestrator._apply_current_frame_processing for module preprocessing. - + Returns: List of processed tensors, or empty list to signal no update needed """ - if not hasattr(self, '_next_frame_result') or self._next_frame_result is None: + if not hasattr(self, "_next_frame_result") or self._next_frame_result is None: # Return empty list to signal no update needed return [] - + # Handle case where preprocessors is None if preprocessors is None: return [] - + processed_images = [None] * len(preprocessors) - + result = self._next_frame_result - if result['status'] != 'success': + if result["status"] != "success": # Return empty list to signal no update needed on error return [] - + # Handle case where no update is needed (cached input) - if 'results' in result and len(result['results']) == 0: + if "results" in result and len(result["results"]) == 0: return [] - + # Get the processed results directly - processed_images = result.get('results', []) + processed_images = result.get("results", []) if not processed_images: return [] - + return processed_images - - #Controlnet methods - def prepare_control_image(self, - control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], - preprocessor: Optional[Any], - target_width: int, - target_height: int) -> torch.Tensor: + + # Controlnet methods + def prepare_control_image( + self, + control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], + preprocessor: Optional[Any], + target_width: int, + target_height: int, + ) -> torch.Tensor: """ Prepare a single control image for ControlNet input with format conversion and preprocessing. - + Args: control_image: Input image in various formats preprocessor: Optional preprocessor to apply target_width: Target width for the output tensor target_height: Target height for the output tensor - + Returns: Processed tensor ready for ControlNet """ # Load image if path if isinstance(control_image, str): control_image = load_image(control_image) - + # Fast tensor processing path if isinstance(control_image, torch.Tensor): return self._process_tensor_input(control_image, preprocessor, target_width, target_height) - + # Apply preprocessor to non-tensor inputs if preprocessor is not None: control_image = preprocessor.process(control_image) - + # Convert to tensor return self._convert_to_tensor(control_image, target_width, target_height) - - def _process_multiple_controlnets_sync(self, - control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], - preprocessors: List[Optional[Any]], - scales: List[float], - stream_width: int, - stream_height: int) -> List[Optional[torch.Tensor]]: + + def _process_multiple_controlnets_sync( + self, + control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], + preprocessors: List[Optional[Any]], + scales: List[float], + stream_width: int, + stream_height: int, + ) -> List[Optional[torch.Tensor]]: """Process multiple ControlNets synchronously with parallel execution""" # Check if any processing is needed if not any(scale > 0 for scale in scales): return [None] * len(preprocessors) - - #TODO: can we reuse similarity filter here? + + # TODO: can we reuse similarity filter here? # Check cache for same input - return early without changing anything - if (self._last_input_frame is not None and - isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) and - control_image is self._last_input_frame): + if ( + self._last_input_frame is not None + and isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) + and control_image is self._last_input_frame + ): # Return empty list to signal no update needed return [] - + self._last_input_frame = control_image self.clear_cache() - + # Prepare input variants for optimal processing control_variants = self._prepare_input_variants(control_image, stream_width, stream_height) - + # Group preprocessors to avoid duplicate work preprocessor_groups = self._group_preprocessors(preprocessors, scales) - + if not preprocessor_groups: return [None] * len(preprocessors) - + # Process groups using parallel logic (efficient for 1 or many items) return self._process_controlnet_preprocessors_parallel( preprocessor_groups, control_variants, stream_width, stream_height, preprocessors ) - - def _process_single_controlnet(self, - control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], - preprocessors: List[Optional[Any]], - scales: List[float], - stream_width: int, - stream_height: int, - index: int) -> List[Optional[torch.Tensor]]: + + def _process_single_controlnet( + self, + control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], + preprocessors: List[Optional[Any]], + scales: List[float], + stream_width: int, + stream_height: int, + index: int, + ) -> List[Optional[torch.Tensor]]: """Process a single ControlNet by index""" if not (0 <= index < len(preprocessors)): raise IndexError(f"ControlNet index {index} out of range") - + if scales[index] == 0: return [None] * len(preprocessors) - + processed_images = [None] * len(preprocessors) - processed_image = self.prepare_control_image( - control_image, preprocessors[index], stream_width, stream_height - ) + processed_image = self.prepare_control_image(control_image, preprocessors[index], stream_width, stream_height) processed_images[index] = processed_image - + return processed_images - - def _process_controlnet_preprocessors_parallel(self, - preprocessor_groups: Dict[str, Dict[str, Any]], - control_variants: Dict[str, Any], - stream_width: int, - stream_height: int, - preprocessors: List[Optional[Any]]) -> List[Optional[torch.Tensor]]: + + def _process_controlnet_preprocessors_parallel( + self, + preprocessor_groups: Dict[str, Dict[str, Any]], + control_variants: Dict[str, Any], + stream_width: int, + stream_height: int, + preprocessors: List[Optional[Any]], + ) -> List[Optional[torch.Tensor]]: """Process ControlNet preprocessor groups in parallel""" futures = [ self._executor.submit( - self._process_single_preprocessor_group, - prep_key, group, control_variants, stream_width, stream_height + self._process_single_preprocessor_group, prep_key, group, control_variants, stream_width, stream_height ) for prep_key, group in preprocessor_groups.items() ] - + processed_images = [None] * len(preprocessors) - + for future in futures: result = future.result() - if result and result['processed_image'] is not None: - prep_key = result['prep_key'] - processed_image = result['processed_image'] - indices = result['indices'] - + if result and result["processed_image"] is not None: + prep_key = result["prep_key"] + processed_image = result["processed_image"] + indices = result["indices"] + # Cache and assign cache_key = f"prep_{prep_key}" self._preprocessed_cache[cache_key] = processed_image for index in indices: processed_images[index] = processed_image - + return processed_images - - #IPAdapter methods - def _process_multiple_ipadapters_sync(self, - control_image: ControlImage, - preprocessors: List[Optional[Any]], - stream_width: int, - stream_height: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: + + # IPAdapter methods + def _process_multiple_ipadapters_sync( + self, control_image: ControlImage, preprocessors: List[Optional[Any]], stream_width: int, stream_height: int + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """ Process IPAdapter preprocessors synchronously. - + This is the implementation that was previously in process_ipadapter_preprocessors(). """ if not preprocessors: return [] - + # For IPAdapter preprocessing, we don't skip on cache hits - we need the actual embeddings # (Unlike spatial preprocessing where empty list means "no update needed") - + # Prepare input variants for processing control_variants = self._prepare_input_variants(control_image, stream_width, stream_height) - + # Process using parallel logic (efficient for 1 or many items) results = self._process_ipadapter_preprocessors_parallel( preprocessors, control_variants, stream_width, stream_height ) - + return results - - def _process_ipadapter_preprocessors_parallel(self, - ipadapter_preprocessors: List[Any], - control_variants: Dict[str, Any], - stream_width: int, - stream_height: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: + + def _process_ipadapter_preprocessors_parallel( + self, + ipadapter_preprocessors: List[Any], + control_variants: Dict[str, Any], + stream_width: int, + stream_height: int, + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """Process multiple IPAdapter preprocessors in parallel""" futures = [ self._executor.submit( - self._process_single_ipadapter, - i, preprocessor, control_variants, stream_width, stream_height + self._process_single_ipadapter, i, preprocessor, control_variants, stream_width, stream_height ) for i, preprocessor in enumerate(ipadapter_preprocessors) ] - + results = [None] * len(ipadapter_preprocessors) - + for future in futures: result = future.result() - if result and result['embeddings'] is not None: - index = result['index'] - embeddings = result['embeddings'] + if result and result["embeddings"] is not None: + index = result["index"] + embeddings = result["embeddings"] results[index] = embeddings - + return results - - def _process_single_ipadapter(self, - index: int, - preprocessor: Any, - control_variants: Dict[str, Any], - stream_width: int, - stream_height: int) -> Optional[Dict[str, Any]]: + + def _process_single_ipadapter( + self, index: int, preprocessor: Any, control_variants: Dict[str, Any], stream_width: int, stream_height: int + ) -> Optional[Dict[str, Any]]: """Process a single IPAdapter preprocessor""" try: # Use tensor processing if available and input is tensor - if (hasattr(preprocessor, 'process_tensor') and - control_variants['tensor'] is not None): - embeddings = preprocessor.process_tensor(control_variants['tensor']) - return { - 'index': index, - 'embeddings': embeddings - } - + if hasattr(preprocessor, "process_tensor") and control_variants["tensor"] is not None: + embeddings = preprocessor.process_tensor(control_variants["tensor"]) + return {"index": index, "embeddings": embeddings} + # Use PIL processing for non-tensor inputs - if control_variants['image'] is not None: - embeddings = preprocessor.process(control_variants['image']) - return { - 'index': index, - 'embeddings': embeddings - } - + if control_variants["image"] is not None: + embeddings = preprocessor.process(control_variants["image"]) + return {"index": index, "embeddings": embeddings} + return None - - except Exception as e: + + except Exception: import traceback + traceback.print_exc() return None - #Helper methods + # Helper methods def _check_pipeline_aware_cached(self, preprocessors: List[Optional[Any]]) -> bool: """ Efficiently check for pipeline-aware preprocessors using caching - + Only performs expensive isinstance checks when preprocessor list actually changes. """ # Create cache key from preprocessor identities cache_key = tuple(id(p) for p in preprocessors) - + # Return cached result if preprocessors haven't changed if cache_key == self._preprocessors_cache_key: return self._has_feedback_cache # Reuse cache variable for backward compatibility - + # Preprocessors changed - recompute and cache self._preprocessors_cache_key = cache_key self._has_feedback_cache = False - + try: # Check for the mixin or class attribute first for prep in preprocessors: - if prep is not None and getattr(prep, 'requires_sync_processing', False): + if prep is not None and getattr(prep, "requires_sync_processing", False): self._has_feedback_cache = True break except Exception: @@ -491,6 +478,7 @@ def _check_pipeline_aware_cached(self, preprocessors: List[Optional[Any]]) -> bo try: from .processors.feedback import FeedbackPreprocessor from .processors.temporal_net import TemporalNetPreprocessor + for prep in preprocessors: if isinstance(prep, (FeedbackPreprocessor, TemporalNetPreprocessor)): self._has_feedback_cache = True @@ -500,44 +488,43 @@ def _check_pipeline_aware_cached(self, preprocessors: List[Optional[Any]]) -> bo for prep in preprocessors: if prep is not None: class_name = prep.__class__.__name__.lower() - if any(name in class_name for name in ['feedback', 'temporal']): + if any(name in class_name for name in ["feedback", "temporal"]): self._has_feedback_cache = True break - + return self._has_feedback_cache def clear_cache(self) -> None: """Clear preprocessing cache""" self._preprocessed_cache.clear() self._last_input_frame = None - + # ========================================================================= # Pipeline Chain Processing Methods (For Hook System Compatibility) # ========================================================================= - - def execute_pipeline_chain(self, - input_data: torch.Tensor, - processors: List[Any], - processing_domain: str = "image") -> torch.Tensor: + + def execute_pipeline_chain( + self, input_data: torch.Tensor, processors: List[Any], processing_domain: str = "image" + ) -> torch.Tensor: """Execute ordered sequential chain of processors for pipeline hooks. - + This method provides compatibility with the hook system modules that expect sequential processor execution rather than pipelined processing. - + Args: input_data: Input tensor (image or latent domain) processors: List of processor instances to execute in sequence processing_domain: "image" or "latent" to determine processing path - + Returns: Processed tensor in same domain as input """ if not processors: return input_data - + result = input_data ordered_processors = self._order_processors(processors) - + for processor in ordered_processors: try: if processing_domain == "image": @@ -550,58 +537,58 @@ def execute_pipeline_chain(self, logger.error(f"execute_pipeline_chain: Processor {type(processor).__name__} failed: {e}") # Continue with next processor rather than failing entire chain continue - + return result - + def _order_processors(self, processors: List[Any]) -> List[Any]: """Order processors based on their configuration. - + Processors can define an 'order' attribute to control execution sequence. """ - return sorted(processors, key=lambda p: getattr(p, 'order', 0)) - + return sorted(processors, key=lambda p: getattr(p, "order", 0)) + def _process_image_processor_chain(self, image_tensor: torch.Tensor, processor: Any) -> torch.Tensor: """Process single image processor in chain, handling tensor<->PIL conversion. - + Leverages existing format conversion and processing logic. """ # Convert tensor to PIL for processor (reuse existing conversion logic) try: # Use existing tensor to PIL conversion from prepare_control_image logic pil_image = self._tensor_to_pil_safe(image_tensor) - + # Process using existing processor execution pattern - if hasattr(processor, 'process'): + if hasattr(processor, "process"): processed_pil = processor.process(pil_image) else: processed_pil = processor(pil_image) - + # Convert back to tensor (reuse existing PIL to tensor logic) result_tensor = self._pil_to_tensor_safe(processed_pil, image_tensor.device, image_tensor.dtype) return result_tensor - + except Exception as e: logger.error(f"_process_image_processor_chain: Failed processing {type(processor).__name__}: {e}") return image_tensor # Return input unchanged on failure - + def _process_latent_processor_chain(self, latent_tensor: torch.Tensor, processor: Any) -> torch.Tensor: """Process single latent processor in chain. - + Direct tensor processing - no format conversion needed for latent domain. """ try: # Latent processors work directly on tensors - if hasattr(processor, 'process_tensor'): + if hasattr(processor, "process_tensor"): return processor.process_tensor(latent_tensor) - elif hasattr(processor, 'process'): + elif hasattr(processor, "process"): return processor.process(latent_tensor) else: return processor(latent_tensor) - + except Exception as e: logger.error(f"_process_latent_processor_chain: Failed processing {type(processor).__name__}: {e}") return latent_tensor # Return input unchanged on failure - + def _tensor_to_pil_safe(self, tensor: torch.Tensor) -> Image.Image: """Convert tensor to PIL Image safely (reuse existing conversion logic).""" # Leverage existing tensor conversion from prepare_control_image @@ -610,50 +597,48 @@ def _tensor_to_pil_safe(self, tensor: torch.Tensor) -> Image.Image: if tensor.dim() == 3 and tensor.shape[0] == 3: # Convert from CHW to HWC tensor = tensor.permute(1, 2, 0) - + # CRITICAL FIX: Handle VAE output range [-1, 1] -> [0, 1] -> [0, 255] # VAE decode_image() outputs in [-1, 1] range, need to convert to [0, 1] first if tensor.min() < 0: - logger.debug(f"_tensor_to_pil_safe: Converting from VAE range [-1, 1] to [0, 1]") + logger.debug("_tensor_to_pil_safe: Converting from VAE range [-1, 1] to [0, 1]") tensor = (tensor / 2.0 + 0.5).clamp(0, 1) # Convert [-1, 1] -> [0, 1] - + # Ensure proper range [0, 1] -> [0, 255] if tensor.max() <= 1.0: tensor = tensor * 255.0 - + # Convert to numpy and then PIL numpy_image = tensor.detach().cpu().numpy().astype(np.uint8) return Image.fromarray(numpy_image) - + def _pil_to_tensor_safe(self, pil_image: Image.Image, device: str, dtype: torch.dtype) -> torch.Tensor: """Convert PIL Image to tensor safely (reuse existing conversion logic).""" # Convert PIL to numpy numpy_image = np.array(pil_image) - + # Convert to tensor and normalize to [0, 1] tensor = torch.from_numpy(numpy_image).float() / 255.0 - + # Convert HWC to CHW if tensor.dim() == 3: tensor = tensor.permute(2, 0, 1) - + # Add batch dimension and move to device tensor = tensor.unsqueeze(0).to(device=device, dtype=dtype) - + # CRITICAL: Convert back to VAE input range [-1, 1] for postprocessing # VAE expects inputs in [-1, 1] range, so convert [0, 1] -> [-1, 1] tensor = (tensor - 0.5) * 2.0 # Convert [0, 1] -> [-1, 1] - + return tensor - - def _process_tensor_input(self, - control_tensor: torch.Tensor, - preprocessor: Optional[Any], - target_width: int, - target_height: int) -> torch.Tensor: + + def _process_tensor_input( + self, control_tensor: torch.Tensor, preprocessor: Optional[Any], target_width: int, target_height: int + ) -> torch.Tensor: """Process tensor input with GPU acceleration when possible""" # Fast path for tensor input with GPU preprocessor - if preprocessor is not None and hasattr(preprocessor, 'process_tensor'): + if preprocessor is not None and hasattr(preprocessor, "process_tensor"): try: processed_tensor = preprocessor.process_tensor(control_tensor) # Ensure NCHW shape @@ -662,155 +647,139 @@ def _process_tensor_input(self, return processed_tensor.to(device=self.device, dtype=self.dtype) except Exception: pass # Fall through to standard processing - + # Direct tensor passthrough (no preprocessor) - preprocessors handle their own sizing if preprocessor is None: # For passthrough, we still need basic format handling if control_tensor.dim() == 3: control_tensor = control_tensor.unsqueeze(0) return control_tensor.to(device=self.device, dtype=self.dtype) - + # Convert to PIL for preprocessor, then back to tensor if control_tensor.dim() == 4: control_tensor = control_tensor[0] if control_tensor.dim() == 3 and control_tensor.shape[0] in [1, 3]: control_tensor = control_tensor.permute(1, 2, 0) - + if control_tensor.is_cuda: control_tensor = control_tensor.cpu() - + control_array = control_tensor.numpy() if control_array.max() <= 1.0: control_array = (control_array * 255).astype(np.uint8) - + control_image = Image.fromarray(control_array.astype(np.uint8)) return self.prepare_control_image(control_image, preprocessor, target_width, target_height) - - def _convert_to_tensor(self, - control_image: Union[Image.Image, np.ndarray], - target_width: int, - target_height: int) -> torch.Tensor: + + def _convert_to_tensor( + self, control_image: Union[Image.Image, np.ndarray], target_width: int, target_height: int + ) -> torch.Tensor: """Convert PIL Image or numpy array to tensor - preprocessors handle their own sizing""" # Handle PIL Images - no resizing here, preprocessors handle their target size if isinstance(control_image, Image.Image): control_tensor = self._cached_transform(control_image).unsqueeze(0) return control_tensor.to(device=self.device, dtype=self.dtype) - + # Handle numpy arrays if isinstance(control_image, np.ndarray): if control_image.max() <= 1.0: control_image = (control_image * 255).astype(np.uint8) control_image = Image.fromarray(control_image) return self._convert_to_tensor(control_image, target_width, target_height) - + raise ValueError(f"Unsupported control image type: {type(control_image)}") - + def _to_tensor_safe(self, image: Image.Image) -> torch.Tensor: """Thread-safe tensor conversion from PIL Image""" return self._cached_transform(image).unsqueeze(0).to(device=self.device, dtype=self.dtype) - - def _prepare_input_variants(self, - control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], - stream_width: int = None, - stream_height: int = None, - thread_safe: bool = False) -> Dict[str, Any]: + + def _prepare_input_variants( + self, + control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], + stream_width: int = None, + stream_height: int = None, + thread_safe: bool = False, + ) -> Dict[str, Any]: """Prepare optimized input variants for different processing paths - + Args: control_image: Input image in various formats stream_width: Target width (unused, kept for backward compatibility) stream_height: Target height (unused, kept for backward compatibility) thread_safe: If True, use thread-safe key naming for background processing - + Returns: Dictionary with 'tensor' and 'image'/'image_safe' keys """ - image_key = 'image_safe' if thread_safe else 'image' - + image_key = "image_safe" if thread_safe else "image" + if isinstance(control_image, torch.Tensor): return { - 'tensor': control_image, - image_key: None # Will create if needed + "tensor": control_image, + image_key: None, # Will create if needed } elif isinstance(control_image, Image.Image): image_copy = control_image.copy() - return { - image_key: image_copy, - 'tensor': self._to_tensor_safe(image_copy) - } + return {image_key: image_copy, "tensor": self._to_tensor_safe(image_copy)} elif isinstance(control_image, str): image_loaded = load_image(control_image) - return { - image_key: image_loaded, - 'tensor': self._to_tensor_safe(image_loaded) - } + return {image_key: image_loaded, "tensor": self._to_tensor_safe(image_loaded)} else: - return { - image_key: control_image, - 'tensor': None - } - - def _group_preprocessors(self, - preprocessors: List[Optional[Any]], - scales: List[float]) -> Dict[str, Dict[str, Any]]: + return {image_key: control_image, "tensor": None} + + def _group_preprocessors( + self, preprocessors: List[Optional[Any]], scales: List[float] + ) -> Dict[str, Dict[str, Any]]: """Group preprocessors by type to avoid duplicate processing""" preprocessor_groups = {} - + for i, scale in enumerate(scales): if scale > 0: preprocessor = preprocessors[i] - preprocessor_key = id(preprocessor) if preprocessor is not None else 'passthrough' - + preprocessor_key = id(preprocessor) if preprocessor is not None else "passthrough" + if preprocessor_key not in preprocessor_groups: - preprocessor_groups[preprocessor_key] = { - 'preprocessor': preprocessor, - 'indices': [] - } - preprocessor_groups[preprocessor_key]['indices'].append(i) - + preprocessor_groups[preprocessor_key] = {"preprocessor": preprocessor, "indices": []} + preprocessor_groups[preprocessor_key]["indices"].append(i) + return preprocessor_groups - def _process_single_preprocessor_group(self, - prep_key: str, - group: Dict[str, Any], - control_variants: Dict[str, Any], - stream_width: int, - stream_height: int) -> Optional[Dict[str, Any]]: + def _process_single_preprocessor_group( + self, + prep_key: str, + group: Dict[str, Any], + control_variants: Dict[str, Any], + stream_width: int, + stream_height: int, + ) -> Optional[Dict[str, Any]]: """Process a single preprocessor group with optimal input selection""" try: - preprocessor = group['preprocessor'] - indices = group['indices'] - + preprocessor = group["preprocessor"] + indices = group["indices"] + # Try tensor processing first (fastest path) - if (preprocessor is not None and - hasattr(preprocessor, 'process_tensor') and - control_variants['tensor'] is not None): + if ( + preprocessor is not None + and hasattr(preprocessor, "process_tensor") + and control_variants["tensor"] is not None + ): try: processed_image = self.prepare_control_image( - control_variants['tensor'], preprocessor, stream_width, stream_height + control_variants["tensor"], preprocessor, stream_width, stream_height ) - return { - 'prep_key': prep_key, - 'indices': indices, - 'processed_image': processed_image - } + return {"prep_key": prep_key, "indices": indices, "processed_image": processed_image} except Exception: pass # Fall through to PIL processing - + # PIL processing fallback - if control_variants['image'] is not None: + if control_variants["image"] is not None: processed_image = self.prepare_control_image( - control_variants['image'], preprocessor, stream_width, stream_height + control_variants["image"], preprocessor, stream_width, stream_height ) - return { - 'prep_key': prep_key, - 'indices': indices, - 'processed_image': processed_image - } - + return {"prep_key": prep_key, "indices": indices, "processed_image": processed_image} + return None - + except Exception as e: logger.error(f"PreprocessingOrchestrator: Preprocessor {prep_key} failed: {e}") return None - diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index b8f68b67..e7c67e45 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -1,15 +1,18 @@ -from typing import List, Optional, Dict, Tuple, Literal, Any, Callable +import logging import threading +from typing import Any, Dict, List, Literal, Optional, Tuple + import torch import torch.nn.functional as F -import gc -import logging + logger = logging.getLogger(__name__) from .preprocessing.orchestrator_user import OrchestratorUser + class CacheStats: """Helper class to track cache statistics""" + def __init__(self): self.hits = 0 self.misses = 0 @@ -22,7 +25,13 @@ def record_miss(self): class StreamParameterUpdater(OrchestratorUser): - def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True): + def __init__( + self, + stream_diffusion, + wrapper=None, + normalize_prompt_weights: bool = True, + normalize_seed_weights: bool = True, + ): self.stream = stream_diffusion self.wrapper = wrapper # Reference to wrapper for accessing pipeline structure self.normalize_prompt_weights = normalize_prompt_weights @@ -39,8 +48,7 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo self._seed_cache: Dict[int, Dict] = {} self._current_seed_list: List[Tuple[int, float]] = [] self._seed_cache_stats = CacheStats() - - + # Attach shared orchestrator once (lazy-creates on stream if absent) self.attach_orchestrator(self.stream) @@ -50,6 +58,7 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo self._current_style_images: Dict[str, Any] = {} # Use the shared orchestrator attached via OrchestratorUser self._embedding_orchestrator = self._preprocessing_orchestrator + def get_cache_info(self) -> Dict: """Get cache statistics for monitoring performance.""" total_requests = self._prompt_cache_stats.hits + self._prompt_cache_stats.misses @@ -68,7 +77,7 @@ def get_cache_info(self) -> Dict: "seed_cache_hits": self._seed_cache_stats.hits, "seed_cache_misses": self._seed_cache_stats.misses, "seed_hit_rate": f"{seed_hit_rate:.2%}", - "current_seeds": len(self._current_seed_list) + "current_seeds": len(self._current_seed_list), } def clear_caches(self) -> None: @@ -81,7 +90,7 @@ def clear_caches(self) -> None: self._seed_cache.clear() self._current_seed_list.clear() self._seed_cache_stats = CacheStats() - + # Clear embedding caches self._embedding_cache.clear() self._current_style_images.clear() @@ -93,13 +102,13 @@ def get_normalize_prompt_weights(self) -> bool: def get_normalize_seed_weights(self) -> bool: """Get the current seed weight normalization setting.""" return self.normalize_seed_weights - + # Deprecated enhancer registration removed; embedding composition is handled via stream.embedding_hooks def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: str) -> None: """ Register an embedding preprocessor for parallel processing. - + Args: preprocessor: IPAdapterEmbeddingPreprocessor instance style_image_key: Unique key for the style image this preprocessor handles @@ -108,28 +117,27 @@ def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: st # Ensure orchestrator is present self.attach_orchestrator(self.stream) self._embedding_orchestrator = self._preprocessing_orchestrator - + self._embedding_preprocessors.append((preprocessor, style_image_key)) - + def unregister_embedding_preprocessor(self, style_image_key: str) -> None: """Unregister an embedding preprocessor by style image key.""" original_count = len(self._embedding_preprocessors) self._embedding_preprocessors = [ - (preprocessor, key) for preprocessor, key in self._embedding_preprocessors - if key != style_image_key + (preprocessor, key) for preprocessor, key in self._embedding_preprocessors if key != style_image_key ] removed_count = original_count - len(self._embedding_preprocessors) - + # Clear cached embeddings for this key if style_image_key in self._embedding_cache: del self._embedding_cache[style_image_key] if style_image_key in self._current_style_images: del self._current_style_images[style_image_key] - + def update_style_image(self, style_image_key: str, style_image: Any, is_stream: bool = False) -> None: """ Update a style image and trigger embedding preprocessing. - + Args: style_image_key: Unique key for the style image style_image: The style image (PIL Image, path, etc.) @@ -138,14 +146,16 @@ def update_style_image(self, style_image_key: str, style_image: Any, is_stream: """ # Store the style image self._current_style_images[style_image_key] = style_image - + # Trigger preprocessing for this style image self._preprocess_style_image_parallel(style_image_key, style_image, is_stream) - - def _preprocess_style_image_parallel(self, style_image_key: str, style_image: Any, is_stream: bool = False) -> None: + + def _preprocess_style_image_parallel( + self, style_image_key: str, style_image: Any, is_stream: bool = False + ) -> None: """ Preprocessing for a specific style image with mode selection - + Args: style_image_key: Unique key for the style image style_image: The style image to process @@ -153,57 +163,47 @@ def _preprocess_style_image_parallel(self, style_image_key: str, style_image: An """ if not self._embedding_preprocessors or self._embedding_orchestrator is None: return - + # Find preprocessors for this key relevant_preprocessors = [ - preprocessor for preprocessor, key in self._embedding_preprocessors - if key == style_image_key + preprocessor for preprocessor, key in self._embedding_preprocessors if key == style_image_key ] - + if not relevant_preprocessors: return - + # Choose processing mode based on is_stream parameter try: if is_stream: # Pipelined processing - optimized for throughput with 1-frame lag embedding_results = self._embedding_orchestrator.process_pipelined( - style_image, - relevant_preprocessors, - None, - self.stream.width, - self.stream.height, - "ipadapter" + style_image, relevant_preprocessors, None, self.stream.width, self.stream.height, "ipadapter" ) else: # Synchronous processing - immediate results for discrete updates embedding_results = self._embedding_orchestrator.process_sync( - style_image, - relevant_preprocessors, - None, - self.stream.width, - self.stream.height, - None, - "ipadapter" + style_image, relevant_preprocessors, None, self.stream.width, self.stream.height, None, "ipadapter" ) - + # Cache results for this style image key if embedding_results and embedding_results[0] is not None: self._embedding_cache[style_image_key] = embedding_results[0] else: # This is an error condition - we should always have results - raise RuntimeError(f"_preprocess_style_image_parallel: Failed to generate embeddings for style image '{style_image_key}'") - - except Exception as e: + raise RuntimeError( + f"_preprocess_style_image_parallel: Failed to generate embeddings for style image '{style_image_key}'" + ) + + except Exception: import traceback + traceback.print_exc() - + def get_cached_embeddings(self, style_image_key: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: """Get cached embeddings for a style image key""" cached_result = self._embedding_cache.get(style_image_key, None) return cached_result - def _normalize_weights(self, weights: List[float], normalize: bool) -> torch.Tensor: """Generic weight normalization helper""" weights_tensor = torch.tensor(weights, device=self.stream.device, dtype=self.stream.dtype) @@ -218,7 +218,7 @@ def _validate_index(self, index: int, item_list: List, operation_name: str) -> b return False if index < 0 or index >= len(item_list): - logger.warning(f"{operation_name}: Warning: Index {index} out of range (0-{len(item_list)-1})") + logger.warning(f"{operation_name}: Warning: Index {index} out of range (0-{len(item_list) - 1})") return False return True @@ -281,28 +281,27 @@ def update_stream_params( f"provided t_index_list (max index: {max_t_index}). Adjusting to {max_t_index + 1}." ) num_inference_steps = max_t_index + 1 - + old_num_steps = len(self.stream.timesteps) self.stream.scheduler.set_timesteps(num_inference_steps, self.stream.device) self.stream.timesteps = self.stream.scheduler.timesteps.to(self.stream.device) - + # If t_index_list wasn't explicitly provided, rescale existing t_list proportionally if t_index_list is None and old_num_steps > 0: # Rescale each index proportionally to the new number of steps # e.g., if t_list = [0, 16, 32, 45] with 50 steps -> [0, 3, 6, 8] with 9 steps scale_factor = (num_inference_steps - 1) / (old_num_steps - 1) if old_num_steps > 1 else 1.0 - t_index_list = [ - min(round(t * scale_factor), num_inference_steps - 1) - for t in self.stream.t_list - ] - + t_index_list = [min(round(t * scale_factor), num_inference_steps - 1) for t in self.stream.t_list] + # Now update timestep-dependent parameters with the correct t_index_list if t_index_list is not None: self._recalculate_timestep_dependent_params(t_index_list) if guidance_scale is not None: if self.stream.cfg_type == "none" and guidance_scale > 1.0: - logger.warning("update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect") + logger.warning( + "update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect" + ) self.stream.guidance_scale = guidance_scale if delta is not None: @@ -310,7 +309,7 @@ def update_stream_params( if seed is not None: self._update_seed(seed) - + if normalize_prompt_weights is not None: self.normalize_prompt_weights = normalize_prompt_weights logger.info(f"update_stream_params: Prompt weight normalization set to {normalize_prompt_weights}") @@ -324,44 +323,42 @@ def update_stream_params( self._update_blended_prompts( prompt_list=prompt_list, negative_prompt=negative_prompt or self._current_negative_prompt, - prompt_interpolation_method=prompt_interpolation_method + prompt_interpolation_method=prompt_interpolation_method, ) # Handle seed blending if seed_list is provided if seed_list is not None: - self._update_blended_seeds( - seed_list=seed_list, - interpolation_method=seed_interpolation_method - ) - + self._update_blended_seeds(seed_list=seed_list, interpolation_method=seed_interpolation_method) # Handle ControlNet configuration updates if controlnet_config is not None: - #TODO: happy path for control images + # TODO: happy path for control images self._update_controlnet_config(controlnet_config) - + # Handle IPAdapter configuration updates if ipadapter_config is not None: - logger.info(f"update_stream_params: Updating IPAdapter configuration") + logger.info("update_stream_params: Updating IPAdapter configuration") self._update_ipadapter_config(ipadapter_config) - + # Handle Hook configuration updates if image_preprocessing_config is not None: - logger.info(f"update_stream_params: Updating image preprocessing configuration with {len(image_preprocessing_config)} processors") + logger.info( + f"update_stream_params: Updating image preprocessing configuration with {len(image_preprocessing_config)} processors" + ) logger.info(f"update_stream_params: image_preprocessing_config = {image_preprocessing_config}") - self._update_hook_config('image_preprocessing', image_preprocessing_config) - + self._update_hook_config("image_preprocessing", image_preprocessing_config) + if image_postprocessing_config is not None: - logger.info(f"update_stream_params: Updating image postprocessing configuration") - self._update_hook_config('image_postprocessing', image_postprocessing_config) - + logger.info("update_stream_params: Updating image postprocessing configuration") + self._update_hook_config("image_postprocessing", image_postprocessing_config) + if latent_preprocessing_config is not None: - logger.info(f"update_stream_params: Updating latent preprocessing configuration") - self._update_hook_config('latent_preprocessing', latent_preprocessing_config) - + logger.info("update_stream_params: Updating latent preprocessing configuration") + self._update_hook_config("latent_preprocessing", latent_preprocessing_config) + if latent_postprocessing_config is not None: - logger.info(f"update_stream_params: Updating latent postprocessing configuration") - self._update_hook_config('latent_postprocessing', latent_postprocessing_config) + logger.info("update_stream_params: Updating latent postprocessing configuration") + self._update_hook_config("latent_postprocessing", latent_postprocessing_config) if self.stream.kvo_cache: if cache_interval is not None: @@ -375,9 +372,7 @@ def update_stream_params( # runtime — resizing one-at-a-time races with TRT inference (causes "Dimensions # with name C must be equal" errors). cache_maxframes is a logical write window. actual_cache_size = ( - self.stream.kvo_cache[0].shape[1] - if self.stream.kvo_cache - else cache_maxframes + self.stream.kvo_cache[0].shape[1] if self.stream.kvo_cache else cache_maxframes ) if cache_maxframes > actual_cache_size: logger.warning( @@ -395,9 +390,7 @@ def update_stream_params( @torch.inference_mode() def update_prompt_weights( - self, - prompt_weights: List[float], - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, prompt_weights: List[float], prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Update weights for current prompt list without re-encoding prompts.""" if not self._current_prompt_list: @@ -405,7 +398,9 @@ def update_prompt_weights( return if len(prompt_weights) != len(self._current_prompt_list): - logger.warning(f"update_prompt_weights: Warning: Weight count {len(prompt_weights)} doesn't match prompt count {len(self._current_prompt_list)}") + logger.warning( + f"update_prompt_weights: Warning: Weight count {len(prompt_weights)} doesn't match prompt count {len(self._current_prompt_list)}" + ) return # Update the current prompt list with new weights @@ -420,9 +415,7 @@ def update_prompt_weights( @torch.inference_mode() def update_seed_weights( - self, - seed_weights: List[float], - interpolation_method: Literal["linear", "slerp"] = "linear" + self, seed_weights: List[float], interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Update weights for current seed list without regenerating noise.""" if not self._current_seed_list: @@ -430,7 +423,9 @@ def update_seed_weights( return if len(seed_weights) != len(self._current_seed_list): - logger.warning(f"update_seed_weights: Warning: Weight count {len(seed_weights)} doesn't match seed count {len(self._current_seed_list)}") + logger.warning( + f"update_seed_weights: Warning: Weight count {len(seed_weights)} doesn't match seed count {len(self._current_seed_list)}" + ) return # Update the current seed list with new weights @@ -448,7 +443,7 @@ def _update_blended_prompts( self, prompt_list: List[Tuple[str, float]], negative_prompt: str = "", - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + prompt_interpolation_method: Literal["linear", "slerp"] = "slerp", ) -> None: """Update prompt embeddings using multiple weighted prompts.""" # Store current state @@ -461,14 +456,10 @@ def _update_blended_prompts( # Apply blending self._apply_prompt_blending(prompt_interpolation_method) - def _cache_prompt_embeddings( - self, - prompt_list: List[Tuple[str, float]], - negative_prompt: str - ) -> None: + def _cache_prompt_embeddings(self, prompt_list: List[Tuple[str, float]], negative_prompt: str) -> None: """Cache prompt embeddings for efficient reuse.""" for idx, (prompt_text, weight) in enumerate(prompt_list): - if idx not in self._prompt_cache or self._prompt_cache[idx]['text'] != prompt_text: + if idx not in self._prompt_cache or self._prompt_cache[idx]["text"] != prompt_text: # Cache miss - encode the prompt self._prompt_cache_stats.record_miss() encoder_output = self.stream.pipe.encode_prompt( @@ -482,10 +473,7 @@ def _cache_prompt_embeddings( if len(self._prompt_cache) >= 32: oldest_key = next(iter(self._prompt_cache)) del self._prompt_cache[oldest_key] - self._prompt_cache[idx] = { - 'embed': encoder_output[0], - 'text': prompt_text - } + self._prompt_cache[idx] = {"embed": encoder_output[0], "text": prompt_text} else: # Cache hit self._prompt_cache_stats.record_hit() @@ -500,7 +488,7 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", for idx, (prompt_text, weight) in enumerate(self._current_prompt_list): if idx in self._prompt_cache: - embeddings.append(self._prompt_cache[idx]['embed']) + embeddings.append(self._prompt_cache[idx]["embed"]) weights.append(weight) if not embeddings: @@ -545,13 +533,14 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", # No CFG, just use the blended embeddings final_prompt_embeds = combined_embeds.repeat(self.stream.batch_size, 1, 1) final_negative_embeds = None # Will be set by enhancers if needed - + # Enhancer mechanism removed in favor of embedding_hooks # Run embedding hooks to compose final embeddings (e.g., append IP-Adapter tokens) try: - if hasattr(self.stream, 'embedding_hooks') and self.stream.embedding_hooks: + if hasattr(self.stream, "embedding_hooks") and self.stream.embedding_hooks: from .hooks import EmbedsCtx # local import to avoid cycles + embeds_ctx = EmbedsCtx( prompt_embeds=final_prompt_embeds, negative_prompt_embeds=final_negative_embeds, @@ -562,8 +551,9 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", final_negative_embeds = embeds_ctx.negative_prompt_embeds except Exception as e: import logging + logging.getLogger(__name__).error(f"_apply_prompt_blending: embedding hook failed: {e}") - + # Set final embeddings on stream self.stream.prompt_embeds = final_prompt_embeds if final_negative_embeds is not None: @@ -604,9 +594,7 @@ def _slerp(self, embed1: torch.Tensor, embed2: torch.Tensor, t: float) -> torch. @torch.inference_mode() def _update_blended_seeds( - self, - seed_list: List[Tuple[int, float]], - interpolation_method: Literal["linear", "slerp"] = "linear" + self, seed_list: List[Tuple[int, float]], interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Update seed tensors using multiple weighted seeds.""" # Store current state @@ -621,7 +609,7 @@ def _update_blended_seeds( def _cache_seed_noise(self, seed_list: List[Tuple[int, float]]) -> None: """Cache seed noise tensors for efficient reuse.""" for idx, (seed_value, weight) in enumerate(seed_list): - if idx not in self._seed_cache or self._seed_cache[idx]['seed'] != seed_value: + if idx not in self._seed_cache or self._seed_cache[idx]["seed"] != seed_value: # Cache miss - generate noise for the seed self._seed_cache_stats.record_miss() generator = torch.Generator(device=self.stream.device) @@ -631,13 +619,10 @@ def _cache_seed_noise(self, seed_list: List[Tuple[int, float]]) -> None: (self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width), generator=generator, device=self.stream.device, - dtype=self.stream.dtype + dtype=self.stream.dtype, ) - self._seed_cache[idx] = { - 'noise': noise, - 'seed': seed_value - } + self._seed_cache[idx] = {"noise": noise, "seed": seed_value} else: # Cache hit self._seed_cache_stats.record_hit() @@ -652,7 +637,7 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"]) for idx, (seed_value, weight) in enumerate(self._current_seed_list): if idx in self._seed_cache: - noise_tensors.append(self._seed_cache[idx]['noise']) + noise_tensors.append(self._seed_cache[idx]["noise"]) weights.append(weight) if not noise_tensors: @@ -673,7 +658,7 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"]) combined_noise = torch.zeros_like(noise_tensors[0]) for noise, weight in zip(noise_tensors, weights): combined_noise += weight * noise - + # Preserve noise magnitude when weights are normalized if self.normalize_seed_weights and len(noise_tensors) > 1: original_magnitude = torch.mean(torch.stack([torch.norm(noise) for noise in noise_tensors])) @@ -748,6 +733,7 @@ def _update_seed(self, seed: int) -> None: def _get_scheduler_scalings(self, timestep): """Get LCM/TCD-specific scaling factors for boundary conditions.""" from diffusers import LCMScheduler + if isinstance(self.stream.scheduler, LCMScheduler): c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) return c_skip, c_out @@ -765,9 +751,7 @@ def _update_timestep_calculations(self) -> None: for t in self.stream.t_list: self.stream.sub_timesteps.append(self.stream.timesteps[t]) - sub_timesteps_tensor = torch.tensor( - self.stream.sub_timesteps, dtype=torch.long, device=self.stream.device - ) + sub_timesteps_tensor = torch.tensor(self.stream.sub_timesteps, dtype=torch.long, device=self.stream.device) self.stream.sub_timesteps_tensor = torch.repeat_interleave( sub_timesteps_tensor, repeats=self.stream.frame_bff_size if self.stream.use_denoising_batch else 1, @@ -793,12 +777,8 @@ def _update_timestep_calculations(self) -> None: ) if self.stream.use_denoising_batch: - self.stream.c_skip = torch.repeat_interleave( - self.stream.c_skip, repeats=self.stream.frame_bff_size, dim=0 - ) - self.stream.c_out = torch.repeat_interleave( - self.stream.c_out, repeats=self.stream.frame_bff_size, dim=0 - ) + self.stream.c_skip = torch.repeat_interleave(self.stream.c_skip, repeats=self.stream.frame_bff_size, dim=0) + self.stream.c_out = torch.repeat_interleave(self.stream.c_out, repeats=self.stream.frame_bff_size, dim=0) # Update alpha_prod_t_sqrt and beta_prod_t_sqrt alpha_prod_t_sqrt_list = [] @@ -838,29 +818,25 @@ def _update_timestep_values_only(self, t_index_list: List[int]) -> None: def _recalculate_timestep_dependent_params(self, t_index_list: List[int]) -> None: """Recalculate all parameters that depend on t_index_list.""" - + # Check if this is a structural change (length) or just value change if len(t_index_list) == len(self.stream.t_list): # Same length - only values changed, use lightweight update (working branch behavior) self._update_timestep_values_only(t_index_list) return - + # Length changed - do full recalculation including batch-dependent parameters (broken branch logic - but it works for this case!) self.stream.t_list = t_index_list self.stream.denoising_steps_num = len(self.stream.t_list) old_batch_size = self.stream.batch_size - + if self.stream.use_denoising_batch: self.stream.batch_size = self.stream.denoising_steps_num * self.stream.frame_bff_size if self.stream.cfg_type == "initialize": - self.stream.trt_unet_batch_size = ( - self.stream.denoising_steps_num + 1 - ) * self.stream.frame_bff_size + self.stream.trt_unet_batch_size = (self.stream.denoising_steps_num + 1) * self.stream.frame_bff_size elif self.stream.cfg_type == "full": - self.stream.trt_unet_batch_size = ( - 2 * self.stream.denoising_steps_num * self.stream.frame_bff_size - ) + self.stream.trt_unet_batch_size = 2 * self.stream.denoising_steps_num * self.stream.frame_bff_size else: self.stream.trt_unet_batch_size = self.stream.denoising_steps_num * self.stream.frame_bff_size else: @@ -891,23 +867,36 @@ def _recalculate_timestep_dependent_params(self, t_index_list: List[int]) -> Non # Resize kvo_cache tensors if batch size changed if self.stream.kvo_cache and old_batch_size != self.stream.batch_size: - logger.info(f"_recalculate_timestep_dependent_params: Resizing kvo_cache tensors from batch_size {old_batch_size} to {self.stream.batch_size}") + logger.info( + f"_recalculate_timestep_dependent_params: Resizing kvo_cache tensors from batch_size {old_batch_size} to {self.stream.batch_size}" + ) for i, cache_tensor in enumerate(self.stream.kvo_cache): # KVO cache shape: (2, cache_maxframes, batch_size, seq_length, hidden_dim) current_shape = cache_tensor.shape - new_shape = (current_shape[0], current_shape[1], self.stream.batch_size, current_shape[3], current_shape[4]) - new_cache_tensor = torch.zeros( - new_shape, - dtype=cache_tensor.dtype, - device=cache_tensor.device + new_shape = ( + current_shape[0], + current_shape[1], + self.stream.batch_size, + current_shape[3], + current_shape[4], ) - + new_cache_tensor = torch.zeros(new_shape, dtype=cache_tensor.dtype, device=cache_tensor.device) + # Copy over as much data as possible from old cache min_batch = min(old_batch_size, self.stream.batch_size) new_cache_tensor[:, :, :min_batch, :, :] = cache_tensor[:, :, :min_batch, :, :] - + self.stream.kvo_cache[i] = new_cache_tensor - logger.info(f"_recalculate_timestep_dependent_params: KVO cache tensors resized to new batch_size {self.stream.batch_size}") + + # Resize replaces per-layer tensors with fresh standalone allocations, + # detaching them from the bucketed storage. Drop the bucket refs so + # update_kvo_cache falls back to per-layer writes against the new tensors. + self.stream._kvo_buckets = None + self.stream._kvo_outputs_by_bucket = None + + logger.info( + f"_recalculate_timestep_dependent_params: KVO cache tensors resized to new batch_size {self.stream.batch_size}" + ) # Update timestep-dependent calculations (shared with value-only path) self._update_timestep_calculations() @@ -926,10 +915,7 @@ def _recalculate_controlnet_inputs(self, width: int, height: int) -> None: @torch.inference_mode() def update_prompt_at_index( - self, - index: int, - new_prompt: str, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, index: int, new_prompt: str, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Update a single prompt at the specified index without re-encoding others.""" if not self._validate_index(index, self._current_prompt_list, "update_prompt_at_index"): @@ -943,11 +929,11 @@ def update_prompt_at_index( self._cache_prompt_embeddings([(new_prompt, weight)], self._current_negative_prompt) # Update cache index to point to the new prompt - if index in self._prompt_cache and self._prompt_cache[index]['text'] != new_prompt: + if index in self._prompt_cache and self._prompt_cache[index]["text"] != new_prompt: # Find if this prompt is already cached elsewhere existing_cache_key = None for cache_idx, cache_data in self._prompt_cache.items(): - if cache_data['text'] == new_prompt: + if cache_data["text"] == new_prompt: existing_cache_key = cache_idx break @@ -965,10 +951,7 @@ def update_prompt_at_index( do_classifier_free_guidance=False, negative_prompt=self._current_negative_prompt, ) - self._prompt_cache[index] = { - 'embed': encoder_output[0], - 'text': new_prompt - } + self._prompt_cache[index] = {"embed": encoder_output[0], "text": new_prompt} # Recompute blended embeddings with updated prompt self._apply_prompt_blending(prompt_interpolation_method) @@ -980,16 +963,12 @@ def get_current_prompts(self) -> List[Tuple[str, float]]: @torch.inference_mode() def add_prompt( - self, - prompt: str, - weight: float = 1.0, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, prompt: str, weight: float = 1.0, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Add a new prompt to the current list.""" new_index = len(self._current_prompt_list) self._current_prompt_list.append((prompt, weight)) - # Cache the new prompt encoder_output = self.stream.pipe.encode_prompt( prompt=prompt, @@ -998,10 +977,7 @@ def add_prompt( do_classifier_free_guidance=False, negative_prompt=self._current_negative_prompt, ) - self._prompt_cache[new_index] = { - 'embed': encoder_output[0], - 'text': prompt - } + self._prompt_cache[new_index] = {"embed": encoder_output[0], "text": prompt} self._prompt_cache_stats.record_miss() # Recompute blended embeddings @@ -1009,9 +985,7 @@ def add_prompt( @torch.inference_mode() def remove_prompt_at_index( - self, - index: int, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, index: int, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Remove a prompt at the specified index.""" if not self._validate_index(index, self._current_prompt_list, "remove_prompt_at_index"): @@ -1036,10 +1010,7 @@ def remove_prompt_at_index( @torch.inference_mode() def update_seed_at_index( - self, - index: int, - new_seed: int, - interpolation_method: Literal["linear", "slerp"] = "linear" + self, index: int, new_seed: int, interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Update a single seed at the specified index without regenerating others.""" if not self._validate_index(index, self._current_seed_list, "update_seed_at_index"): @@ -1049,16 +1020,15 @@ def update_seed_at_index( old_seed, weight = self._current_seed_list[index] self._current_seed_list[index] = (new_seed, weight) - # Cache the new seed noise self._cache_seed_noise([(new_seed, weight)]) # Update cache index to point to the new seed - if index in self._seed_cache and self._seed_cache[index]['seed'] != new_seed: + if index in self._seed_cache and self._seed_cache[index]["seed"] != new_seed: # Find if this seed is already cached elsewhere existing_cache_key = None for cache_idx, cache_data in self._seed_cache.items(): - if cache_data['seed'] == new_seed: + if cache_data["seed"] == new_seed: existing_cache_key = cache_idx break @@ -1076,13 +1046,10 @@ def update_seed_at_index( (self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width), generator=generator, device=self.stream.device, - dtype=self.stream.dtype + dtype=self.stream.dtype, ) - self._seed_cache[index] = { - 'noise': noise, - 'seed': new_seed - } + self._seed_cache[index] = {"noise": noise, "seed": new_seed} # Recompute blended noise with updated seed self._apply_seed_blending(interpolation_method) @@ -1094,10 +1061,7 @@ def get_current_seeds(self) -> List[Tuple[int, float]]: @torch.inference_mode() def add_seed( - self, - seed: int, - weight: float = 1.0, - interpolation_method: Literal["linear", "slerp"] = "linear" + self, seed: int, weight: float = 1.0, interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Add a new seed to the current list.""" new_index = len(self._current_seed_list) @@ -1113,24 +1077,17 @@ def add_seed( (self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width), generator=generator, device=self.stream.device, - dtype=self.stream.dtype + dtype=self.stream.dtype, ) - self._seed_cache[new_index] = { - 'noise': noise, - 'seed': seed - } + self._seed_cache[new_index] = {"noise": noise, "seed": seed} self._seed_cache_stats.record_miss() # Recompute blended noise self._apply_seed_blending(interpolation_method) @torch.inference_mode() - def remove_seed_at_index( - self, - index: int, - interpolation_method: Literal["linear", "slerp"] = "linear" - ) -> None: + def remove_seed_at_index(self, index: int, interpolation_method: Literal["linear", "slerp"] = "linear") -> None: """Remove a seed at the specified index.""" if not self._validate_index(index, self._current_seed_list, "remove_seed_at_index"): return @@ -1155,7 +1112,7 @@ def remove_seed_at_index( def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> None: """ Update ControlNet configuration by diffing current vs desired state. - + Args: desired_config: Complete ControlNet configuration list defining the desired state. Each dict contains: model_id, preprocessor, conditioning_scale, enabled, etc. @@ -1163,41 +1120,47 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non # Find the ControlNet pipeline/module (module-aware) controlnet_pipeline = self._get_controlnet_pipeline() if not controlnet_pipeline: - logger.debug("_update_controlnet_config: No ControlNet pipeline found (expected when ControlNet not loaded)") + logger.debug( + "_update_controlnet_config: No ControlNet pipeline found (expected when ControlNet not loaded)" + ) return - + current_config = self._get_current_controlnet_config() - + # Simple approach: detect what changed and apply minimal updates - current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)} - desired_models = {cfg['model_id']: cfg for cfg in desired_config} - + current_models = { + i: getattr(cn, "model_id", f"controlnet_{i}") for i, cn in enumerate(controlnet_pipeline.controlnets) + } + desired_models = {cfg["model_id"]: cfg for cfg in desired_config} + # Reorder to match desired order (module supports stable reordering) try: - desired_order = [cfg['model_id'] for cfg in desired_config if 'model_id' in cfg] - if hasattr(controlnet_pipeline, 'reorder_controlnets_by_model_ids'): + desired_order = [cfg["model_id"] for cfg in desired_config if "model_id" in cfg] + if hasattr(controlnet_pipeline, "reorder_controlnets_by_model_ids"): controlnet_pipeline.reorder_controlnets_by_model_ids(desired_order) except Exception: pass # Recompute current models after potential reorder - current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)} + current_models = { + i: getattr(cn, "model_id", f"controlnet_{i}") for i, cn in enumerate(controlnet_pipeline.controlnets) + } # Remove controlnets not in desired config for i in reversed(range(len(controlnet_pipeline.controlnets))): - model_id = current_models.get(i, f'controlnet_{i}') + model_id = current_models.get(i, f"controlnet_{i}") if model_id not in desired_models: logger.info(f"_update_controlnet_config: Removing ControlNet {model_id}") try: controlnet_pipeline.remove_controlnet(i) except Exception: raise - + # Add new controlnets and update existing ones for desired_cfg in desired_config: - model_id = desired_cfg['model_id'] + model_id = desired_cfg["model_id"] existing_index = next((i for i, mid in current_models.items() if mid == model_id), None) - + if existing_index is None: # Add new controlnet logger.info(f"_update_controlnet_config: Adding ControlNet {model_id}") @@ -1205,15 +1168,16 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non # Prefer module path: construct ControlNetConfig try: from .modules.controlnet_module import ControlNetConfig # type: ignore + cn_cfg = ControlNetConfig( - model_id=desired_cfg.get('model_id'), - preprocessor=desired_cfg.get('preprocessor'), - conditioning_scale=desired_cfg.get('conditioning_scale', 1.0), - enabled=desired_cfg.get('enabled', True), - conditioning_channels=desired_cfg.get('conditioning_channels'), - preprocessor_params=desired_cfg.get('preprocessor_params'), + model_id=desired_cfg.get("model_id"), + preprocessor=desired_cfg.get("preprocessor"), + conditioning_scale=desired_cfg.get("conditioning_scale", 1.0), + enabled=desired_cfg.get("enabled", True), + conditioning_channels=desired_cfg.get("conditioning_channels"), + preprocessor_params=desired_cfg.get("preprocessor_params"), ) - controlnet_pipeline.add_controlnet(cn_cfg, desired_cfg.get('control_image')) + controlnet_pipeline.add_controlnet(cn_cfg, desired_cfg.get("control_image")) except Exception: # No fallback raise @@ -1221,114 +1185,136 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non logger.error(f"_update_controlnet_config: add_controlnet failed for {model_id}: {e}") else: # Update existing controlnet - if 'conditioning_scale' in desired_cfg: - current_scale = current_config[existing_index].get('conditioning_scale', 1.0) - desired_scale = desired_cfg['conditioning_scale'] - + if "conditioning_scale" in desired_cfg: + current_scale = current_config[existing_index].get("conditioning_scale", 1.0) + desired_scale = desired_cfg["conditioning_scale"] + if current_scale != desired_scale: - logger.info(f"_update_controlnet_config: Updating {model_id} scale: {current_scale} → {desired_scale}") - if hasattr(controlnet_pipeline, 'controlnet_scales') and 0 <= existing_index < len(controlnet_pipeline.controlnet_scales): + logger.info( + f"_update_controlnet_config: Updating {model_id} scale: {current_scale} → {desired_scale}" + ) + if hasattr(controlnet_pipeline, "controlnet_scales") and 0 <= existing_index < len( + controlnet_pipeline.controlnet_scales + ): controlnet_pipeline.controlnet_scales[existing_index] = float(desired_scale) - + # Enable/disable toggle - if 'enabled' in desired_cfg and hasattr(controlnet_pipeline, 'enabled_list'): + if "enabled" in desired_cfg and hasattr(controlnet_pipeline, "enabled_list"): if 0 <= existing_index < len(controlnet_pipeline.enabled_list): - controlnet_pipeline.enabled_list[existing_index] = bool(desired_cfg['enabled']) + controlnet_pipeline.enabled_list[existing_index] = bool(desired_cfg["enabled"]) - if 'preprocessor_params' in desired_cfg and hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[existing_index]: + if ( + "preprocessor_params" in desired_cfg + and hasattr(controlnet_pipeline, "preprocessors") + and controlnet_pipeline.preprocessors[existing_index] + ): preprocessor = controlnet_pipeline.preprocessors[existing_index] - preprocessor.params.update(desired_cfg['preprocessor_params']) - for param_name, param_value in desired_cfg['preprocessor_params'].items(): + preprocessor.params.update(desired_cfg["preprocessor_params"]) + for param_name, param_value in desired_cfg["preprocessor_params"].items(): if hasattr(preprocessor, param_name): setattr(preprocessor, param_name, param_value) - + # Pipeline references are now automatically managed during preprocessor creation # No need to manually re-establish pipeline references for pipeline-aware processors - def _get_controlnet_pipeline(self): """ Get the ControlNet module or legacy pipeline from the structure (module-aware). """ # Module-installed path - if hasattr(self.stream, '_controlnet_module'): + if hasattr(self.stream, "_controlnet_module"): return self.stream._controlnet_module # Legacy paths - if hasattr(self.stream, 'controlnets'): + if hasattr(self.stream, "controlnets"): return self.stream - if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'controlnets'): + if hasattr(self.stream, "stream") and hasattr(self.stream.stream, "controlnets"): return self.stream.stream - if self.wrapper and hasattr(self.wrapper, 'stream'): - if hasattr(self.wrapper.stream, '_controlnet_module'): + if self.wrapper and hasattr(self.wrapper, "stream"): + if hasattr(self.wrapper.stream, "_controlnet_module"): return self.wrapper.stream._controlnet_module - if hasattr(self.wrapper.stream, 'controlnets'): + if hasattr(self.wrapper.stream, "controlnets"): return self.wrapper.stream - if hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'controlnets'): + if hasattr(self.wrapper.stream, "stream") and hasattr(self.wrapper.stream.stream, "controlnets"): return self.wrapper.stream.stream return None def _get_current_controlnet_config(self) -> List[Dict[str, Any]]: """ Get current ControlNet configuration state. - + Returns: List of current ControlNet configurations """ controlnet_pipeline = self._get_controlnet_pipeline() - if not controlnet_pipeline or not hasattr(controlnet_pipeline, 'controlnets') or not controlnet_pipeline.controlnets: + if ( + not controlnet_pipeline + or not hasattr(controlnet_pipeline, "controlnets") + or not controlnet_pipeline.controlnets + ): return [] - + current_config = [] for i, controlnet in enumerate(controlnet_pipeline.controlnets): - model_id = getattr(controlnet, 'model_id', f'controlnet_{i}') - scale = controlnet_pipeline.controlnet_scales[i] if hasattr(controlnet_pipeline, 'controlnet_scales') and i < len(controlnet_pipeline.controlnet_scales) else 1.0 + model_id = getattr(controlnet, "model_id", f"controlnet_{i}") + scale = ( + controlnet_pipeline.controlnet_scales[i] + if hasattr(controlnet_pipeline, "controlnet_scales") and i < len(controlnet_pipeline.controlnet_scales) + else 1.0 + ) enabled_val = True try: - if hasattr(controlnet_pipeline, 'enabled_list') and i < len(controlnet_pipeline.enabled_list): + if hasattr(controlnet_pipeline, "enabled_list") and i < len(controlnet_pipeline.enabled_list): enabled_val = bool(controlnet_pipeline.enabled_list[i]) except Exception: enabled_val = True config = { - 'model_id': model_id, - 'conditioning_scale': scale, - 'preprocessor_params': getattr(controlnet_pipeline.preprocessors[i], 'params', {}) if hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[i] else {}, - 'enabled': enabled_val, + "model_id": model_id, + "conditioning_scale": scale, + "preprocessor_params": getattr(controlnet_pipeline.preprocessors[i], "params", {}) + if hasattr(controlnet_pipeline, "preprocessors") and controlnet_pipeline.preprocessors[i] + else {}, + "enabled": enabled_val, } current_config.append(config) - + return current_config def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: """ Update IPAdapter configuration. - + Args: - desired_config: IPAdapter configuration dict containing: + desired_config: IPAdapter configuration dict containing: ipadapter_model_path, image_encoder_path, style_image, scale, enabled, etc. """ # Find the IPAdapter pipeline ipadapter_pipeline = self._get_ipadapter_pipeline() - + if not ipadapter_pipeline: - logger.warning(f"_update_ipadapter_config: No IPAdapter pipeline found") + logger.warning("_update_ipadapter_config: No IPAdapter pipeline found") return - - if 'scale' in desired_config and desired_config['scale'] is not None: - desired_scale = float(desired_config['scale']) + + if "scale" in desired_config and desired_config["scale"] is not None: + desired_scale = float(desired_config["scale"]) # Get current scale from IPAdapter instance - current_scale = getattr(self.stream.ipadapter, 'scale', 1.0) if hasattr(self.stream, 'ipadapter') else 1.0 - + current_scale = getattr(self.stream.ipadapter, "scale", 1.0) if hasattr(self.stream, "ipadapter") else 1.0 + if current_scale != desired_scale: logger.info(f"_update_ipadapter_config: Updating scale: {current_scale} → {desired_scale}") - + # Get weight_type from IPAdapter instance - weight_type = getattr(self.stream.ipadapter, 'weight_type', None) if hasattr(self.stream, 'ipadapter') else None - + weight_type = ( + getattr(self.stream.ipadapter, "weight_type", None) if hasattr(self.stream, "ipadapter") else None + ) + # Apply scale with weight type consideration - if weight_type is not None and hasattr(self.stream, 'ipadapter'): + if weight_type is not None and hasattr(self.stream, "ipadapter"): try: from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights - ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")] + + ip_procs = [ + p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index") + ] num_layers = len(ip_procs) weights = build_layer_weights(num_layers, desired_scale, weight_type) if weights is not None: @@ -1336,47 +1322,51 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: else: self.stream.ipadapter.set_scale(desired_scale) # Update our tracking attribute - setattr(self.stream.ipadapter, 'scale', desired_scale) + setattr(self.stream.ipadapter, "scale", desired_scale) except Exception: # Do not add fallback mechanisms raise else: # Simple uniform scale - if hasattr(self.stream, 'ipadapter'): + if hasattr(self.stream, "ipadapter"): # Tell diffusers_ipadapter to set the scale self.stream.ipadapter.set_scale(desired_scale) # Update our tracking attribute - setattr(self.stream.ipadapter, 'scale', desired_scale) - + setattr(self.stream.ipadapter, "scale", desired_scale) # Update enabled state if provided - if 'enabled' in desired_config and desired_config['enabled'] is not None: - enabled_state = bool(desired_config['enabled']) + if "enabled" in desired_config and desired_config["enabled"] is not None: + enabled_state = bool(desired_config["enabled"]) # Update IPAdapter instance - if hasattr(self.stream, 'ipadapter'): - current_enabled = getattr(self.stream.ipadapter, 'enabled', True) + if hasattr(self.stream, "ipadapter"): + current_enabled = getattr(self.stream.ipadapter, "enabled", True) if current_enabled != enabled_state: - logger.info(f"_update_ipadapter_config: Updating enabled state: {current_enabled} → {enabled_state}") - setattr(self.stream.ipadapter, 'enabled', enabled_state) + logger.info( + f"_update_ipadapter_config: Updating enabled state: {current_enabled} → {enabled_state}" + ) + setattr(self.stream.ipadapter, "enabled", enabled_state) # Update weight type if provided (affects per-layer distribution and/or per-step factor) - if 'weight_type' in desired_config and desired_config['weight_type'] is not None: - weight_type = desired_config['weight_type'] + if "weight_type" in desired_config and desired_config["weight_type"] is not None: + weight_type = desired_config["weight_type"] # Update IPAdapter instance - if hasattr(self.stream, 'ipadapter'): - setattr(self.stream.ipadapter, 'weight_type', weight_type) - + if hasattr(self.stream, "ipadapter"): + setattr(self.stream.ipadapter, "weight_type", weight_type) + # For PyTorch UNet, immediately apply a per-layer scale vector so layers reflect selection types try: - is_tensorrt_engine = hasattr(self.stream.unet, 'engine') and hasattr(self.stream.unet, 'stream') + is_tensorrt_engine = hasattr(self.stream.unet, "engine") and hasattr(self.stream.unet, "stream") if not is_tensorrt_engine: # Compute per-layer vector using Diffusers_IPAdapter helper from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights + # Count installed IP layers by scanning processors with _ip_layer_index - ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")] + ip_procs = [ + p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index") + ] num_layers = len(ip_procs) # Get base weight from IPAdapter instance - base_weight = float(getattr(self.stream.ipadapter, 'scale', 1.0)) + base_weight = float(getattr(self.stream.ipadapter, "scale", 1.0)) weights = build_layer_weights(num_layers, base_weight, weight_type) # If None, keep uniform base scale; else set per-layer vector if weights is not None: @@ -1384,7 +1374,7 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: else: self.stream.ipadapter.set_scale(base_weight) # Keep our tracking attribute in sync - setattr(self.stream.ipadapter, 'scale', base_weight) + setattr(self.stream.ipadapter, "scale", base_weight) except Exception: # Do not add fallback mechanisms raise @@ -1392,191 +1382,207 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: def _get_ipadapter_pipeline(self): """ Get the IPAdapter pipeline from the pipeline structure (following ControlNet pattern). - + Returns: IPAdapter pipeline object or None if not found """ # Check if stream is IPAdapter pipeline directly - if hasattr(self.stream, 'ipadapter'): + if hasattr(self.stream, "ipadapter"): return self.stream - + # Check if stream has nested stream (ControlNet wrapper) - if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'ipadapter'): + if hasattr(self.stream, "stream") and hasattr(self.stream.stream, "ipadapter"): return self.stream.stream - + # Check if we have a wrapper reference and can access through it - if self.wrapper and hasattr(self.wrapper, 'stream'): - if hasattr(self.wrapper.stream, 'ipadapter'): + if self.wrapper and hasattr(self.wrapper, "stream"): + if hasattr(self.wrapper.stream, "ipadapter"): return self.wrapper.stream - elif hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'ipadapter'): + elif hasattr(self.wrapper.stream, "stream") and hasattr(self.wrapper.stream.stream, "ipadapter"): return self.wrapper.stream.stream - + return None def _get_current_ipadapter_config(self) -> Optional[Dict[str, Any]]: """ Get current IPAdapter configuration by introspecting the IPAdapter instance. - + Returns: Current IPAdapter configuration dict or None if no IPAdapter """ # Get config from IPAdapter instance - if hasattr(self.stream, 'ipadapter') and self.stream.ipadapter is not None: + if hasattr(self.stream, "ipadapter") and self.stream.ipadapter is not None: ipadapter = self.stream.ipadapter - + config = { - 'scale': getattr(ipadapter, 'scale', 1.0), - 'weight_type': getattr(ipadapter, 'weight_type', None), - 'enabled': getattr(ipadapter, 'enabled', True), # Check actual enabled state + "scale": getattr(ipadapter, "scale", 1.0), + "weight_type": getattr(ipadapter, "weight_type", None), + "enabled": getattr(ipadapter, "enabled", True), # Check actual enabled state } - + # Add static initialization fields - if hasattr(self.stream, '_ipadapter_module'): + if hasattr(self.stream, "_ipadapter_module"): module_config = self.stream._ipadapter_module.config - config.update({ - 'style_image_key': module_config.style_image_key, - 'num_image_tokens': module_config.num_image_tokens, - 'type': module_config.type.value, - }) - + config.update( + { + "style_image_key": module_config.style_image_key, + "num_image_tokens": module_config.num_image_tokens, + "type": module_config.type.value, + } + ) + # Check if style image is set ipadapter_pipeline = self._get_ipadapter_pipeline() - if ipadapter_pipeline and hasattr(ipadapter_pipeline, 'style_image') and ipadapter_pipeline.style_image: - config['has_style_image'] = True + if ipadapter_pipeline and hasattr(ipadapter_pipeline, "style_image") and ipadapter_pipeline.style_image: + config["has_style_image"] = True else: - config['has_style_image'] = False - + config["has_style_image"] = False + return config - + # No IPAdapter instance found return None def _get_current_hook_config(self, hook_type: str) -> List[Dict[str, Any]]: """ Get current hook configuration by introspecting the hook module state. - + Args: hook_type: Type of hook (image_preprocessing, image_postprocessing, etc.) - + Returns: List of processor configurations or empty list if no module """ # Get the hook module module_attr_name = f"_{hook_type}_module" hook_module = getattr(self.stream, module_attr_name, None) - + if not hook_module: return [] - + # Get processors from the module - processors = getattr(hook_module, 'processors', []) - + processors = getattr(hook_module, "processors", []) + config = [] for i, processor in enumerate(processors): proc_config = { - 'type': getattr(processor, '__class__').__name__, - 'order': getattr(processor, 'order', i), - 'enabled': getattr(processor, 'enabled', True), + "type": getattr(processor, "__class__").__name__, + "order": getattr(processor, "order", i), + "enabled": getattr(processor, "enabled", True), } - + # Try to get processor parameters - if hasattr(processor, 'params'): - proc_config['params'] = dict(processor.params) - + if hasattr(processor, "params"): + proc_config["params"] = dict(processor.params) + config.append(proc_config) - + return config def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any]]) -> None: """ Update hook configuration by modifying existing processors in-place instead of recreating them. - + Args: hook_type: Type of hook (image_preprocessing, image_postprocessing, etc.) desired_config: List of processor configurations """ logger.info(f"_update_hook_config: Updating {hook_type} with {len(desired_config)} processors") - + # Get or create the hook module module_attr_name = f"_{hook_type}_module" hook_module = getattr(self.stream, module_attr_name, None) - + if not hook_module: logger.info(f"_update_hook_config: No existing {hook_type} module, creating new one") # Create the appropriate hook module try: if hook_type in ["image_preprocessing", "image_postprocessing"]: - from streamdiffusion.modules.image_processing_module import ImagePreprocessingModule, ImagePostprocessingModule + from streamdiffusion.modules.image_processing_module import ( + ImagePostprocessingModule, + ImagePreprocessingModule, + ) + if hook_type == "image_preprocessing": hook_module = ImagePreprocessingModule() else: hook_module = ImagePostprocessingModule() elif hook_type in ["latent_preprocessing", "latent_postprocessing"]: - from streamdiffusion.modules.latent_processing_module import LatentPreprocessingModule, LatentPostprocessingModule + from streamdiffusion.modules.latent_processing_module import ( + LatentPostprocessingModule, + LatentPreprocessingModule, + ) + if hook_type == "latent_preprocessing": hook_module = LatentPreprocessingModule() else: hook_module = LatentPostprocessingModule() else: raise ValueError(f"Unknown hook type: {hook_type}") - + # Install the module hook_module.install(self.stream) setattr(self.stream, module_attr_name, hook_module) logger.info(f"_update_hook_config: Created and installed {hook_type} module") - + except Exception as e: logger.error(f"_update_hook_config: Failed to create {hook_type} module: {e}") return - - logger.info(f"_update_hook_config: Found existing {hook_type} module with {len(hook_module.processors)} processors") - + + logger.info( + f"_update_hook_config: Found existing {hook_type} module with {len(hook_module.processors)} processors" + ) + # Modify existing processors in-place instead of clearing and recreating for i, proc_config in enumerate(desired_config): - processor_type = proc_config.get('type', 'unknown') - enabled = proc_config.get('enabled', True) - params = proc_config.get('params', {}) - + processor_type = proc_config.get("type", "unknown") + enabled = proc_config.get("enabled", True) + params = proc_config.get("params", {}) + logger.info(f"_update_hook_config: Processing config {i}: type={processor_type}, enabled={enabled}") - + if i < len(hook_module.processors): # Modify existing processor existing_processor = hook_module.processors[i] - + # Get the current processor type from registry name if available, otherwise use class name - current_type = existing_processor.params.get('_registry_name') if hasattr(existing_processor, 'params') else None + current_type = ( + existing_processor.params.get("_registry_name") if hasattr(existing_processor, "params") else None + ) if not current_type: current_type = existing_processor.__class__.__name__ - - logger.info(f"_update_hook_config: Modifying existing processor {i}: {current_type} -> {processor_type}") - + + logger.info( + f"_update_hook_config: Modifying existing processor {i}: {current_type} -> {processor_type}" + ) + # If processor type changed, replace it if current_type.lower() != processor_type.lower(): logger.info(f"_update_hook_config: Type changed, replacing processor {i}") try: from streamdiffusion.preprocessing.processors import get_preprocessor - + # Determine normalization context from hook type - if 'latent' in hook_type: - normalization_context = 'latent' + if "latent" in hook_type: + normalization_context = "latent" else: # Image preprocessing/postprocessing uses 'pipeline' context - normalization_context = 'pipeline' - + normalization_context = "pipeline" + new_processor = get_preprocessor( - processor_type, - pipeline_ref=getattr(self, 'stream', None), - normalization_context=normalization_context + processor_type, + pipeline_ref=getattr(self, "stream", None), + normalization_context=normalization_context, ) - + # Copy attributes from old processor - setattr(new_processor, 'order', getattr(existing_processor, 'order', i)) - setattr(new_processor, 'enabled', enabled) - + setattr(new_processor, "order", getattr(existing_processor, "order", i)) + setattr(new_processor, "enabled", enabled) + # Set parameters - if hasattr(new_processor, 'params'): + if hasattr(new_processor, "params"): new_processor.params.update(params) - + hook_module.processors[i] = new_processor logger.info(f"_update_hook_config: Successfully replaced processor {i} with {processor_type}") except Exception as e: @@ -1584,15 +1590,15 @@ def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any else: # Same type, just update attributes logger.info(f"_update_hook_config: Same type, updating attributes for processor {i}") - setattr(existing_processor, 'enabled', enabled) - + setattr(existing_processor, "enabled", enabled) + # Update parameters - if hasattr(existing_processor, 'params'): + if hasattr(existing_processor, "params"): existing_processor.params.update(params) for param_name, param_value in params.items(): if hasattr(existing_processor, param_name): setattr(existing_processor, param_name, param_value) - + logger.info(f"_update_hook_config: Updated processor {i} enabled={enabled}, params={params}") else: # Add new processor @@ -1602,12 +1608,15 @@ def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any logger.info(f"_update_hook_config: Successfully added processor {i}: {processor_type}") except Exception as e: logger.error(f"_update_hook_config: Failed to add processor {i}: {e}") - + # Remove extra processors if config is shorter while len(hook_module.processors) > len(desired_config): removed_idx = len(hook_module.processors) - 1 removed_processor = hook_module.processors.pop() - logger.info(f"_update_hook_config: Removed extra processor {removed_idx}: {removed_processor.__class__.__name__}") - - logger.info(f"_update_hook_config: Finished updating {hook_type}, now has {len(hook_module.processors)} processors") + logger.info( + f"_update_hook_config: Removed extra processor {removed_idx}: {removed_processor.__class__.__name__}" + ) + logger.info( + f"_update_hook_config: Finished updating {hook_type}, now has {len(hook_module.processors)} processors" + ) diff --git a/src/streamdiffusion/tools/compile_raft_tensorrt.py b/src/streamdiffusion/tools/compile_raft_tensorrt.py index 8734987e..d3faef95 100644 --- a/src/streamdiffusion/tools/compile_raft_tensorrt.py +++ b/src/streamdiffusion/tools/compile_raft_tensorrt.py @@ -1,21 +1,24 @@ -import torch import logging from pathlib import Path -from typing import Optional + import fire +import torch + -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) try: import tensorrt as trt + TENSORRT_AVAILABLE = True except ImportError: TENSORRT_AVAILABLE = False logger.error("TensorRT not available. Please install it first.") try: - from torchvision.models.optical_flow import raft_small, Raft_Small_Weights + from torchvision.models.optical_flow import Raft_Small_Weights, raft_small + TORCHVISION_AVAILABLE = True except ImportError: TORCHVISION_AVAILABLE = False @@ -28,11 +31,11 @@ def export_raft_to_onnx( min_width: int = 512, max_height: int = 512, max_width: int = 512, - device: str = "cuda" + device: str = "cuda", ) -> bool: """ Export RAFT model to ONNX format - + Args: onnx_path: Path to save the ONNX model min_height: Minimum input height for the model @@ -40,41 +43,41 @@ def export_raft_to_onnx( max_height: Maximum input height for the model max_width: Maximum input width for the model device: Device to use for export - + Returns: True if successful, False otherwise """ if not TORCHVISION_AVAILABLE: logger.error("torchvision is required but not installed") return False - + logger.info(f"Exporting RAFT model to ONNX: {onnx_path}") logger.info(f"Resolution range: {min_height}x{min_width} - {max_height}x{max_width}") - + try: # Load RAFT model logger.info("Loading RAFT Small model...") raft_model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=True) raft_model = raft_model.to(device=device) raft_model.eval() - + # Create dummy inputs using max resolution for export dummy_frame1 = torch.randn(1, 3, max_height, max_width).to(device) dummy_frame2 = torch.randn(1, 3, max_height, max_width).to(device) - + # Apply RAFT preprocessing if available weights = Raft_Small_Weights.DEFAULT - if hasattr(weights, 'transforms') and weights.transforms is not None: + if hasattr(weights, "transforms") and weights.transforms is not None: transforms = weights.transforms() dummy_frame1, dummy_frame2 = transforms(dummy_frame1, dummy_frame2) - + # Make batch, height, and width dimensions dynamic dynamic_axes = { "frame1": {0: "batch_size", 2: "height", 3: "width"}, "frame2": {0: "batch_size", 2: "height", 3: "width"}, "flow": {0: "batch_size", 2: "height", 3: "width"}, } - + logger.info("Exporting to ONNX...") with torch.no_grad(): torch.onnx.export( @@ -82,22 +85,23 @@ def export_raft_to_onnx( (dummy_frame1, dummy_frame2), str(onnx_path), verbose=False, - input_names=['frame1', 'frame2'], - output_names=['flow'], + input_names=["frame1", "frame2"], + output_names=["flow"], opset_version=17, export_params=True, dynamic_axes=dynamic_axes, ) - + del raft_model torch.cuda.empty_cache() - + logger.info(f"Successfully exported ONNX model to {onnx_path}") return True - + except Exception as e: logger.error(f"Failed to export ONNX model: {e}") import traceback + traceback.print_exc() return False @@ -110,11 +114,11 @@ def build_tensorrt_engine( max_height: int = 512, max_width: int = 512, fp16: bool = True, - workspace_size_gb: int = 4 + workspace_size_gb: int = 4, ) -> bool: """ Build TensorRT engine from ONNX model - + Args: onnx_path: Path to the ONNX model engine_path: Path to save the TensorRT engine @@ -124,74 +128,74 @@ def build_tensorrt_engine( max_width: Maximum input width for optimization fp16: Enable FP16 precision mode workspace_size_gb: Maximum workspace size in GB - + Returns: True if successful, False otherwise """ if not TENSORRT_AVAILABLE: logger.error("TensorRT is required but not installed") return False - + if not onnx_path.exists(): logger.error(f"ONNX model not found: {onnx_path}") return False - + logger.info(f"Building TensorRT engine from ONNX model: {onnx_path}") logger.info(f"Output path: {engine_path}") logger.info(f"Resolution range: {min_height}x{min_width} - {max_height}x{max_width}") logger.info(f"FP16 mode: {fp16}") logger.info("This may take several minutes...") - + try: builder = trt.Builder(trt.Logger(trt.Logger.INFO)) network = builder.create_network() # EXPLICIT_BATCH deprecated/ignored in TRT 10.x parser = trt.OnnxParser(network, trt.Logger(trt.Logger.WARNING)) - + logger.info("Parsing ONNX model...") - with open(onnx_path, 'rb') as model: + with open(onnx_path, "rb") as model: if not parser.parse(model.read()): logger.error("Failed to parse ONNX model") for error in range(parser.num_errors): logger.error(f"Parser error: {parser.get_error(error)}") return False - + logger.info("Configuring TensorRT builder...") config = builder.create_builder_config() - + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size_gb * (1 << 30)) - + if fp16: config.set_flag(trt.BuilderFlag.FP16) logger.info("FP16 mode enabled") - + # Calculate optimal resolution (middle point) opt_height = (min_height + max_height) // 2 opt_width = (min_width + max_width) // 2 - + profile = builder.create_optimization_profile() min_shape = (1, 3, min_height, min_width) opt_shape = (1, 3, opt_height, opt_width) max_shape = (1, 3, max_height, max_width) - + profile.set_shape("frame1", min_shape, opt_shape, max_shape) profile.set_shape("frame2", min_shape, opt_shape, max_shape) config.add_optimization_profile(profile) - + logger.info("Building TensorRT engine... (this will take a while)") engine = builder.build_serialized_network(network, config) - + if engine is None: logger.error("Failed to build TensorRT engine") return False - + logger.info(f"Saving engine to {engine_path}") engine_path.parent.mkdir(parents=True, exist_ok=True) - with open(engine_path, 'wb') as f: + with open(engine_path, "wb") as f: f.write(engine) - + logger.info(f"Successfully built and saved TensorRT engine: {engine_path}") - logger.info(f"Engine size: {engine_path.stat().st_size / (1024*1024):.2f} MB") - + logger.info(f"Engine size: {engine_path.stat().st_size / (1024 * 1024):.2f} MB") + # Delete ONNX file after successful engine creation try: if onnx_path.exists(): @@ -199,12 +203,13 @@ def build_tensorrt_engine( logger.info(f"Deleted ONNX file: {onnx_path}") except Exception as e: logger.warning(f"Failed to delete ONNX file: {e}") - + return True - + except Exception as e: logger.error(f"Failed to build TensorRT engine: {e}") import traceback + traceback.print_exc() return False @@ -216,11 +221,11 @@ def compile_raft( device: str = "cuda", fp16: bool = True, workspace_size_gb: int = 4, - force_rebuild: bool = False + force_rebuild: bool = False, ): """ Main function to compile RAFT model to TensorRT engine - + Args: min_resolution: Minimum input resolution as "HxW" (e.g., "512x512") (default: "512x512") max_resolution: Maximum input resolution as "HxW" (e.g., "1024x1024") (default: "512x512") @@ -234,46 +239,46 @@ def compile_raft( logger.error("TensorRT is not available. Please install it first using:") logger.error(" python -m streamdiffusion.tools.install-tensorrt") return - + if not TORCHVISION_AVAILABLE: logger.error("torchvision is not available. Please install it first using:") logger.error(" pip install torchvision") return - + # Parse resolution strings try: - min_height, min_width = map(int, min_resolution.split('x')) + min_height, min_width = map(int, min_resolution.split("x")) except: logger.error(f"Invalid min_resolution format: {min_resolution}. Expected format: HxW (e.g., 512x512)") return - + try: - max_height, max_width = map(int, max_resolution.split('x')) + max_height, max_width = map(int, max_resolution.split("x")) except: logger.error(f"Invalid max_resolution format: {max_resolution}. Expected format: HxW (e.g., 1024x1024)") return - + output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) - + # Add resolution suffix to filenames onnx_path = output_path / f"raft_small_min_{min_resolution}_max_{max_resolution}.onnx" engine_path = output_path / f"raft_small_min_{min_resolution}_max_{max_resolution}.engine" - - logger.info("="*80) + + logger.info("=" * 80) logger.info("RAFT TensorRT Compilation") - logger.info("="*80) + logger.info("=" * 80) logger.info(f"Output directory: {output_path.absolute()}") logger.info(f"Resolution range: {min_resolution} - {max_resolution}") logger.info(f"ONNX path: {onnx_path}") logger.info(f"Engine path: {engine_path}") - logger.info("="*80) - + logger.info("=" * 80) + if engine_path.exists() and not force_rebuild: logger.info(f"TensorRT engine already exists: {engine_path}") logger.info("Use --force_rebuild to rebuild it") return - + if not onnx_path.exists() or force_rebuild: logger.info("\n[Step 1/2] Exporting RAFT to ONNX...") if not export_raft_to_onnx(onnx_path, min_height, min_width, max_height, max_width, device): @@ -281,21 +286,22 @@ def compile_raft( return else: logger.info(f"\n[Step 1/2] ONNX model already exists: {onnx_path}") - + logger.info("\n[Step 2/2] Building TensorRT engine...") - if not build_tensorrt_engine(onnx_path, engine_path, min_height, min_width, max_height, max_width, fp16, workspace_size_gb): + if not build_tensorrt_engine( + onnx_path, engine_path, min_height, min_width, max_height, max_width, fp16, workspace_size_gb + ): logger.error("Failed to build TensorRT engine") return - - logger.info("\n" + "="*80) + + logger.info("\n" + "=" * 80) logger.info("✓ Compilation completed successfully!") - logger.info("="*80) + logger.info("=" * 80) logger.info(f"Engine path: {engine_path.absolute()}") logger.info("\nYou can now use this engine in TemporalNetTensorRTPreprocessor:") logger.info(f' engine_path="{engine_path.absolute()}"') - logger.info("="*80) + logger.info("=" * 80) if __name__ == "__main__": fire.Fire(compile_raft) - diff --git a/src/streamdiffusion/tools/cuda_l2_cache.py b/src/streamdiffusion/tools/cuda_l2_cache.py index cdafcaa9..176a158d 100644 --- a/src/streamdiffusion/tools/cuda_l2_cache.py +++ b/src/streamdiffusion/tools/cuda_l2_cache.py @@ -10,6 +10,7 @@ Environment variables: SDTD_L2_PERSIST=1 Enable L2 persistence (default: 1) SDTD_L2_PERSIST_MB=64 MB of L2 to reserve for persistent data (default: 64) + SDTD_L2_PERSIST_TIER2=0 Enable per-tensor access policy window (default: 0, nn.Module UNet only) SDTD_L2_PERSIST_LAYERS= Comma-separated layer names for access policy (default: auto) Expected impact: 5-16% on memory-bandwidth-bound layers (normalization, small GEMMs). @@ -30,6 +31,9 @@ L2_PERSIST_ENABLED = os.environ.get("SDTD_L2_PERSIST", "1") == "1" L2_PERSIST_MB = int(os.environ.get("SDTD_L2_PERSIST_MB", "64")) +# Tier 2 (per-tensor access policy window) is opt-in: only works for PyTorch nn.Module +# UNets (not TRT engines), and only the single largest tensor window is active per stream. +L2_PERSIST_TIER2 = os.environ.get("SDTD_L2_PERSIST_TIER2", "0") == "1" # Hot layer prefixes — these contain the most attention + FF hook computation. # mid_block: 1 transformer block, seq_len=1024, 16 FF hooks @@ -112,17 +116,11 @@ def _get_cudart() -> Optional[ctypes.CDLL]: # Option 1: PyTorch ships cudart in torch/lib/ torch_lib = os.path.join(os.path.dirname(torch.__file__), "lib") - candidates = sorted( - glob.glob(os.path.join(torch_lib, "cudart64_*.dll")), reverse=True - ) + candidates = sorted(glob.glob(os.path.join(torch_lib, "cudart64_*.dll")), reverse=True) # Option 2: CUDA toolkit installation - cuda_path = os.environ.get( - "CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8" - ) - candidates += sorted( - glob.glob(os.path.join(cuda_path, "bin", "cudart64_*.dll")), reverse=True - ) + cuda_path = os.environ.get("CUDA_PATH", r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8") + candidates += sorted(glob.glob(os.path.join(cuda_path, "bin", "cudart64_*.dll")), reverse=True) for dll_path in candidates: try: @@ -162,9 +160,7 @@ def reserve_l2_persisting_cache(persist_mb: int = L2_PERSIST_MB) -> bool: props = torch.cuda.get_device_properties(device) major, minor = props.major, props.minor if major < 8: - print( - f"[L2] L2 persistence skipped — compute {major}.{minor} < 8.0 (Ampere required)" - ) + print(f"[L2] L2 persistence skipped — compute {major}.{minor} < 8.0 (Ampere required)") return False l2_total_mb = props.L2_cache_size // (1024 * 1024) @@ -296,22 +292,28 @@ def pin_hot_unet_weights( persist_mb: int = L2_PERSIST_MB, ) -> int: """ - Mark hot UNet layer weights as L2-persistent. + Mark the single largest hot UNet attention weight as L2-persistent. - Identifies attention Q/K/V/out projection weights in the hottest layers - (mid_block, up_blocks.1) and requests they persist in L2 cache. + CUDA allows only one cudaAccessPolicyWindow per stream at a time — registering + multiple tensors silently replaces the previous window. This function correctly + picks the single largest hot attention weight (by byte size) and registers exactly + one window for it. Args: - unet: The UNet model (already on CUDA). + unet: The UNet model (already on CUDA, must be torch.nn.Module). hot_prefixes: Layer name prefixes to target. Defaults to mid_block + up_blocks.1. persist_mb: MB of L2 to reserve (passed to reserve_l2_persisting_cache). Returns: - Number of weight tensors successfully pinned. + 1 if a tensor was pinned, 0 otherwise. """ if not L2_PERSIST_ENABLED: return 0 + if not isinstance(unet, torch.nn.Module): + print("[L2] Tier 2 skipped — model is not nn.Module (e.g. TRT engine). Use Tier 1 only.") + return 0 + if hot_prefixes is None: hot_prefixes = _DEFAULT_HOT_LAYER_PREFIXES @@ -321,12 +323,14 @@ def pin_hot_unet_weights( if not tier1_ok: return 0 - # Tier 2: Set access policy on hot attention weights - # Target: to_q, to_k, to_v, to_out weights in hot transformer blocks. - # These are small-to-medium GEMMs that benefit most from L2 hits. + # Tier 2: Find the single largest hot attention weight. + # CUDA allows only one cudaAccessPolicyWindow per stream — registering N tensors + # results in only the Nth window being active (each call replaces the previous). + # Pinning the largest tensor maximises L2 utilization for the one permitted window. _hot_weight_keywords = ["to_q", "to_k", "to_v", "to_out"] - pinned_count = 0 - pinned_bytes = 0 + best_tensor = None + best_bytes = 0 + candidate_count = 0 for name, param in unet.named_parameters(): if not param.is_cuda: @@ -334,21 +338,21 @@ def pin_hot_unet_weights( is_hot = any(prefix in name for prefix in hot_prefixes) is_attn_weight = any(kw in name for kw in _hot_weight_keywords) if is_hot and is_attn_weight: - if set_tensor_persisting(param.data): - pinned_count += 1 - pinned_bytes += param.data.nbytes + candidate_count += 1 + if param.data.nbytes > best_bytes: + best_bytes = param.data.nbytes + best_tensor = param.data - if pinned_count > 0: + if best_tensor is not None and set_tensor_persisting(best_tensor): print( - f"[L2] Pinned {pinned_count} attention weight tensors " - f"({pinned_bytes / 1024 / 1024:.1f}MB) in L2 persisting cache" - ) - else: - print( - "[L2] No tensors pinned (params may require_grad=True before compile — call after freeze)" + f"[L2] Pinned 1 of {candidate_count} hot tensors (largest, " + f"{best_bytes / 1024 / 1024:.1f}MB) — single-window CUDA limit applies" ) + return 1 - return pinned_count + if candidate_count == 0: + print("[L2] No tensors pinned (params may require_grad=True before compile — call after freeze)") + return 0 def setup_l2_persistence(unet: torch.nn.Module) -> bool: @@ -358,8 +362,15 @@ def setup_l2_persistence(unet: torch.nn.Module) -> bool: Call this AFTER model is loaded and BEFORE torch.compile. For best results with frozen weights, call AFTER torch.compile with freezing=True. + Tier 1 (L2 set-aside via cudaDeviceSetLimit) is always attempted — this reserves + a portion of L2 for hot data and is correct for all GPU modes including TRT engines. + + Tier 2 (per-tensor access policy window) is opt-in via SDTD_L2_PERSIST_TIER2=1. + It only works for PyTorch nn.Module UNets (not TRT engines), and CUDA allows only + one window per stream — this function registers only the single largest hot tensor. + Args: - unet: The UNet model on CUDA. + unet: The UNet model on CUDA (nn.Module for Tier-2 to apply; TRT Engine for Tier-1 only). Returns: True if at least Tier 1 (L2 reservation) succeeded. @@ -367,21 +378,18 @@ def setup_l2_persistence(unet: torch.nn.Module) -> bool: if not L2_PERSIST_ENABLED: return False - print( - f"\n[L2] Setting up L2 cache persistence " - f"(SDTD_L2_PERSIST_MB={L2_PERSIST_MB})..." - ) + print(f"\n[L2] Setting up L2 cache persistence (SDTD_L2_PERSIST_MB={L2_PERSIST_MB})...") - # Tier 1 is the reliable baseline — always attempt + # Tier 1: Reserve L2 persisting region — works for all GPU modes, always attempt. tier1_ok = reserve_l2_persisting_cache(L2_PERSIST_MB) if tier1_ok: - # Tier 2: per-tensor access policy (best-effort) - pinned = pin_hot_unet_weights(unet, persist_mb=0) # Tier 1 already reserved - if pinned == 0: + if L2_PERSIST_TIER2: + # Tier 2: per-tensor access policy window — opt-in, nn.Module only. + pin_hot_unet_weights(unet, persist_mb=0) # Tier 1 already reserved above + else: print( - "[L2] Tier 2 access policy skipped (call pin_hot_unet_weights() " - "after compile+freeze for per-tensor control)" + "[L2] Tier 2 access policy disabled (set SDTD_L2_PERSIST_TIER2=1 to enable; nn.Module UNet required)" ) return tier1_ok diff --git a/src/streamdiffusion/tools/gpu_profiler.py b/src/streamdiffusion/tools/gpu_profiler.py new file mode 100644 index 00000000..3f4efbbd --- /dev/null +++ b/src/streamdiffusion/tools/gpu_profiler.py @@ -0,0 +1,630 @@ +""" +gpu_profiler.py — Portable GPU profiling module for CUDA/PyTorch projects. + +PORTABILITY: Copy this single file into any project. Only stdlib is required +when profiling is disabled. PyTorch is imported lazily when enabled. + +USAGE: + from streamdiffusion.tools.gpu_profiler import profiler, configure + + configure(enabled=True, nvtx=True, events=True) + + with profiler.region("inference"): + output = model(input) + + profiler.report() + +CUDA GRAPH COMPATIBILITY: + NVTX push/pop calls break CUDA graph replay — any push/pop recorded during + graph capture fires only once (at capture time), not on each replay step. + When CUDA graphs are active set nvtx=False (or GPU_PROFILER_NVTX=0) so + only CUDA-event timing is collected — events are always graph-safe. + + For StreamDiffusion's TRT engine path: set GPU_PROFILER_NVTX=0 to use + events-only mode during graph-replayed inference, and GPU_PROFILER_NVTX=1 + for a non-graph capture run (STREAMDIFFUSION_PROFILE_TRT=1 disables graphs + when you need per-layer IProfiler timing instead). + +NSIGHT SYSTEMS COMMAND: + nsys profile --trace=cuda,nvtx,cublas,cudnn --cuda-memory-usage=true \\ + -o profiles/sdtd_out --force-overwrite true \\ + .venv/Scripts/python scripts/profiling/profile_nsys.py --target benchmark +""" + +from __future__ import annotations + +import json +import os +import pickle +import time +from contextlib import contextmanager +from functools import wraps +from typing import Any, Callable, Dict, Generator, List, Optional + + +# ───────────────────────────────────────────────────────────────────────────── +# RegionStats — per-region histogram with percentile support +# ───────────────────────────────────────────────────────────────────────────── + + +class RegionStats: + """Histogram-based timing statistics for a named profiling region.""" + + __slots__ = ("name", "samples", "count", "total_ms") + + MAX_SAMPLES = 10_000 # cap to avoid unbounded memory + + def __init__(self, name: str) -> None: + self.name = name + self.samples: List[float] = [] + self.count: int = 0 + self.total_ms: float = 0.0 + + def record(self, ms: float) -> None: + self.count += 1 + self.total_ms += ms + if len(self.samples) < self.MAX_SAMPLES: + self.samples.append(ms) + + @property + def mean(self) -> float: + return self.total_ms / self.count if self.count else 0.0 + + @property + def p50(self) -> float: + return self._percentile(50) + + @property + def p95(self) -> float: + return self._percentile(95) + + @property + def p99(self) -> float: + return self._percentile(99) + + @property + def min(self) -> float: + return min(self.samples) if self.samples else 0.0 + + @property + def max(self) -> float: + return max(self.samples) if self.samples else 0.0 + + def _percentile(self, p: int) -> float: + if not self.samples: + return 0.0 + s = sorted(self.samples) + idx = int(len(s) * p / 100) + return s[min(idx, len(s) - 1)] + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "count": self.count, + "mean_ms": round(self.mean, 3), + "p50_ms": round(self.p50, 3), + "p95_ms": round(self.p95, 3), + "p99_ms": round(self.p99, 3), + "min_ms": round(self.min, 3), + "max_ms": round(self.max, 3), + "total_ms": round(self.total_ms, 3), + } + + +# ───────────────────────────────────────────────────────────────────────────── +# _RegionCtx — context manager for a single profiled region +# ───────────────────────────────────────────────────────────────────────────── + + +class _RegionCtx: + """Context manager that records one entry for a named region. + + On entry: optional NVTX range_push + CUDA event start record. + On exit: optional NVTX range_pop + CUDA event elapsed_time -> RegionStats. + """ + + __slots__ = ("_profiler", "_name", "_nvtx", "_start_evt", "_end_evt") + + def __init__(self, profiler: "GPUProfiler", name: str) -> None: + self._profiler = profiler + self._name = name + self._nvtx = profiler._nvtx_enabled + self._start_evt = None + self._end_evt = None + + def __enter__(self) -> "_RegionCtx": + p = self._profiler + if self._nvtx: + p._torch.cuda.nvtx.range_push(self._name) + if p._events_enabled: + self._start_evt = p._torch.cuda.Event(enable_timing=True) + self._end_evt = p._torch.cuda.Event(enable_timing=True) + self._start_evt.record() + return self + + def __exit__(self, *_: object) -> None: + p = self._profiler + if p._events_enabled and self._start_evt is not None: + self._end_evt.record() + # Synchronize lazily — elapsed_time blocks only when read. + # We defer the sync to avoid stalling the GPU here. + p._pending.append((self._name, self._start_evt, self._end_evt)) + if self._nvtx: + p._torch.cuda.nvtx.range_pop() + + +class _NullCtx: + """Zero-overhead context manager used when profiler is disabled.""" + + __slots__ = () + + def __enter__(self) -> "_NullCtx": + return self + + def __exit__(self, *_: object) -> None: + pass + + +# ───────────────────────────────────────────────────────────────────────────── +# GPUProfiler — the real profiler (only instantiated when enabled) +# ───────────────────────────────────────────────────────────────────────────── + + +class GPUProfiler: + """ + Unified GPU profiling singleton. + + Activate via module-level ``configure(enabled=True, ...)``. + All methods are safe to call from any thread/process. + """ + + def __init__(self) -> None: + self._nvtx_enabled: bool = False + self._events_enabled: bool = False + self._memory_enabled: bool = False + self._trace_path: Optional[str] = None + + self._regions: Dict[str, RegionStats] = {} + self._pending: List[tuple] = [] # (name, start_evt, end_evt) awaiting sync + + self._torch_profiler = None # active torch.profiler.profile instance + self._profiler_step: int = 0 + + self._torch = None # lazy torch reference + self._cudart = None # lazy cudart reference + + def configure( + self, + enabled: bool = True, + nvtx: bool = True, + events: bool = True, + memory: bool = False, + trace_path: Optional[str] = None, + ) -> None: + """Configure the profiler. Must be called once before any region().""" + import torch as _torch + + self._torch = _torch + self._nvtx_enabled = nvtx and _torch.cuda.is_available() + self._events_enabled = events and _torch.cuda.is_available() + self._memory_enabled = memory + self._trace_path = trace_path + + # ── Core API ───────────────────────────────────────────────────────────── + + def region(self, name: str) -> _RegionCtx: + """Return a context manager that profiles one execution of ``name``.""" + if name not in self._regions: + self._regions[name] = RegionStats(name) + return _RegionCtx(self, name) + + def trace(self, name: str) -> Callable: + """Decorator that wraps a function in a named profiler region. + + Usage:: + + @profiler.trace("cupy_rgba_to_rgb") + def my_kernel(src, dst): ... + """ + + def decorator(fn: Callable) -> Callable: + @wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + with self.region(name): + return fn(*args, **kwargs) + + return wrapper + + return decorator + + def mark(self, name: str) -> None: + """Place an NVTX instant marker (zero-duration annotation).""" + if self._nvtx_enabled and self._torch is not None: + self._torch.cuda.nvtx.range_push(name) + self._torch.cuda.nvtx.range_pop() + + def begin(self, name: str) -> None: + """Open a named NVTX range without a context manager (pair with end()).""" + if name not in self._regions: + self._regions[name] = RegionStats(name) + if self._nvtx_enabled and self._torch is not None: + self._torch.cuda.nvtx.range_push(name) + + def end(self, name: str) -> None: + """Close a previously opened named NVTX range.""" + if self._nvtx_enabled and self._torch is not None: + self._torch.cuda.nvtx.range_pop() + + # ── Nsight Systems gated capture ───────────────────────────────────────── + + def nsys_start(self) -> None: + """Signal Nsight Systems to begin capture (cudaProfilerStart). + + Run your script under nsys: ``nsys profile --trace=cuda,nvtx ...`` + Capture only starts when this is called — useful to skip warmup. + """ + if self._torch is not None and self._torch.cuda.is_available(): + try: + self._torch.cuda.cudart().cudaProfilerStart() + except Exception: # broad: CUDA profiler API may not be available (no nsys, profiling disabled) + pass + + def nsys_stop(self) -> None: + """Signal Nsight Systems to stop capture (cudaProfilerStop).""" + if self._torch is not None and self._torch.cuda.is_available(): + try: + self._torch.cuda.cudart().cudaProfilerStop() + except Exception: # broad: CUDA profiler API may not be available (no nsys, profiling disabled) + pass + + # ── torch.profiler integration ──────────────────────────────────────────── + + @contextmanager + def torch_trace( + self, + path: Optional[str] = None, + warmup: int = 1, + active: int = 5, + ) -> Generator[None, None, None]: + """Context manager wrapping torch.profiler.profile. + + Schedule: wait=0, warmup=``warmup``, active=``active``. + Exports Chrome trace to ``path`` (or self._trace_path if not specified). + Also prints top-30 ops by CUDA time to stdout. + + Usage:: + + with profiler.torch_trace("trace.json", warmup=1, active=5): + for i in range(warmup + active): + profiler.step() + run_inference() + """ + out_path = path or self._trace_path or "gpu_profile_trace.json" + if self._torch is None: + yield + return + + torch_profiler_mod = self._torch.profiler + + def _on_trace_ready(prof: Any) -> None: + prof.export_chrome_trace(out_path) + print(f"\n[gpu_profiler] Chrome trace -> {out_path}") + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) + + with torch_profiler_mod.profile( + activities=[ + torch_profiler_mod.ProfilerActivity.CPU, + torch_profiler_mod.ProfilerActivity.CUDA, + ], + schedule=torch_profiler_mod.schedule(wait=0, warmup=warmup, active=active), + on_trace_ready=_on_trace_ready, + record_shapes=True, + with_stack=True, + ) as prof: + self._torch_profiler = prof + try: + yield + finally: + self._torch_profiler = None + + def step(self) -> None: + """Advance the torch.profiler schedule by one step. + + Call once per iteration inside a ``torch_trace`` context. + """ + if self._torch_profiler is not None: + self._torch_profiler.step() + + # ── Memory profiling ────────────────────────────────────────────────────── + + @contextmanager + def memory_trace(self, path: str = "mem_snapshot.pkl") -> Generator[None, None, None]: + """Context manager that captures a VRAM allocation snapshot. + + The resulting ``.pkl`` file can be converted to interactive HTML via:: + + python -c " + import pickle, torch + with open('mem_snapshot.pkl','rb') as f: + snap = pickle.load(f) + html = torch.cuda._memory_viz.trace_plot(snap) + open('memory.html','w').write(html) + " + """ + if self._torch is None or not self._torch.cuda.is_available(): + yield + return + + self._torch.cuda.synchronize() + self._torch.cuda.memory._record_memory_history( + True, + trace_alloc_max_entries=100_000, + trace_alloc_record_context=True, + ) + try: + yield + finally: + self._torch.cuda.synchronize() + snapshot = self._torch.cuda.memory._snapshot() + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + with open(path, "wb") as fh: + pickle.dump(snapshot, fh) + print(f"[gpu_profiler] Memory snapshot -> {path}") + try: + self._torch.cuda.memory._record_memory_history(False) + except Exception: # broad: may fail if history was never started, or on older PyTorch without the API + pass + + # ── Statistics ──────────────────────────────────────────────────────────── + + def flush(self) -> None: + """Resolve all pending CUDA event timings (forces GPU sync). + + Called automatically by report() and export_stats(). You can also + call it manually after a batch of iterations to get up-to-date stats + without printing. + """ + if not self._pending: + return + if self._torch is not None: + self._torch.cuda.synchronize() + for name, start_evt, end_evt in self._pending: + try: + ms = start_evt.elapsed_time(end_evt) + self._regions[name].record(ms) + except Exception: # broad: CUDA event timing fails if event was never recorded (e.g., boundary skipped) + pass + self._pending.clear() + + def report(self, top_n: int = 30) -> None: + """Print a summary table sorted by total CUDA time. + + Flushes pending events first. + """ + self.flush() + if not self._regions: + print("[gpu_profiler] No regions recorded.") + return + + rows = sorted( + self._regions.values(), + key=lambda s: s.total_ms, + reverse=True, + )[:top_n] + + col_w = max(len(r.name) for r in rows) + 2 + header = ( + f"{'Region':<{col_w}} {'Count':>6} " + f"{'Mean':>8} {'P50':>8} {'P95':>8} {'P99':>8} " + f"{'Min':>8} {'Max':>8} {'Total':>10}" + ) + sep = "-" * len(header) + print(f"\n[gpu_profiler] Timing Report (top {top_n} by total ms)") + print(sep) + print(header) + print(sep) + for r in rows: + print( + f"{r.name:<{col_w}} {r.count:>6} " + f"{r.mean:>7.2f}ms {r.p50:>7.2f}ms " + f"{r.p95:>7.2f}ms {r.p99:>7.2f}ms " + f"{r.min:>7.2f}ms {r.max:>7.2f}ms " + f"{r.total_ms:>9.1f}ms" + ) + print(sep) + + def export_stats(self, path: str = "gpu_profile_stats.json") -> None: + """Write region statistics to a JSON file. + + Flushes pending events first. + """ + self.flush() + data = { + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), + "regions": [s.to_dict() for s in self._regions.values()], + } + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + with open(path, "w") as fh: + json.dump(data, fh, indent=2) + print(f"[gpu_profiler] Stats -> {path}") + + def reset(self) -> None: + """Clear all accumulated timing data and pending events.""" + self._regions.clear() + self._pending.clear() + + +# ───────────────────────────────────────────────────────────────────────────── +# _NullProfiler — all methods are no-ops, zero imports, zero overhead +# ───────────────────────────────────────────────────────────────────────────── + +_NULL_CTX = _NullCtx() + + +class _NullProfiler: + """Drop-in replacement for GPUProfiler when profiling is disabled. + + Every method is a no-op. region() returns a shared _NullCtx singleton + that has bare __enter__/__exit__ bodies — no attribute lookups, no CUDA + calls, no allocation. + """ + + __slots__ = () + + def region(self, name: str) -> _NullCtx: # noqa: ARG002 + return _NULL_CTX + + def trace(self, name: str) -> Callable: # noqa: ARG002 + """Return identity decorator — function is NOT wrapped.""" + + def decorator(fn: Callable) -> Callable: + return fn + + return decorator + + def mark(self, name: str) -> None: + pass # noqa: E704 + + def begin(self, name: str) -> None: + pass # noqa: E704 + + def end(self, name: str) -> None: + pass # noqa: E704 + + def nsys_start(self) -> None: + pass # noqa: E704 + + def nsys_stop(self) -> None: + pass # noqa: E704 + + def step(self) -> None: + pass # noqa: E704 + + def flush(self) -> None: + pass # noqa: E704 + + def report(self, top_n: int = 30) -> None: + pass # noqa: E704, ARG002 + + def reset(self) -> None: + pass # noqa: E704 + + @contextmanager + def torch_trace(self, path: Optional[str] = None, warmup: int = 1, active: int = 5) -> Generator[None, None, None]: # noqa: ARG002 + yield + + @contextmanager + def memory_trace(self, path: str = "mem_snapshot.pkl") -> Generator[None, None, None]: # noqa: ARG002 + yield + + def export_stats(self, path: str = "gpu_profile_stats.json") -> None: # noqa: ARG002 + pass + + def configure(self, **kwargs: Any) -> None: + pass # noqa: E704, ARG002 + + +# ───────────────────────────────────────────────────────────────────────────── +# _ProfilerProxy — stable singleton; configure() mutates the delegate in-place +# ───────────────────────────────────────────────────────────────────────────── + + +class _ProfilerProxy: + """Stable proxy that delegates every call to the active inner profiler. + + Importing ``from streamdiffusion.tools.gpu_profiler import profiler`` is + safe to do once at module load time. ``configure()`` updates the inner + delegate in-place, so stale import references keep working correctly + without requiring a re-import after configure. + + Pattern: profiler._set_inner(new_instance) → all future calls forwarded + """ + + __slots__ = ("_inner",) + + def __init__(self) -> None: + object.__setattr__(self, "_inner", _NullProfiler()) + + def _set_inner(self, inner: Any) -> None: + object.__setattr__(self, "_inner", inner) + + def __getattr__(self, name: str) -> Any: + return getattr(object.__getattribute__(self, "_inner"), name) + + +# ───────────────────────────────────────────────────────────────────────────── +# Module-level singleton + configure() +# ───────────────────────────────────────────────────────────────────────────── + +# Stable proxy — safe to import once. configure() swaps the inner delegate. +profiler: Any = _ProfilerProxy() + + +def configure( + enabled: bool = False, + nvtx: bool = True, + events: bool = True, + memory: bool = False, + trace_path: Optional[str] = None, +) -> None: + """Configure the module-level profiler singleton. + + Args: + enabled: Master switch. When False, all profiler calls are no-ops. + nvtx: Emit NVTX ranges (visible in Nsight Systems timeline). + Disable when CUDA graphs are active (GPU_PROFILER_NVTX=0) + to avoid incorrect timeline positions on graph replay. + events: Collect CUDA-event timing into RegionStats histograms. + Always safe with CUDA graphs. + memory: Enable torch.cuda.memory._record_memory_history when + memory_trace() context is entered. + trace_path: Default Chrome trace output path for torch_trace(). + + Also reads environment variables (take priority over config when GPU_PROFILER=1): + GPU_PROFILER=1 → enabled=True; resets nvtx/events defaults to True + GPU_PROFILER_NVTX=0 → nvtx=False (only when GPU_PROFILER=1) + GPU_PROFILER_EVENTS=0 → events=False (only when GPU_PROFILER=1) + """ + _env_enabled = os.environ.get("GPU_PROFILER", "0") == "1" + enabled = enabled or _env_enabled + if _env_enabled: + # Env activation: use env vars as the sole source for nvtx/events so + # config "events: false" can't silently suppress CUDA-event collection. + nvtx = os.environ.get("GPU_PROFILER_NVTX", "1") != "0" + events = os.environ.get("GPU_PROFILER_EVENTS", "1") != "0" + else: + nvtx = nvtx and os.environ.get("GPU_PROFILER_NVTX", "1") != "0" + events = events and os.environ.get("GPU_PROFILER_EVENTS", "1") != "0" + + if not enabled: + profiler._set_inner(_NullProfiler()) + return + + p = GPUProfiler() + p.configure(enabled=enabled, nvtx=nvtx, events=events, memory=memory, trace_path=trace_path) + profiler._set_inner(p) + + +def configure_from_dict(cfg: Dict[str, Any]) -> None: + """Convenience: read profiling settings from a config sub-dict. + + Expected keys (all optional):: + + { + "profiling": { + "enabled": false, + "nvtx": true, + "events": true, + "memory": false, + "trace_path": "profiler_logs/trace.json" + } + } + """ + prof_cfg = cfg.get("profiling", {}) + configure( + enabled=prof_cfg.get("enabled", False), + nvtx=prof_cfg.get("nvtx", True), + events=prof_cfg.get("events", True), + memory=prof_cfg.get("memory", False), + trace_path=prof_cfg.get("trace_path", None), + ) diff --git a/src/streamdiffusion/tools/install-tensorrt.py b/src/streamdiffusion/tools/install-tensorrt.py index 5862e931..2aa61df6 100644 --- a/src/streamdiffusion/tools/install-tensorrt.py +++ b/src/streamdiffusion/tools/install-tensorrt.py @@ -1,10 +1,10 @@ +import platform from typing import Literal, Optional import fire from packaging.version import Version -from ..pip_utils import is_installed, run_pip, version, get_cuda_major -import platform +from ..pip_utils import get_cuda_major, is_installed, run_pip, version def install(cu: Optional[Literal["11", "12"]] = get_cuda_major()): @@ -13,50 +13,42 @@ def install(cu: Optional[Literal["11", "12"]] = get_cuda_major()): print("Installing TensorRT requirements...") - min_trt_version = Version("10.12.0") if cu == "12" else Version("9.0.0") + min_trt_version = Version("10.16.0") if cu == "12" else Version("9.0.0") trt_version = version("tensorrt") if trt_version and trt_version < min_trt_version: run_pip("uninstall -y tensorrt") cudnn_package, trt_package = ( - ("nvidia-cudnn-cu12==9.7.1.26", "tensorrt==10.12.0.36") - if cu == "12" else - ("nvidia-cudnn-cu11==8.9.7.29", "tensorrt==9.0.1.post11.dev4") + ("nvidia-cudnn-cu12==9.7.1.26", "tensorrt==10.16.1.11") + if cu == "12" + else ("nvidia-cudnn-cu11==8.9.7.29", "tensorrt==9.0.1.post11.dev4") ) if not is_installed(trt_package): run_pip(f"install {cudnn_package} --no-cache-dir") run_pip(f"install --extra-index-url https://pypi.nvidia.com {trt_package} --no-cache-dir") if not is_installed("polygraphy"): - run_pip( - "install polygraphy==0.49.26 --extra-index-url https://pypi.ngc.nvidia.com" - ) + run_pip("install polygraphy==0.49.26 --extra-index-url https://pypi.ngc.nvidia.com") if not is_installed("onnx_graphsurgeon"): - run_pip( - "install onnx-graphsurgeon==0.5.8 --extra-index-url https://pypi.ngc.nvidia.com" - ) - if platform.system() == 'Windows' and not is_installed("pywin32"): - run_pip( - "install pywin32==311" - ) - if platform.system() == 'Windows' and not is_installed("triton"): - run_pip( - "install triton-windows==3.4.0.post21" - ) - - # Pin onnx 1.18 + onnxruntime-gpu 1.24 together: - # - onnx 1.18 exports IR 11; modelopt needs FLOAT4E2M1 added in 1.18 - # - onnx 1.19+ exports IR 12 (ORT 1.24 max) and removes float32_to_bfloat16 (onnx-gs needs it) - # - onnxruntime-gpu 1.24 supports IR 11; never co-install CPU onnxruntime (shared files conflict) - run_pip("install onnx==1.18.0 onnxruntime-gpu==1.24.4 --no-cache-dir") + run_pip("install onnx-graphsurgeon==0.6.1 --extra-index-url https://pypi.ngc.nvidia.com") + if platform.system() == "Windows" and not is_installed("pywin32"): + run_pip("install pywin32==311") + if platform.system() == "Windows" and not is_installed("triton"): + run_pip("install triton-windows==3.4.0.post21") + + # ONNX stack aligned with FLUX for TRT 10.16: + # - onnx 1.19.1 (IR 11); modelopt's FLOAT4E2M1 support landed in 1.18 and stays in 1.19 + # - onnx-gs 0.6.1 no longer needs float32_to_bfloat16 (previously forced onnx==1.18) + # - onnxruntime-gpu 1.24.4 supports IR 11; never co-install CPU onnxruntime (shared files conflict) + # - onnxoptimizer/onnxslim/onnxscript pair with the onnxoptimizer.optimize_from_path pipeline + run_pip( + "install onnx==1.19.1 onnxruntime-gpu==1.24.4 onnxoptimizer==0.4.2 onnxslim==0.1.91 onnxscript==0.6.2 --no-cache-dir" + ) # FP8 quantization dependencies (CUDA 12 only) # nvidia-modelopt requires cupy; pin cupy 13.x + numpy<2 for mediapipe compat if cu == "12": - run_pip( - 'install "nvidia-modelopt[onnx]" "cupy-cuda12x==13.6.0" "numpy==1.26.4"' - " --no-cache-dir" - ) + run_pip("install nvidia-modelopt[onnx] cupy-cuda12x==13.6.0 numpy==1.26.4 --no-cache-dir") if __name__ == "__main__": diff --git a/src/streamdiffusion/utils/__init__.py b/src/streamdiffusion/utils/__init__.py index 00ff7cf7..b40413d2 100644 --- a/src/streamdiffusion/utils/__init__.py +++ b/src/streamdiffusion/utils/__init__.py @@ -1,5 +1,6 @@ from .reporting import report_error + __all__ = [ "report_error", -] \ No newline at end of file +] diff --git a/src/streamdiffusion/utils/reporting.py b/src/streamdiffusion/utils/reporting.py index 44838d9c..25e650c6 100644 --- a/src/streamdiffusion/utils/reporting.py +++ b/src/streamdiffusion/utils/reporting.py @@ -25,5 +25,3 @@ def report_error( stacklevel=stacklevel, extra={"report_error": True}, ) - - diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 997208b3..71fca816 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -11,6 +11,8 @@ from .image_utils import postprocess_image from .model_detection import detect_model from .pipeline import StreamDiffusion +from .tools.gpu_profiler import configure as _configure_profiler +from .tools.gpu_profiler import profiler logger = logging.getLogger(__name__) @@ -127,6 +129,8 @@ def __init__( max_cache_maxframes: int = 4, fp8: bool = False, static_shapes: bool = False, + fp8_allow_fp16_fallback: bool = False, + builder_optimization_level: Optional[int] = None, ): """ Initializes the StreamDiffusionWrapper. @@ -245,6 +249,31 @@ def __init__( The maximum number of frames to cache, by default 1. cache_interval : int, optional The interval to cache the frames, by default 1. + builder_optimization_level : Optional[int], optional + TensorRT IBuilderConfig.builder_optimization_level (range 0-5, + TRT default 3). When set, overrides the per-GPU auto-detect default + in ``acceleration/tensorrt/utilities.py::detect_gpu_profile()``. + + TouchDesigner TrtProfile mapping aligned with NVIDIA reference + pipelines (demoDiffusion: level 3 for FP16; TensorRT-Model-Optimizer: + level 4 for FP8/INT8 quantized):: + + 0 = Flexible static_shapes=False + level 3 — FP16 dynamic; + matches NVIDIA demoDiffusion default. + 2 = Fast Build static_shapes=True + level 2 — heuristic-sorted + fastest tactics; ~30-40% faster build with minimal + runtime loss (build-time tradeoff). + 4 = Quality static_shapes=True + level 3 — FP16 static; + matches NVIDIA demoDiffusion default (level 4 has + no NVIDIA-validated benefit for unquantized FP16). + Performance static_shapes=True + level 4 + fp8=True — + matches NVIDIA TensorRT-Model-Optimizer default + for quantized diffusion (RTX 40+ only). + + Levels 1 and 5 are valid TRT values but not exposed via TrtProfile + UI (1 = degraded; 5 = used by no NVIDIA reference pipeline). Set to + None to auto-detect per GPU (Ada/Ampere/Blackwell → 4, pre-Ampere + → 3). Default None. """ if compile_engines_only: logger.info("compile_engines_only is True, will only compile engines and not load the model") @@ -274,6 +303,7 @@ def __init__( if not use_denoising_batch: raise NotImplementedError("img2img mode must use denoising batch for now.") + _configure_profiler() # activates via GPU_PROFILER=1 env var; no-op otherwise self.device = device self.dtype = dtype self.width = width @@ -281,6 +311,9 @@ def __init__( self.mode = mode self.output_type = output_type self.frame_buffer_size = frame_buffer_size + self._output_pin_buf: Optional[torch.Tensor] = None # pinned CPU buffer for async D2H output + self._output_gpu_buf: Optional[torch.Tensor] = None # persistent GPU fp32 staging (avoids per-frame alloc) + self._d2h_event: Optional[torch.cuda.Event] = None # event for fine-grained D2H sync self.batch_size = len(t_index_list) * frame_buffer_size if use_denoising_batch else frame_buffer_size self.min_batch_size = min_batch_size self.max_batch_size = max_batch_size @@ -293,6 +326,8 @@ def __init__( self.safety_checker_threshold = safety_checker_threshold self.fp8 = fp8 self.static_shapes = static_shapes + self.fp8_allow_fp16_fallback = fp8_allow_fp16_fallback + self.builder_optimization_level = builder_optimization_level self.stream: StreamDiffusion = self._load_model( model_id_or_path=model_id_or_path, @@ -879,9 +914,25 @@ def postprocess_image( # Denormalize on GPU, return tensor return self._denormalize_on_gpu(image_tensor) elif output_type == "np": - # Denormalize on GPU, then single efficient CPU transfer + # GPU uint8 conversion + single async DMA to pinned host buffer. + # uint8 is 4× smaller than fp32, so PCIe transfer time is 4× shorter. + # Eliminates the intermediate fp32 GPU staging buffer and the PIL round-trip + # that was needed when callers immediately called np.array(pil_image). denormalized = self._denormalize_on_gpu(image_tensor) - return denormalized.cpu().permute(0, 2, 3, 1).float().numpy() + uint8_nhwc = (denormalized * 255).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).contiguous() + if ( + self._output_pin_buf is None + or self._output_pin_buf.shape != uint8_nhwc.shape + or self._output_pin_buf.dtype != torch.uint8 + ): + self._output_pin_buf = torch.empty(uint8_nhwc.shape, dtype=torch.uint8, pin_memory=True) + self._d2h_event = torch.cuda.Event() + self._output_pin_buf.copy_(uint8_nhwc, non_blocking=True) + with profiler.region("d2h_sync"): + self._d2h_event.record() + self._d2h_event.synchronize() + out = self._output_pin_buf.numpy() + return out if self.frame_buffer_size > 1 else out[0] # PIL output path (optimized) if output_type == "pil": @@ -1312,7 +1363,7 @@ def _load_model( if use_cached_attn: from streamdiffusion.acceleration.tensorrt.models.utils import create_kvo_cache - kvo_cache, _ = create_kvo_cache( + kvo_cache, _, kvo_buckets, kvo_outputs_by_bucket = create_kvo_cache( pipe.unet, batch_size=stream.trt_unet_batch_size, cache_maxframes=max_cache_maxframes, # Allocate at max to avoid runtime resize race @@ -1322,6 +1373,8 @@ def _load_model( dtype=self.dtype, ) stream.kvo_cache = kvo_cache + stream._kvo_buckets = kvo_buckets + stream._kvo_outputs_by_bucket = kvo_outputs_by_bucket # Load and properly merge LoRA weights using the standard diffusers approach lora_adapters_to_merge = [] @@ -1362,7 +1415,7 @@ def _load_model( # Clean up any partial state try: stream.pipe.unload_lora_weights() - except: + except Exception: pass if use_tiny_vae: @@ -1513,6 +1566,7 @@ def _load_model( use_controlnet=use_controlnet_trt, fp8=fp8, resolution=(self.height, self.width), + builder_optimization_level=self.builder_optimization_level, ) vae_encoder_path = engine_manager.get_engine_path( EngineType.VAE_ENCODER, @@ -1526,6 +1580,7 @@ def _load_model( ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None, resolution=(self.height, self.width), + builder_optimization_level=self.builder_optimization_level, ) vae_decoder_path = engine_manager.get_engine_path( EngineType.VAE_DECODER, @@ -1539,6 +1594,7 @@ def _load_model( ipadapter_tokens=ipadapter_tokens, is_faceid=is_faceid if use_ipadapter_trt else None, resolution=(self.height, self.width), + builder_optimization_level=self.builder_optimization_level, ) # Check if all required engines exist @@ -1586,12 +1642,7 @@ def _load_model( logger.info( f"compile_and_load_engine: Compiling UNet engine for image size: {self.width}x{self.height}" ) - try: - logger.debug( - f"compile_and_load_engine: use_ipadapter_trt={use_ipadapter_trt}, num_ip_layers={num_ip_layers}, tokens={num_tokens}" - ) - except Exception: - pass + logger.debug(f"compile_and_load_engine: use_ipadapter_trt={use_ipadapter_trt}, tokens={num_tokens}") # Note: LoRA weights have already been merged permanently during model loading @@ -1616,7 +1667,7 @@ def _load_model( # Snapshot processors before install — IPAdapter.set_ip_adapter() replaces them # before load_state_dict(), so a failure leaves the UNet in corrupted state - _saved_unet_processors = {name: proc for name, proc in stream.unet.attn_processors.items()} + _saved_unet_processors = dict(stream.unet.attn_processors) # Use first config if list provided cfg = ipadapter_config[0] if isinstance(ipadapter_config, list) else ipadapter_config @@ -1793,7 +1844,16 @@ def _load_model( "opt_image_width": self.width, "build_dynamic_shape": not self.static_shapes, "build_static_batch": self.static_shapes, - **({"min_image_resolution": 384, "max_image_resolution": 1024, "build_all_tactics": True} if not self.static_shapes else {}), + **( + {"min_image_resolution": 384, "max_image_resolution": 1024, "build_all_tactics": True} + if not self.static_shapes + else {} + ), + **( + {"builder_optimization_level": self.builder_optimization_level} + if self.builder_optimization_level is not None + else {} + ), }, ) @@ -1818,35 +1878,59 @@ def _load_model( "opt_image_width": self.width, "build_dynamic_shape": not self.static_shapes, "build_static_batch": self.static_shapes, - **({"min_image_resolution": 384, "max_image_resolution": 1024, "build_all_tactics": True} if not self.static_shapes else {}), + **( + {"min_image_resolution": 384, "max_image_resolution": 1024, "build_all_tactics": True} + if not self.static_shapes + else {} + ), + **( + {"builder_optimization_level": self.builder_optimization_level} + if self.builder_optimization_level is not None + else {} + ), }, ) + # Use polygraphy's default Blocking stream. A NonBlocking engine stream + # would skip the legacy/per-thread NULL-stream auto-sync that the rest of + # the pipeline relies on (PyTorch ops run on stream 0x0), creating a data + # race where the engine reads stale inputs and writes outputs that + # downstream PyTorch never observes — symptom is black/zero output frames. cuda_stream = cuda.Stream() vae_config = stream.vae.config vae_dtype = stream.vae.dtype try: - logger.info("Loading TensorRT UNet engine...") - # Build engine_build_options, adding FP8 calibration callback when enabled. + logger.warning( + f"[TRT] UNet engine: fp8={fp8}, static_shapes={self.static_shapes}, engine_path={unet_path}" + ) _unet_build_opts = { "opt_image_height": self.height, "opt_image_width": self.width, "build_dynamic_shape": False, "build_static_batch": True, } + if self.builder_optimization_level is not None: + _unet_build_opts["builder_optimization_level"] = self.builder_optimization_level if fp8: - from streamdiffusion.acceleration.tensorrt.fp8_quantize import ( - generate_unet_calibration_data, - ) - _captured_model = unet_model - _calib_batch = stream.trt_unet_batch_size - _calib_h, _calib_w = self.height, self.width + _is_turbo = getattr(self, "_is_turbo", False) _unet_build_opts["fp8"] = True - _unet_build_opts["onnx_opset"] = 19 # modelopt FP8 needs opset ≥19 for fp16 Q/DQ scales - _unet_build_opts["calibration_data_fn"] = lambda: generate_unet_calibration_data( - _captured_model, _calib_batch, _calib_h, _calib_w + _unet_build_opts["onnx_opset"] = 19 # FP8 Q/DQ scales require opset ≥19 + _unet_build_opts["pipe_ref"] = stream.pipe + # SDXL-Turbo: 4 steps, guidance_scale=0.0 (matches inference); + # SDXL base: 20 steps, guidance_scale=7.5. + # Calibration activations must match inference-time ranges. + _unet_build_opts["calibration_steps"] = 4 if _is_turbo else 20 + _unet_build_opts["fp8_guidance_scale"] = 0.0 if _is_turbo else 7.5 + _unet_build_opts["fp8_allow_fp16_fallback"] = self.fp8_allow_fp16_fallback + _unet_build_opts["fp8_use_cached_attn"] = use_cached_attn + _unet_build_opts["fp8_use_controlnet"] = use_controlnet_trt + _unet_build_opts["fp8_num_ip_layers"] = num_ip_layers if use_ipadapter_trt else 0 + logger.warning( + f"[TRT] FP8 build opts: turbo={_is_turbo}, " + f"steps={_unet_build_opts['calibration_steps']}, " + f"guidance={_unet_build_opts['fp8_guidance_scale']}" ) # Compile and load UNet engine using EngineManager @@ -1885,7 +1969,7 @@ def _load_model( if hasattr(stream, "unet"): try: del stream.unet - except: + except Exception: pass self.cleanup_gpu_memory() @@ -1943,7 +2027,7 @@ def _load_model( if hasattr(stream, "vae"): try: del stream.vae - except: + except Exception: pass self.cleanup_gpu_memory() @@ -2073,6 +2157,7 @@ def _load_model( opt_image_width=self.width, load_engine=load_engine, conditioning_channels=cfg.get("conditioning_channels", 3), + builder_optimization_level=self.builder_optimization_level, ) try: setattr(engine, "model_id", cfg["model_id"]) @@ -2080,11 +2165,15 @@ def _load_model( pass compiled_cn_engines.append(engine) except Exception as e: - logger.warning(f"Failed to compile/load ControlNet engine for {cfg.get('model_id')}: {e}") + logger.warning( + f"Failed to compile/load ControlNet engine for {cfg.get('model_id')}: {e}" + ) if compiled_cn_engines: setattr(stream, "controlnet_engines", compiled_cn_engines) try: - logger.info(f"Compiled/loaded {len(compiled_cn_engines)} ControlNet TensorRT engine(s)") + logger.info( + f"Compiled/loaded {len(compiled_cn_engines)} ControlNet TensorRT engine(s)" + ) except Exception: pass except Exception: @@ -2131,7 +2220,7 @@ def _load_model( insightface_model_name=cfg.get("insightface_model_name"), ) ip_module = IPAdapterModule(ip_cfg) - _saved_unet_processors_post = {name: proc for name, proc in stream.unet.attn_processors.items()} + _saved_unet_processors_post = dict(stream.unet.attn_processors) ip_module.install(stream) # Expose for later updates stream._ipadapter_module = ip_module @@ -2407,7 +2496,7 @@ def cleanup_gpu_memory(self) -> None: try: self.stream._param_updater.clear_caches() logger.info(" Cleared prompt caches") - except: + except Exception: pass # Enhanced TensorRT engine cleanup @@ -2423,20 +2512,20 @@ def cleanup_gpu_memory(self) -> None: try: # Call the engine's destructor explicitly unet_engine.engine.__del__() - except: + except Exception: pass # Clear all engine-related attributes if hasattr(unet_engine, "context"): try: del unet_engine.context - except: + except Exception: pass if hasattr(unet_engine, "engine"): try: del unet_engine.engine.engine # TensorRT runtime engine del unet_engine.engine - except: + except Exception: pass del self.stream.unet @@ -2454,11 +2543,11 @@ def cleanup_gpu_memory(self) -> None: if hasattr(engine, "engine") and hasattr(engine.engine, "__del__"): try: engine.engine.__del__() - except: + except Exception: pass try: delattr(vae_engine, engine_name) - except: + except Exception: pass del self.stream.vae @@ -2471,7 +2560,7 @@ def cleanup_gpu_memory(self) -> None: self.stream.controlnet_engine_pool.cleanup() del self.stream.controlnet_engine_pool logger.info(" ControlNet engine pool cleanup completed") - except: + except Exception: pass except Exception as e: @@ -2482,10 +2571,16 @@ def cleanup_gpu_memory(self) -> None: try: del self.stream logger.info(" Cleared stream object") - except: + except Exception: pass self.stream = None + # Release wrapper-level frame buffers so the next model swap allocates fresh + # for the new output shape and pinned host memory is returned to the OS. + self._output_pin_buf = None + self._output_gpu_buf = None + self._d2h_event = None + # Force multiple garbage collection cycles for thorough cleanup for i in range(3): gc.collect() diff --git a/tools/summarize_audit.py b/tools/summarize_audit.py new file mode 100644 index 00000000..f865c144 --- /dev/null +++ b/tools/summarize_audit.py @@ -0,0 +1,1706 @@ +#!/usr/bin/env python3 +""" +Dependency Audit Summary Generator + +Parses pip-audit JSON output and generates comprehensive executive summary reports. + +Features: +- Security vulnerability analysis with severity breakdown +- ML stack health check (PyTorch, CUDA, GPU) +- Outdated packages with risk prioritization +- Dependency tree analysis with orphan detection +- Full markdown report saved to audit_reports/ + +Usage: + .venv/Scripts/pip-audit --format json | python tools/summarize_audit.py + # or: + python tools/summarize_audit.py audit.json + # or with custom output: + python tools/summarize_audit.py --no-save # stdout only + python tools/summarize_audit.py -o custom.md # custom output path +""" + +import argparse +import json +import os +import re +import subprocess +import sys +from collections import defaultdict +from datetime import datetime +from pathlib import Path + +import tomllib + + +# ============================================================================ +# PACKAGE CLASSIFICATION CONSTANTS +# ============================================================================ + +# Tier 1: Universal ML packages (present in most ML projects) +UNIVERSAL_ML_CORE = { + # PyTorch ecosystem + "torch", + "torchvision", + "torchaudio", + # HuggingFace ecosystem + "transformers", + "accelerate", + "safetensors", + "peft", + "huggingface-hub", + "tokenizers", + "datasets", + # Core ML utilities + "numpy", + "scipy", + "scikit-learn", + "pandas", + "einops", + "pillow", +} + +# PyTorch ecosystem packages (implicitly version-locked) +PYTORCH_ECOSYSTEM = { + "torch", + "torchvision", + "torchaudio", + "xformers", + "triton", + "triton-windows", + "torch-tensorrt", + "torch_tensorrt", + "torchao", + "nvidia-cudnn-cu12", + "nvidia-cudnn-cu11", + "nvidia-cublas-cu12", + "nvidia-cublas-cu11", +} + +# TensorRT/CUDA stack +TENSORRT_ECOSYSTEM = { + "tensorrt", + "tensorrt-cu12", + "tensorrt_cu12", + "tensorrt_cu12_bindings", + "tensorrt_cu12_libs", + "onnx", + "onnx-graphsurgeon", + "onnx_graphsurgeon", + "onnxruntime", + "polygraphy", +} + +# CUDA/NVIDIA packages +CUDA_STACK = { + "cuda-python", + "cuda-toolkit", + "cuda-pathfinder", + "nvidia-cuda-runtime", + "nvidia-cuda-runtime-cu12", + "nvidia-cuda-runtime-cu11", + "nvidia-cudnn-cu12", + "nvidia-cudnn-cu11", + "nvidia-cublas-cu12", + "nvidia-cublas-cu11", + "nvidia-ml-py", + "nvidia-pyindex", +} + +# Universal dev tools +DEV_TOOLS = { + # Testing + "pytest", + "pytest-cov", + "pytest-asyncio", + "pytest-mock", + "coverage", + # Linting/Formatting + "black", + "isort", + "ruff", + "mypy", + "pyrefly", + "flake8", + # Build/Audit + "pip", + "pip-audit", + "pip_audit", + "pipdeptree", + "pip-licenses", + "wheel", + "setuptools", + "uv", + "ninja", + # Development + "ipython", + "pyreadline3", + "rich", + "colorama", + "tqdm", +} + +# Tier 2: Domain-specific packages (auto-detected based on what's installed) +EMBEDDING_SEARCH = { + "sentence-transformers", + "faiss-cpu", + "faiss-gpu", + "FlagEmbedding", + "rank-bm25", + "ir_datasets", +} + +CODE_PARSING = { + "tree-sitter", + "tree-sitter-python", + "tree-sitter-javascript", + "tree-sitter-typescript", + "tree-sitter-rust", + "tree-sitter-go", + "tree-sitter-java", + "tree-sitter-c", + "tree-sitter-cpp", + "tree-sitter-c-sharp", + "tree-sitter-glsl", +} + +IMAGE_GENERATION = { + "diffusers", + "controlnet-aux", + "compel", + "xformers", + "sageattention", + "tomesd", + "opencv-python", + "opencv-contrib-python", + "mss", +} + +WEB_API = { + "fastapi", + "starlette", + "uvicorn", + "aiohttp", + "httpx", + "mcp", + "sse-starlette", + "python-multipart", +} + +NLP_PACKAGES = { + "nltk", + "tiktoken", + "sentencepiece", + "regex", +} + +# Category display configuration +CATEGORY_LABELS = { + "project_package": ("[PROJECT]", "Project package"), + "ml_core": ("[ML CORE]", "Core ML packages (required)"), + "pytorch_ecosystem": ("[PYTORCH]", "PyTorch ecosystem (required)"), + "cuda_stack": ("[CUDA]", "CUDA/TensorRT stack (required)"), + "dev_tools": ("[DEV]", "Development tools"), + "embedding_search": ("[EMBEDDING]", "Embedding/Search packages"), + "code_parsing": ("[PARSING]", "Code parsing (tree-sitter)"), + "image_generation": ("[IMAGE]", "Image generation packages"), + "web_api": ("[WEB]", "Web/API packages"), + "nlp": ("[NLP]", "NLP packages"), + "true_orphans": ("[?]", "Unknown packages (investigate)"), +} + + +def parse_audit_json(data: dict) -> dict: + """Extract vulnerability information from pip-audit JSON.""" + # Group vulnerabilities by package + vuln_packages = defaultdict(list) + severity_counts = defaultdict(int) + total_packages = 0 + + for dep in data.get("dependencies", []): + pkg_name = dep["name"] + + # Skip dependencies that couldn't be audited + if "skip_reason" in dep: + continue + + total_packages += 1 + pkg_version = dep["version"] + vulns = dep.get("vulns", []) + + if vulns: + for vuln in vulns: + vuln_packages[pkg_name].append( + { + "version": pkg_version, + "cve_id": vuln.get("id", "UNKNOWN"), + "fix_versions": vuln.get("fix_versions", []), + "aliases": vuln.get("aliases", []), + "description": ( + vuln.get("description", "")[:200] + "..." + if len(vuln.get("description", "")) > 200 + else vuln.get("description", "") + ), + } + ) + + # Try to extract severity from CVE ID or description + if any(alias.startswith("CVE-") for alias in vuln.get("aliases", [])): + severity_counts["high"] += 1 + else: + severity_counts["medium"] += 1 + + return { + "total_packages": total_packages, + "vulnerable_packages": len(vuln_packages), + "total_cves": sum(len(v) for v in vuln_packages.values()), + "vulnerabilities": dict(vuln_packages), + "severity_counts": dict(severity_counts), + } + + +def get_python_executable(cli_override: Path | None = None) -> str: + """Get Python executable with priority: CLI > env var > auto-detect > fallback. + + Priority order: + 1. CLI argument (--python) + 2. DEPS_AUDIT_PYTHON environment variable + 3. .venv/Scripts/python.exe or .venv/bin/python (auto-detect) + 4. venv/Scripts/python.exe or venv/bin/python + 5. sys.executable (last resort) + + Args: + cli_override: Optional Python path from CLI argument + + Returns: + Path to Python executable as string + """ + # Priority 1: CLI override + if cli_override and Path(cli_override).exists(): + return str(cli_override) + + # Priority 2: Environment variable + env_python = os.environ.get("DEPS_AUDIT_PYTHON") + if env_python and Path(env_python).exists(): + return env_python + + # Priority 3-5: Auto-detect (venv first, then sys.executable) + candidates = [ + Path.cwd() / ".venv" / "Scripts" / "python.exe", # Windows + Path.cwd() / ".venv" / "bin" / "python", # Linux/Mac + Path.cwd() / "venv" / "Scripts" / "python.exe", # Alt Windows + Path.cwd() / "venv" / "bin" / "python", # Alt Linux/Mac + ] + + for candidate in candidates: + if candidate.exists(): + return str(candidate) + + # Fallback to current Python + return sys.executable + + +def find_python_with_pipdeptree(python_override: Path | None = None) -> str | None: + """Find a Python executable that has pipdeptree installed. + + Args: + python_override: Optional Python path from CLI argument + + Returns: + Path to Python executable with pipdeptree, or None if not found + """ + python_exe = get_python_executable(python_override) + + # Verify pipdeptree is available + try: + result = subprocess.run( + [python_exe, "-m", "pipdeptree", "--version"], + capture_output=True, + timeout=5, + ) + if result.returncode == 0: + return python_exe + except (subprocess.SubprocessError, FileNotFoundError): + pass + + return None + + +def get_dependency_tree_json(python_override: Path | None = None) -> list | None: + """Run pipdeptree --json and return parsed output. + + Args: + python_override: Optional Python path from CLI argument + + Returns: + Parsed pipdeptree JSON output, or None on failure + """ + python_exe = find_python_with_pipdeptree(python_override) + if python_exe is None: + return None + + try: + result = subprocess.run( + [python_exe, "-m", "pipdeptree", "--json"], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode == 0: + return json.loads(result.stdout) + except (subprocess.SubprocessError, json.JSONDecodeError, FileNotFoundError): + return None + return None + + +def get_outdated_packages(python_override: Path | None = None) -> list[dict] | None: + """Run pip list --outdated --format=json and return parsed output. + + Args: + python_override: Optional Python path from CLI argument + + Returns: + List of outdated packages, empty list if none, or None on failure + """ + python_exe = get_python_executable(python_override) + + try: + result = subprocess.run( + [python_exe, "-m", "pip", "list", "--outdated", "--format=json"], + capture_output=True, + text=True, + timeout=120, # Increased timeout for slow networks + ) + if result.returncode == 0 and result.stdout.strip(): + return json.loads(result.stdout) + # Return empty list instead of None if command succeeded but no outdated packages + if result.returncode == 0: + return [] + except subprocess.TimeoutExpired: + # Timeout is common on slow networks, return None to signal failure + return None + except (subprocess.SubprocessError, json.JSONDecodeError, FileNotFoundError): + return None + return None + + +def get_ml_stack_health(python_override: Path | None = None) -> dict: + """Get ML stack health information (PyTorch, CUDA, GPU). + + Args: + python_override: Optional Python path from CLI argument + + Returns: + Dictionary with ML stack information + """ + python_exe = get_python_executable(python_override) + + ml_info = { + "pytorch_version": None, + "cuda_available": False, + "cuda_version": None, + "gpu_name": None, + "gpu_count": 0, + "transformers_version": None, + "faiss_version": None, + "sentence_transformers_version": None, + } + + # Get PyTorch/CUDA info + try: + result = subprocess.run( + [ + python_exe, + "-c", + """ +import json +info = {} +try: + import torch + info['pytorch_version'] = torch.__version__ + info['cuda_available'] = torch.cuda.is_available() + if torch.cuda.is_available(): + info['cuda_version'] = torch.version.cuda + info['gpu_count'] = torch.cuda.device_count() + info['gpu_name'] = torch.cuda.get_device_name(0) +except ImportError: + pass +print(json.dumps(info)) +""", + ], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode == 0: + pytorch_info = json.loads(result.stdout.strip()) + ml_info.update(pytorch_info) + except (subprocess.SubprocessError, json.JSONDecodeError): + pass + + # Get other ML package versions + try: + result = subprocess.run( + [ + python_exe, + "-c", + """ +import json +info = {} +try: + import transformers + info['transformers_version'] = transformers.__version__ +except ImportError: + pass +try: + import faiss + info['faiss_version'] = faiss.__version__ if hasattr(faiss, '__version__') else 'installed' +except ImportError: + pass +try: + import sentence_transformers + info['sentence_transformers_version'] = sentence_transformers.__version__ +except ImportError: + pass +print(json.dumps(info)) +""", + ], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode == 0: + other_info = json.loads(result.stdout.strip()) + ml_info.update(other_info) + except (subprocess.SubprocessError, json.JSONDecodeError): + pass + + return ml_info + + +def _normalize_package_name(name: str) -> str: + """PEP 503 package name normalization. + + Normalizes package names by converting to lowercase and replacing + runs of hyphens, underscores, and dots with a single hyphen. + This ensures 'pre-commit' and 'pre_commit' are treated as the same package. + """ + return re.sub(r"[-_.]+", "-", name).lower() + + +def get_direct_dependencies(project_root: Path | None = None) -> set[str]: + """Extract direct dependency names from pyproject.toml.""" + if project_root is None: + project_root = Path.cwd() + + pyproject_path = project_root / "pyproject.toml" + if not pyproject_path.exists(): + return set() + + with open(pyproject_path, "rb") as f: + data = tomllib.load(f) + + deps = set() + # Main dependencies + for dep in data.get("project", {}).get("dependencies", []): + # Extract package name (before any version specifier) + name = re.split(r"[<>=!~\[]", dep)[0].strip() + deps.add(_normalize_package_name(name)) + + # Optional dependencies (dev, test, etc.) + for group_deps in data.get("project", {}).get("optional-dependencies", {}).values(): + for dep in group_deps: + name = re.split(r"[<>=!~\[]", dep)[0].strip() + deps.add(_normalize_package_name(name)) + + return deps + + +def get_dependency_constraints( + tree_data: list | None, +) -> dict[str, list[tuple[str, str]]]: + """Get dependency constraints for each package. + + Returns dict mapping package_name -> list of (dependent, constraint) tuples. + Example: {"fsspec": [("datasets", "<=2025.10.0")]} + """ + if not tree_data: + return {} + + constraints = defaultdict(list) + + for pkg in tree_data: + pkg_name = pkg["package"]["package_name"] + for dep in pkg.get("dependencies", []): + dep_name = dep["package_name"] + required_version = dep.get("required_version", "Any") + + # Only track meaningful constraints (not "Any") + if required_version and required_version != "Any": + constraints[dep_name.lower()].append((pkg_name, required_version)) + + return dict(constraints) + + +def categorize_outdated_packages( + outdated: list[dict], + constraints: dict | None = None, + ecosystem_constraints: dict | None = None, +) -> dict: + """Categorize outdated packages by update risk. + + Args: + outdated: List of outdated packages from pip list --outdated + constraints: Explicit version constraints from dependency tree + ecosystem_constraints: Implicit constraints (PyTorch/CUDA ecosystem) + + Returns: + Dict with categorized packages (ml_core, blocked, ecosystem_blocked, major_jump, safe_update) + """ + ML_CORE = UNIVERSAL_ML_CORE # Use universal ML core packages + + if constraints is None: + constraints = {} + if ecosystem_constraints is None: + ecosystem_constraints = {} + + categories = { + "ml_core": [], # DO NOT auto-update + "major_jump": [], # Review breaking changes + "safe_update": [], # Safe to update + "blocked": [], # Cannot update due to explicit constraints + "ecosystem_blocked": [], # Cannot update due to PyTorch/CUDA ecosystem + } + + for pkg in outdated: + name = pkg["name"].lower() + current = pkg["version"] + latest = pkg["latest_version"] + + # Parse major versions + current_major = current.split(".")[0].lstrip("v") + latest_major = latest.split(".")[0].lstrip("v") + + # Get blocking constraints for this package + blocking = constraints.get(name, []) + + # Check for ecosystem constraints FIRST (implicit locking to PyTorch/CUDA) + ecosystem_reason = ecosystem_constraints.get(name) + + pkg_info = { + "name": pkg["name"], + "current": current, + "latest": latest, + "is_major_jump": current_major != latest_major, + "blocked_by": blocking if blocking else None, + } + + # Check if package has strict version constraints + has_strict_constraint = any( + "==" in constraint or "<=" in constraint or "<" in constraint for _, constraint in blocking + ) + + if ecosystem_reason: + # Package is implicitly locked to PyTorch/CUDA ecosystem + pkg_info["blocked_by"] = [(ecosystem_reason, "ecosystem")] + categories["ecosystem_blocked"].append(pkg_info) + elif has_strict_constraint and blocking: + # Package is blocked by explicit version constraints + categories["blocked"].append(pkg_info) + elif name in ML_CORE: + categories["ml_core"].append(pkg_info) + elif current_major != latest_major: + categories["major_jump"].append(pkg_info) + else: + categories["safe_update"].append(pkg_info) + + return categories + + +def build_package_trees(tree_data: list, direct_deps: set) -> dict: + """Build dependency trees for direct dependencies only.""" + trees = {} + for pkg in tree_data: + pkg_name = pkg["package"]["package_name"].lower() + if pkg_name in direct_deps: + trees[pkg_name] = { + "version": pkg["package"]["installed_version"], + "dependencies": pkg.get("dependencies", []), + } + return trees + + +def detect_project_domains(installed_packages: set[str]) -> dict[str, list]: + """Auto-detect which domain-specific categories apply to this project. + + Returns dict mapping category name to list of matching packages. + """ + domains = {} + + # Check each domain category + domain_sets = { + "embedding_search": EMBEDDING_SEARCH, + "code_parsing": CODE_PARSING, + "image_generation": IMAGE_GENERATION, + "web_api": WEB_API, + "nlp": NLP_PACKAGES, + } + + for domain_name, domain_packages in domain_sets.items(): + matches = installed_packages & domain_packages + if len(matches) >= 2: # At least 2 packages to count as a domain + domains[domain_name] = sorted(matches) + + return domains + + +def get_ecosystem_constraints(ml_info: dict | None = None) -> dict[str, str]: + """Generate synthetic constraints for PyTorch ecosystem packages. + + Args: + ml_info: ML stack information from get_ml_stack_health() + + Returns: + Dict mapping package name -> constraint reason + """ + constraints = {} + if not ml_info: + return constraints + + pytorch_version = ml_info.get("pytorch_version") + cuda_version = ml_info.get("cuda_version") + + if pytorch_version: + for pkg in PYTORCH_ECOSYSTEM: + constraints[pkg.lower()] = f"Implicitly locked to torch=={pytorch_version}" + + if cuda_version: + for pkg in TENSORRT_ECOSYSTEM: + if pkg.lower() not in constraints: + constraints[pkg.lower()] = f"Implicitly locked to CUDA {cuda_version}" + + return constraints + + +def find_orphan_packages(tree_data: list, direct_deps: set, project_name: str | None = None) -> dict: + """Find packages not in direct deps and categorize them. + + Uses tiered classification: + 1. Universal ML packages (always recognized) + 2. Domain-specific packages (auto-detected) + 3. True orphans (unknown packages) + + Args: + tree_data: Output from pipdeptree --json + direct_deps: Set of direct dependencies from pyproject.toml + project_name: Optional project package name (auto-detected from pyproject.toml) + + Returns: + Dict mapping category -> list of packages + """ + all_installed = {} + all_required_by = defaultdict(set) + + for pkg in tree_data: + name = _normalize_package_name(pkg["package"]["package_name"]) + version = pkg["package"]["installed_version"] + all_installed[name] = version + + # Track reverse dependencies + for dep in pkg.get("dependencies", []): + dep_name = _normalize_package_name(dep["package_name"]) + all_required_by[dep_name].add(name) + + # Find packages with no dependents and not in direct deps + potential_orphans = [] + for name, version in all_installed.items(): + if name not in direct_deps and not all_required_by.get(name): + potential_orphans.append({"name": name, "version": version}) + + # Get all installed package names for domain detection + installed = {pkg["package"]["package_name"].lower() for pkg in tree_data} + detected_domains = detect_project_domains(installed) + + # Categorize orphans + categorized = { + "project_package": [], # The project itself + "ml_core": [], # Universal ML packages + "pytorch_ecosystem": [], # PyTorch-locked packages + "cuda_stack": [], # CUDA/TensorRT packages + "dev_tools": [], # Dev/build tools + "embedding_search": [], # Auto-detected: embedding/search + "code_parsing": [], # Auto-detected: tree-sitter + "image_generation": [], # Auto-detected: diffusion/image + "web_api": [], # Auto-detected: web frameworks + "nlp": [], # Auto-detected: NLP tools + "true_orphans": [], # Actually unknown + } + + for pkg in potential_orphans: + name = pkg["name"].lower() + + # Check project package first + if project_name and name == project_name.lower(): + categorized["project_package"].append(pkg) + # Universal categories + elif name in UNIVERSAL_ML_CORE: + categorized["ml_core"].append(pkg) + elif name in PYTORCH_ECOSYSTEM: + categorized["pytorch_ecosystem"].append(pkg) + elif name in CUDA_STACK or name in TENSORRT_ECOSYSTEM: + categorized["cuda_stack"].append(pkg) + elif name in DEV_TOOLS: + categorized["dev_tools"].append(pkg) + # Domain-specific (only if domain detected in project) + elif "embedding_search" in detected_domains and name in EMBEDDING_SEARCH: + categorized["embedding_search"].append(pkg) + elif "code_parsing" in detected_domains and name in CODE_PARSING: + categorized["code_parsing"].append(pkg) + elif "image_generation" in detected_domains and name in IMAGE_GENERATION: + categorized["image_generation"].append(pkg) + elif "web_api" in detected_domains and name in WEB_API: + categorized["web_api"].append(pkg) + elif "nlp" in detected_domains and name in NLP_PACKAGES: + categorized["nlp"].append(pkg) + else: + categorized["true_orphans"].append(pkg) + + # Remove empty categories + return {k: v for k, v in categorized.items() if v} + + +def safe_print(text: str) -> None: + """Print text with Windows-safe encoding.""" + try: + print(text) + except UnicodeEncodeError: + # Replace problematic Unicode characters for Windows console + safe_text = text.encode("ascii", "replace").decode("ascii") + print(safe_text) + + +def print_summary(summary: dict) -> None: + """Print formatted security summary.""" + safe_print("=" * 70) + safe_print("DEPENDENCY AUDIT SUMMARY".center(70)) + safe_print("=" * 70) + print(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + print(f"Total Packages: {summary['total_packages']}") + print(f"Vulnerable Packages: {summary['vulnerable_packages']}") + print(f"Total CVEs: {summary['total_cves']}") + print() + + if summary["total_cves"] == 0: + print("[OK] No known vulnerabilities found!") + print() + return + + # Print severity breakdown if available + if summary["severity_counts"]: + print("Severity Breakdown:") + for severity, count in summary["severity_counts"].items(): + print(f" - {severity.capitalize()}: {count}") + print() + + print("=" * 70) + print("VULNERABILITIES FOUND".center(70)) + print("=" * 70) + print() + + for pkg_name, vulns in sorted(summary["vulnerabilities"].items()): + print(f"[PACKAGE] {pkg_name} ({vulns[0]['version']})") + print("-" * 70) + + for vuln in vulns: + print(f" [VULN] {vuln['cve_id']}") + + if vuln["aliases"]: + aliases_str = ", ".join(vuln["aliases"]) + print(f" Aliases: {aliases_str}") + + if vuln["fix_versions"]: + fix_str = ", ".join(vuln["fix_versions"]) + print(f" Fix Available: {fix_str}") + else: + print(" Fix Available: No fix released yet") + + if vuln["description"]: + safe_print(f" Description: {vuln['description']}") + + print() + + print() + + print("=" * 70) + print("RECOMMENDED ACTIONS".center(70)) + print("=" * 70) + print() + + # Generate actionable recommendations + fixable = [(pkg, v) for pkg, vulns in summary["vulnerabilities"].items() for v in vulns if v["fix_versions"]] + + if fixable: + print("[FIXES] Packages with available fixes:") + for pkg, vuln in fixable: + fix_version = vuln["fix_versions"][0] if vuln["fix_versions"] else "latest" + print(f" pip install --upgrade {pkg}=={fix_version}") + print() + + unfixable = [(pkg, v) for pkg, vulns in summary["vulnerabilities"].items() for v in vulns if not v["fix_versions"]] + + if unfixable: + print("[MONITOR] Packages without fixes (monitor for updates):") + for pkg, vuln in unfixable: + print(f" {pkg}: {vuln['cve_id']}") + print() + + print("[NEXT STEPS] Actions to take:") + print(" 1. Review CVE details at https://osv.dev/") + print(" 2. Test updates in isolated environment") + print(" 3. Run full test suite before deploying") + print(" 4. Update pyproject.toml with new version constraints") + print() + + +def print_dependency_tree(pkg_name: str, pkg_data: dict, indent: int = 0, visited: set | None = None) -> None: + """Recursively print ASCII dependency tree.""" + if visited is None: + visited = set() + + prefix = " " * indent + if indent == 0: + safe_print(f"{pkg_name}=={pkg_data['version']}") + + deps = pkg_data.get("dependencies", []) + for i, dep in enumerate(deps): + dep_name = dep["package_name"] + dep_key = dep.get("key", dep_name.lower()) + required = dep.get("required_version", "Any") + installed = dep.get("installed_version", "?") + + is_last = i == len(deps) - 1 + branch = "+-- " if is_last else "|-- " # ASCII-safe for Windows console + + safe_print(f"{prefix}{branch}{dep_name} [required: {required}, installed: {installed}]") + + # Prevent infinite loops from circular dependencies + if dep_key not in visited: + visited.add(dep_key) + # Recurse for nested dependencies + nested_deps = dep.get("dependencies", []) + if nested_deps: + child_prefix = " " if is_last else "| " # ASCII-safe + for j, nested in enumerate(nested_deps): + nested_is_last = j == len(nested_deps) - 1 + nested_branch = "+-- " if nested_is_last else "|-- " # ASCII-safe + n_name = nested["package_name"] + n_req = nested.get("required_version", "Any") + n_inst = nested.get("installed_version", "?") + safe_print( + f"{prefix}{child_prefix}{nested_branch}{n_name} [required: {n_req}, installed: {n_inst}]" + ) + + +def print_dependency_analysis(tree_data: list, direct_deps: set) -> None: + """Print complete dependency analysis section.""" + if not tree_data: + safe_print("\n[WARN] pipdeptree not available - skipping dependency tree analysis") + safe_print(" Install with: pip install pipdeptree") + return + + # Calculate stats + all_installed = {pkg["package"]["package_name"].lower() for pkg in tree_data} + transitive = all_installed - direct_deps + + safe_print("\n" + "=" * 70) + safe_print("DEPENDENCY TREE ANALYSIS".center(70)) + safe_print("=" * 70) + safe_print(f"Direct Dependencies: {len(direct_deps)} (from pyproject.toml)") + safe_print(f"Transitive Dependencies: {len(transitive)} (pulled in automatically)") + safe_print(f"Total Installed: {len(all_installed)}") + safe_print("") + + # Build and print trees for direct deps + trees = build_package_trees(tree_data, direct_deps) + + for pkg_name in sorted(trees.keys()): + pkg_data = trees[pkg_name] + if pkg_data["dependencies"]: # Only show packages with dependencies + safe_print(f"[TREE] {pkg_name} ({pkg_data['version']})") + safe_print("-" * 70) + print_dependency_tree(pkg_name, pkg_data) + safe_print("") + + # Find and categorize orphans + project_info = get_project_info() + project_name = project_info.get("name") if project_info["name"] != "Unknown" else None + orphan_data = find_orphan_packages(tree_data, direct_deps, project_name) + total_orphans = sum(len(v) for v in orphan_data.values()) + + if total_orphans: + safe_print(f"[PACKAGES] {total_orphans} packages not tracked in pyproject.toml:") + safe_print("-" * 70) + + for category, packages in orphan_data.items(): + if packages: + tag, description = CATEGORY_LABELS.get(category, ("[?]", category)) + safe_print(f" {tag} {description}:") + for pkg in sorted(packages, key=lambda x: x["name"]): + safe_print(f" - {pkg['name']} ({pkg['version']})") + safe_print("") + + if orphan_data.get("true_orphans"): + safe_print(" Actions for unknown packages:") + safe_print(" - If needed: Add to pyproject.toml dependencies") + safe_print(" - If not needed: pip uninstall ") + safe_print("") + else: + safe_print("[OK] No packages outside direct dependencies detected") + safe_print("") + + +def print_outdated_analysis( + outdated: list[dict] | None, + tree_data: list | None = None, + ml_info: dict | None = None, +) -> None: + """Print outdated packages analysis to console.""" + if outdated is None: + safe_print("\n[WARN] Could not retrieve outdated packages") + return + + if not outdated: + safe_print("\n[OK] All packages are up to date!") + return + + # Get dependency constraints + constraints = get_dependency_constraints(tree_data) + ecosystem_constraints = get_ecosystem_constraints(ml_info) + categories = categorize_outdated_packages(outdated, constraints, ecosystem_constraints) + + safe_print("\n" + "=" * 70) + safe_print("OUTDATED PACKAGES ANALYSIS".center(70)) + safe_print("=" * 70) + safe_print(f"Total Outdated: {len(outdated)}") + safe_print("") + + # Blocked packages (cannot update due to explicit version constraints) + if categories["blocked"]: + safe_print("[BLOCKED] Cannot update due to version constraints:") + safe_print("-" * 70) + for pkg in categories["blocked"]: + safe_print(f" {pkg['name']}: {pkg['current']} -> {pkg['latest']}") + if pkg["blocked_by"]: + for dependent, constraint in pkg["blocked_by"]: + safe_print(f" Blocked by: {dependent} requires {constraint}") + safe_print("") + + # Ecosystem-blocked packages (PyTorch/CUDA version-locked) + if categories.get("ecosystem_blocked"): + safe_print("[ECOSYSTEM] Locked to current PyTorch/CUDA version:") + safe_print("-" * 70) + for pkg in categories["ecosystem_blocked"]: + safe_print(f" {pkg['name']}: {pkg['current']} -> {pkg['latest']}") + if pkg["blocked_by"]: + reason, _ = pkg["blocked_by"][0] + safe_print(f" Reason: {reason}") + safe_print("") + + # ML Core packages (DO NOT auto-update) + if categories["ml_core"]: + safe_print("[ML CORE] DO NOT auto-update - test thoroughly first:") + safe_print("-" * 70) + for pkg in categories["ml_core"]: + jump = " [MAJOR]" if pkg["is_major_jump"] else "" + safe_print(f" {pkg['name']}: {pkg['current']} -> {pkg['latest']}{jump}") + safe_print("") + + # Major version jumps (review breaking changes) + if categories["major_jump"]: + safe_print("[MAJOR VERSION] Review breaking changes before updating:") + safe_print("-" * 70) + for pkg in categories["major_jump"]: + safe_print(f" {pkg['name']}: {pkg['current']} -> {pkg['latest']}") + safe_print("") + + # Safe updates + if categories["safe_update"]: + safe_print("[SAFE] Minor/patch updates (generally safe):") + safe_print("-" * 70) + for pkg in categories["safe_update"]: + safe_print(f" {pkg['name']}: {pkg['current']} -> {pkg['latest']}") + safe_print("") + + +def get_project_info() -> dict: + """Get project name and version from pyproject.toml.""" + pyproject_path = Path.cwd() / "pyproject.toml" + if not pyproject_path.exists(): + return {"name": "Unknown", "version": "Unknown"} + + try: + with open(pyproject_path, "rb") as f: + data = tomllib.load(f) + return { + "name": data.get("project", {}).get("name", "Unknown"), + "version": data.get("project", {}).get("version", "Unknown"), + } + except tomllib.TOMLDecodeError: + return {"name": "Unknown", "version": "Unknown"} + + +def generate_markdown_report( + summary: dict, + tree_data: list | None, + direct_deps: set, + outdated: list[dict] | None = None, + ml_info: dict | None = None, +) -> str: + """Generate comprehensive executive summary report as markdown string.""" + lines = [] + now = datetime.now() + project_info = get_project_info() + + # Calculate dependency stats + total_installed = len(tree_data) if tree_data else summary["total_packages"] + transitive_count = total_installed - len(direct_deps) if tree_data else 0 + + # Title + lines.append("# Dependency Audit Executive Summary") + lines.append("") + lines.append(f"**Date**: {now.strftime('%Y-%m-%d %H:%M:%S')}") + lines.append(f"**Project**: {project_info['name']} (v{project_info['version']})") + lines.append( + f"**Total Dependencies**: {total_installed} packages " + f"({len(direct_deps)} direct + {transitive_count} transitive)" + ) + lines.append(f"**Audit Report**: `audit_reports/{now.strftime('%Y-%m-%d-%H%M')}-audit-summary.md`") + lines.append("") + + # --- Security Status --- + lines.append("---") + lines.append("") + if summary["total_cves"] == 0: + lines.append("## ✅ Security Status: **EXCELLENT**") + else: + lines.append("## ⚠️ Security Status: **ACTION REQUIRED**") + lines.append("") + lines.append( + f"- **Known Vulnerabilities**: " + f"{summary['severity_counts'].get('critical', 0)} critical, " + f"{summary['severity_counts'].get('high', 0)} high, " + f"{summary['severity_counts'].get('medium', 0)} medium, " + f"{summary['severity_counts'].get('low', 0)} low" + ) + lines.append(f"- **CVE Count**: {summary['total_cves']}") + lines.append(f"- **Last Scan**: {now.strftime('%Y-%m-%d %H:%M:%S')}") + lines.append("") + + if summary["total_cves"] == 0: + lines.append( + "**Finding**: No security vulnerabilities detected in any dependencies. " + "All packages are clean according to OSV database." + ) + lines.append("") + + # --- ML Stack Health --- + if ml_info: + lines.append("---") + lines.append("") + if ml_info.get("cuda_available"): + lines.append("## 🤖 ML Stack Health: **GOOD**") + elif ml_info.get("pytorch_version"): + lines.append("## 🤖 ML Stack Health: **CPU-ONLY**") + else: + lines.append("## 🤖 ML Stack Health: **NOT INSTALLED**") + lines.append("") + + lines.append("| Component | Version | Status | Notes |") + lines.append("|-----------|---------|--------|-------|") + + # PyTorch + if ml_info.get("pytorch_version"): + pytorch_status = "✅ Stable" + pytorch_notes = "" + # Check if outdated + if outdated: + for pkg in outdated: + if pkg["name"].lower() == "torch": + pytorch_notes = f"Latest: {pkg['latest_version']}" + break + lines.append(f"| PyTorch | {ml_info['pytorch_version']} | {pytorch_status} | {pytorch_notes} |") + else: + lines.append("| PyTorch | Not installed | ⚪ N/A | |") + + # CUDA + if ml_info.get("cuda_available"): + lines.append( + f"| CUDA | {ml_info.get('cuda_version', 'Unknown')} | ✅ Available | " + f"Compatible with {ml_info.get('gpu_name', 'GPU')} |" + ) + else: + lines.append("| CUDA | N/A | ⚪ Not available | CPU mode |") + + # GPU + if ml_info.get("gpu_name"): + lines.append(f"| GPU | {ml_info['gpu_name']} | ✅ Active | {ml_info.get('gpu_count', 1)} device(s) |") + else: + lines.append("| GPU | None | ⚪ N/A | |") + + # transformers + if ml_info.get("transformers_version"): + lines.append(f"| transformers | {ml_info['transformers_version']} | ✅ Current | |") + + # FAISS + if ml_info.get("faiss_version"): + lines.append(f"| FAISS | {ml_info['faiss_version']} | ✅ Current | CPU version |") + + # sentence-transformers + if ml_info.get("sentence_transformers_version"): + lines.append(f"| sentence-transformers | {ml_info['sentence_transformers_version']} | ✅ Current | |") + + lines.append("") + + if ml_info.get("cuda_available") and ml_info.get("pytorch_version"): + lines.append( + f"**CUDA/PyTorch Compatibility**: Excellent. " + f"PyTorch {ml_info['pytorch_version']} with CUDA {ml_info.get('cuda_version', 'Unknown')} " + f"support is working correctly" + f"{' with ' + ml_info['gpu_name'] if ml_info.get('gpu_name') else ''}." + ) + lines.append("") + + # --- Vulnerabilities Found --- + if summary["vulnerabilities"]: + lines.append("---") + lines.append("") + lines.append("## 🔴 Vulnerabilities Found") + lines.append("") + + for pkg_name, vulns in sorted(summary["vulnerabilities"].items()): + lines.append(f"### {pkg_name} ({vulns[0]['version']})") + lines.append("") + + for vuln in vulns: + lines.append(f"**{vuln['cve_id']}**") + lines.append("") + + if vuln["aliases"]: + aliases_str = ", ".join(vuln["aliases"]) + lines.append(f"- **Aliases**: {aliases_str}") + + if vuln["fix_versions"]: + fix_str = ", ".join(vuln["fix_versions"]) + lines.append(f"- **Fix Available**: {fix_str}") + else: + lines.append("- **Fix Available**: No fix released yet") + + if vuln["description"]: + desc = vuln["description"].replace("\n", " ") + lines.append(f"- **Description**: {desc}") + + lines.append("") + + # Recommended fix commands + lines.append("### Recommended Actions") + lines.append("") + + fixable = [(pkg, v) for pkg, vulns in summary["vulnerabilities"].items() for v in vulns if v["fix_versions"]] + + if fixable: + lines.append("**Packages with available fixes:**") + lines.append("") + lines.append("```bash") + for pkg, vuln in fixable: + fix_version = vuln["fix_versions"][0] if vuln["fix_versions"] else "latest" + lines.append(f"pip install --upgrade {pkg}=={fix_version}") + lines.append("```") + lines.append("") + + unfixable = [ + (pkg, v) for pkg, vulns in summary["vulnerabilities"].items() for v in vulns if not v["fix_versions"] + ] + + if unfixable: + lines.append("**Packages without fixes (monitor for updates):**") + lines.append("") + for pkg, vuln in unfixable: + lines.append(f"- **{pkg}**: {vuln['cve_id']}") + lines.append("") + + # --- Outdated Packages Analysis --- + if outdated: + constraints = get_dependency_constraints(tree_data) if tree_data else {} + ecosystem_constraints = get_ecosystem_constraints(ml_info) + categories = categorize_outdated_packages(outdated, constraints, ecosystem_constraints) + + lines.append("---") + lines.append("") + lines.append("## 📦 Outdated Packages Analysis") + lines.append("") + lines.append(f"**Total Outdated**: {len(outdated)} packages (prioritized by risk)") + lines.append("") + + # High Priority - ML Core + if categories["ml_core"]: + lines.append("### 🔴 High Priority (ML Core - Test Thoroughly)") + lines.append("") + lines.append("⚠️ **DO NOT auto-update these packages** - CUDA compatibility and model behavior may change.") + lines.append("") + lines.append("| Package | Current | Latest | Priority | Reason |") + lines.append("|---------|---------|--------|----------|--------|") + for pkg in categories["ml_core"]: + priority = "🔴 High" if pkg["is_major_jump"] else "🟡 Medium" + reason = "Major version change" if pkg["is_major_jump"] else "ML core package" + lines.append(f"| **{pkg['name']}** | {pkg['current']} | {pkg['latest']} | {priority} | {reason} |") + lines.append("") + + # Major Version Jumps + if categories["major_jump"]: + lines.append("### 🟡 Medium Priority (Major Version Changes)") + lines.append("") + lines.append("Review breaking changes before updating. Check release notes.") + lines.append("") + lines.append("| Package | Current | Latest | Type |") + lines.append("|---------|---------|--------|------|") + for pkg in categories["major_jump"]: + pkg_type = ( + "Dev tool" + if pkg["name"].lower() in {"pytest", "black", "isort", "ruff", "mypy", "pyrefly"} + else "Library" + ) + lines.append(f"| {pkg['name']} | {pkg['current']} | {pkg['latest']} | {pkg_type} |") + lines.append("") + + # Blocked packages + if categories.get("blocked"): + lines.append("### 🔒 Blocked (Cannot Update)") + lines.append("") + lines.append("These packages have newer versions but cannot be updated due to dependency constraints.") + lines.append("") + lines.append("| Package | Current | Latest | Blocked By |") + lines.append("|---------|---------|--------|------------|") + for pkg in categories["blocked"]: + # Format blocking dependencies + blockers = [] + for blocker, constraint in pkg.get("blocked_by", []): + blockers.append(f"`{blocker}` requires `{constraint}`") + blockers_str = "
".join(blockers) if blockers else "Unknown" + lines.append(f"| **{pkg['name']}** | {pkg['current']} | {pkg['latest']} | {blockers_str} |") + lines.append("") + + # Ecosystem-locked packages + if categories.get("ecosystem_blocked"): + lines.append("### 🔗 Ecosystem Locked (PyTorch/CUDA)") + lines.append("") + lines.append( + "These packages are implicitly locked to the current PyTorch/CUDA version. " + "Update only as part of a coordinated ecosystem upgrade." + ) + lines.append("") + lines.append("| Package | Current | Latest | Reason |") + lines.append("|---------|---------|--------|--------|") + for pkg in categories["ecosystem_blocked"]: + reason = pkg.get("blocked_by", [("Unknown", "")][0])[0] + lines.append(f"| {pkg['name']} | {pkg['current']} | {pkg['latest']} | {reason} |") + lines.append("") + + # Safe Updates + if categories["safe_update"]: + lines.append("### 🟢 Low Priority (Minor/Patch Updates)") + lines.append("") + lines.append("Generally safe to update. Run tests after updating.") + lines.append("") + lines.append("| Package | Current | Latest |") + lines.append("|---------|---------|--------|") + for pkg in categories["safe_update"]: + lines.append(f"| {pkg['name']} | {pkg['current']} | {pkg['latest']} |") + lines.append("") + + # --- Dependency Tree Analysis --- + if tree_data: + all_installed = {pkg["package"]["package_name"].lower() for pkg in tree_data} + transitive = all_installed - direct_deps + + lines.append("---") + lines.append("") + lines.append("## 🌳 Dependency Tree Analysis") + lines.append("") + lines.append(f"- **Direct Dependencies**: {len(direct_deps)} (from pyproject.toml)") + lines.append(f"- **Transitive Dependencies**: {len(transitive)} (pulled in automatically)") + lines.append(f"- **Total Installed**: {len(all_installed)}") + dep_ratio = len(all_installed) / len(direct_deps) if direct_deps else 0 + lines.append(f"- **Dependency Ratio**: {dep_ratio:.2f}:1 (each direct dep pulls ~{dep_ratio:.1f} transitive)") + lines.append("") + + # Build trees for direct deps + trees = build_package_trees(tree_data, direct_deps) + + if trees: + lines.append("### Key Dependency Trees") + lines.append("") + lines.append("
Click to expand dependency trees") + lines.append("") + + for pkg_name in sorted(trees.keys()): + pkg_data = trees[pkg_name] + if pkg_data["dependencies"]: + lines.append(f"**{pkg_name}** ({pkg_data['version']})") + lines.append("") + lines.append("```") + lines.append(f"{pkg_name}=={pkg_data['version']}") + for dep in pkg_data["dependencies"]: + dep_name = dep["package_name"] + required = dep.get("required_version", "Any") + installed = dep.get("installed_version", "?") + lines.append(f"|-- {dep_name} [required: {required}, installed: {installed}]") + # Add nested deps if present + for nested in dep.get("dependencies", [])[:3]: + n_name = nested["package_name"] + n_inst = nested.get("installed_version", "?") + lines.append(f" +-- {n_name} [{n_inst}]") + lines.append("```") + lines.append("") + + lines.append("
") + lines.append("") + + # Orphan packages (categorized) + project_name = project_info.get("name") if project_info["name"] != "Unknown" else None + orphan_data = find_orphan_packages(tree_data, direct_deps, project_name) + total_orphans = sum(len(v) for v in orphan_data.values()) + + if total_orphans: + lines.append("### 🧹 Orphan Packages") + lines.append("") + lines.append(f"Found **{total_orphans}** packages not in pyproject.toml with no dependents.") + lines.append("") + + # Display categories using CATEGORY_LABELS + # Safe to keep categories + safe_categories = [ + "project_package", + "ml_core", + "pytorch_ecosystem", + "cuda_stack", + "dev_tools", + ] + # Domain-specific categories + domain_categories = [ + "embedding_search", + "code_parsing", + "image_generation", + "web_api", + "nlp", + ] + + # Safe to Keep (infrastructure) + has_safe = any(orphan_data.get(cat) for cat in safe_categories) + if has_safe: + lines.append("**Safe to Keep (Development Tools)**:") + for category in safe_categories: + if category in orphan_data and orphan_data[category]: + tag, description = CATEGORY_LABELS.get(category, ("[?]", category)) + for pkg in sorted(orphan_data[category], key=lambda x: x["name"]): + lines.append(f"- `{pkg['name']}` ({pkg['version']})") + lines.append("") + + # Domain-specific packages + has_domain = any(orphan_data.get(cat) for cat in domain_categories) + if has_domain: + lines.append("**Domain-Specific Packages (Auto-Detected)**:") + for category in domain_categories: + if category in orphan_data and orphan_data[category]: + tag, description = CATEGORY_LABELS.get(category, ("[?]", category)) + lines.append(f"- {description}:") + for pkg in sorted(orphan_data[category], key=lambda x: x["name"]): + lines.append(f" - `{pkg['name']}` ({pkg['version']})") + lines.append("") + + # Investigate Before Removing + if orphan_data.get("true_orphans"): + lines.append("**Investigate Before Removing**:") + investigate_list = orphan_data["true_orphans"] + for orphan in investigate_list[:10]: # Limit to first 10 + lines.append(f"- `{orphan['name']}` ({orphan['version']})") + if len(investigate_list) > 10: + lines.append(f"- ... and {len(investigate_list) - 10} more") + lines.append("") + + lines.append( + "**Recommendation**: Don't remove orphans yet. Many are transitive dependencies " + "that pipdeptree may not detect correctly (especially for compiled packages)." + ) + lines.append("") + + # --- Health Metrics --- + lines.append("---") + lines.append("") + lines.append("## 📊 Dependency Health Metrics") + lines.append("") + lines.append("| Metric | Value | Status |") + lines.append("|--------|-------|--------|") + + total_pkgs = total_installed + lines.append(f"| Total Packages | {total_pkgs} | ✅ Reasonable |") + lines.append(f"| Direct Dependencies | {len(direct_deps)} | ✅ Manageable |") + lines.append(f"| Transitive Dependencies | {transitive_count} | ✅ Expected for ML project |") + + if direct_deps: + dep_ratio = total_pkgs / len(direct_deps) + lines.append(f"| Dependency Ratio | {dep_ratio:.2f}:1 | ✅ Normal |") + + lines.append( + f"| Security Vulnerabilities | {summary['total_cves']} | {'✅ Excellent' if summary['total_cves'] == 0 else '⚠️ Action needed'} |" + ) + + if outdated: + outdated_pct = (len(outdated) / total_pkgs) * 100 if total_pkgs else 0 + lines.append( + f"| Outdated Packages | {len(outdated)} ({outdated_pct:.1f}%) | " + f"{'✅ Acceptable' if outdated_pct < 20 else '⚠️ Review needed'} |" + ) + + if tree_data: + orphan_data_count = find_orphan_packages(tree_data, direct_deps, project_name) + orphan_count = sum(len(v) for v in orphan_data_count.values()) + orphan_pct = (orphan_count / total_pkgs) * 100 if total_pkgs else 0 + lines.append( + f"| Orphan Packages | {orphan_count} ({orphan_pct:.1f}%) | " + f"{'✅ Normal' if orphan_pct < 20 else '⚠️ Review periodically'} |" + ) + + lines.append("") + + # --- Recommended Actions Summary --- + lines.append("---") + lines.append("") + lines.append("## 🎯 Recommended Actions") + lines.append("") + + if summary["total_cves"] == 0: + lines.append("### Immediate Actions (This Week)") + lines.append("") + lines.append("✅ **No immediate security updates required** - All packages are vulnerability-free.") + lines.append("") + else: + lines.append("### 🔴 Immediate Actions (This Week)") + lines.append("") + lines.append("Security vulnerabilities detected. Apply fixes listed above.") + lines.append("") + + if outdated and categories.get("ml_core"): + lines.append("### Short-Term Actions (Next Sprint)") + lines.append("") + lines.append("Consider updating ML core packages after thorough testing:") + lines.append("") + for pkg in categories["ml_core"][:3]: + lines.append(f"- `{pkg['name']}`: {pkg['current']} → {pkg['latest']}") + lines.append("") + lines.append( + "**Before upgrading**: Check release notes, verify CUDA compatibility, test in isolated environment." + ) + lines.append("") + + # Ecosystem upgrade recommendations + if outdated and categories.get("ecosystem_blocked"): + ecosystem_pkgs = categories["ecosystem_blocked"] + lines.append("### Ecosystem Upgrade (When Ready)") + lines.append("") + + pytorch_version = ml_info.get("pytorch_version") if ml_info else None + if pytorch_version: + lines.append( + f"When upgrading PyTorch from {pytorch_version} to a new version, " + f"also update these {len(ecosystem_pkgs)} ecosystem packages:" + ) + else: + lines.append(f"When upgrading PyTorch/CUDA, also update these {len(ecosystem_pkgs)} ecosystem packages:") + lines.append("") + + for pkg in ecosystem_pkgs[:5]: # Show first 5 + lines.append(f"- `{pkg['name']}`: {pkg['current']} → {pkg['latest']}") + if len(ecosystem_pkgs) > 5: + lines.append(f"- ... and {len(ecosystem_pkgs) - 5} more packages") + lines.append("") + + lines.append( + "**Coordination required**: These packages must be updated together to maintain PyTorch/CUDA compatibility." + ) + lines.append("") + + lines.append("### Quarterly Review Actions") + lines.append("") + # Calculate next quarter - set day=1 first to avoid invalid dates (e.g., Feb 31) + next_quarter = now.replace(day=1, month=((now.month - 1) // 3 + 1) * 3 % 12 + 1) + if next_quarter.month <= now.month: + next_quarter = next_quarter.replace(year=now.year + 1) + lines.append(f"1. Re-run security audit: {next_quarter.strftime('%B %Y')}") + lines.append("2. Review outdated packages for updates") + lines.append("3. Check PyTorch ecosystem for new releases") + lines.append("4. Clean up orphan packages (if still unused)") + lines.append("") + + # --- Summary --- + lines.append("---") + lines.append("") + lines.append("## 💡 Summary") + lines.append("") + + overall_status = "EXCELLENT" if summary["total_cves"] == 0 else "ACTION REQUIRED" + lines.append(f"**Overall Health**: {'✅' if summary['total_cves'] == 0 else '⚠️'} **{overall_status}**") + lines.append("") + + checklist = [] + if summary["total_cves"] == 0: + checklist.append("✅ **Zero security vulnerabilities** - All packages clean") + else: + checklist.append(f"⚠️ **{summary['total_cves']} security vulnerabilities** - Action required") + + if ml_info and ml_info.get("cuda_available"): + checklist.append( + f"✅ **ML stack stable** - PyTorch {ml_info.get('pytorch_version', 'Unknown')} + " + f"CUDA {ml_info.get('cuda_version', 'Unknown')} working" + ) + elif ml_info and ml_info.get("pytorch_version"): + checklist.append("⚠️ **ML stack CPU-only** - No CUDA available") + + checklist.append(f"✅ **Dependencies manageable** - {total_pkgs} packages with good organization") + + if outdated: + checklist.append(f"🟡 **{len(outdated)} packages outdated** - Review when convenient") + + for item in checklist: + lines.append(f"- {item}") + + lines.append("") + lines.append( + f"**Risk Level**: **{'LOW' if summary['total_cves'] == 0 else 'MEDIUM'}** - " + f"{'No immediate action required.' if summary['total_cves'] == 0 else 'Apply security fixes first.'}" + ) + lines.append("") + + return "\n".join(lines) + + +def save_report(content: str, output_path: Path) -> None: + """Save report to file.""" + try: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(content, encoding="utf-8") + print(f"\n[SAVED] Report saved to: {output_path}") + except (OSError, PermissionError) as e: + print(f"\n[ERROR] Failed to save report: {e}", file=sys.stderr) + + +def main() -> None: + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Parse pip-audit JSON and generate human-readable summary", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Default: displays AND saves automatically + pip-audit --format json | python tools/summarize_audit.py + + # Disable auto-save (stdout only) + pip-audit --format json | python tools/summarize_audit.py --no-save + + # Custom output path + pip-audit --format json | python tools/summarize_audit.py -o audit_reports/before-fixes.md + + # Read from file + python tools/summarize_audit.py audit_reports/2025-12-18-audit.json + """, + ) + parser.add_argument( + "input_file", + nargs="?", + help="Input JSON file (or pipe from stdin)", + ) + parser.add_argument( + "--no-save", + action="store_true", + help="Disable auto-save (stdout only)", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + help="Custom output path (overrides auto-save location)", + ) + parser.add_argument( + "-p", + "--python", + type=Path, + help="Path to Python executable. Overrides DEPS_AUDIT_PYTHON env var and auto-detection.", + ) + args = parser.parse_args() + + # Read input data + if args.input_file: + # Read from file + json_file = Path(args.input_file) + if not json_file.exists(): + print(f"Error: File not found: {json_file}", file=sys.stderr) + sys.exit(1) + + with open(json_file, "r", encoding="utf-8") as f: + content = f.read() + + # Handle pip-audit header line (e.g., "No known vulnerabilities found") + # Find the start of JSON content + json_start = content.find("{") + if json_start == -1: + print("Error: No JSON object found in file", file=sys.stderr) + sys.exit(1) + + try: + data = json.loads(content[json_start:]) + except json.JSONDecodeError as e: + print(f"Error: Invalid JSON: {e}", file=sys.stderr) + sys.exit(1) + else: + # Read from stdin - also handle header line + try: + content = sys.stdin.read() + json_start = content.find("{") + if json_start == -1: + print("Error: No JSON object found in input", file=sys.stderr) + sys.exit(1) + data = json.loads(content[json_start:]) + except json.JSONDecodeError as e: + print(f"Error: Invalid JSON input: {e}", file=sys.stderr) + print("\nUsage:", file=sys.stderr) + print( + " pip-audit --format json | python tools/summarize_audit.py", + file=sys.stderr, + ) + print(" python tools/summarize_audit.py audit.json", file=sys.stderr) + sys.exit(1) + + # Parse data + summary = parse_audit_json(data) + tree_data = get_dependency_tree_json(args.python) + direct_deps = get_direct_dependencies() + outdated_data = get_outdated_packages(args.python) + ml_info = get_ml_stack_health(args.python) + + # Always print to stdout (for chat display) + print_summary(summary) + print_dependency_analysis(tree_data, direct_deps) + print_outdated_analysis(outdated_data, tree_data) + + # Save to file unless --no-save + if not args.no_save: + report = generate_markdown_report(summary, tree_data, direct_deps, outdated_data, ml_info) + if args.output: + output_path = args.output + else: + filename = f"{datetime.now().strftime('%Y-%m-%d-%H%M')}-audit-summary.md" + output_path = Path("audit_reports") / filename + save_report(report, output_path) + + +if __name__ == "__main__": + main()