diff --git a/__init__.py b/__init__.py index 2f5efd4..dbd15d8 100644 --- a/__init__.py +++ b/__init__.py @@ -1,25 +1,28 @@ import torch import logging -import weakref import os -import copy import json +import importlib +from contextlib import contextmanager from datetime import datetime from pathlib import Path +from types import MethodType import folder_paths import comfy.model_management as mm +import comfy.memory_management import comfy.model_patcher +import comfy.sample as comfy_sample from nodes import NODE_CLASS_MAPPINGS as GLOBAL_NODE_CLASS_MAPPINGS from .device_utils import ( get_device_list, is_accelerator_available, - soft_empty_cache_multigpu, + soft_empty_cache_multigpu as soft_empty_cache_multigpu, ) from .model_management_mgpu import ( - trigger_executor_cache_reset, - check_cpu_memory_threshold, - multigpu_memory_log, - force_full_system_cleanup, + trigger_executor_cache_reset as trigger_executor_cache_reset, + check_cpu_memory_threshold as check_cpu_memory_threshold, + multigpu_memory_log as multigpu_memory_log, + force_full_system_cleanup as force_full_system_cleanup, ) WEB_DIRECTORY = "./web" @@ -134,21 +137,45 @@ def mgpu_mm_log_method(self, msg): f"[MultiGPU Model Management] {msg}", extra={"mgpu_context": {"component": "model_management"}}, ) -logger.mgpu_mm_log = mgpu_mm_log_method.__get__(logger, type(logger)) +logger.mgpu_mm_log = MethodType(mgpu_mm_log_method, logger) + +def _normalize_module_name(module_name): + """Normalize a custom node directory name for tolerant matching.""" + return "".join(char for char in os.path.basename(module_name).lower() if char.isalnum()) def check_module_exists(module_path): """Check if a custom node module exists in ComfyUI custom_nodes directory.""" - full_path = os.path.join(folder_paths.get_folder_paths("custom_nodes")[0], module_path) - logger.debug(f"[MultiGPU] Checking for module at {full_path}") - if not os.path.exists(full_path): - logger.debug(f"[MultiGPU] Module {module_path} not found - skipping") - return False - logger.debug(f"[MultiGPU] Found {module_path}, creating compatible MultiGPU nodes") - return True + custom_nodes_paths = folder_paths.get_folder_paths("custom_nodes") + normalized_module_path = _normalize_module_name(module_path) + + for custom_nodes_path in custom_nodes_paths: + full_path = os.path.join(custom_nodes_path, module_path) + logger.debug(f"[MultiGPU] Checking for module at {full_path}") + if os.path.isdir(full_path): + logger.debug(f"[MultiGPU] Found exact module match for {module_path} at {full_path}") + return True + + for custom_nodes_path in custom_nodes_paths: + try: + with os.scandir(custom_nodes_path) as entries: + for entry in entries: + if not entry.is_dir(): + continue + if _normalize_module_name(entry.name) == normalized_module_path: + logger.debug(f"[MultiGPU] Found normalized module match for {module_path} at {entry.path}") + return True + except OSError: + continue + + logger.debug(f"[MultiGPU] Module {module_path} not found - skipping") + return False current_device = mm.get_torch_device() current_text_encoder_device = mm.text_encoder_device() current_unet_offload_device = mm.unet_offload_device() +_aimdo_initialized_devices = set() +if isinstance(current_device, torch.device) and current_device.type == "cuda" and current_device.index is not None: + _aimdo_initialized_devices.add(current_device.index) def set_current_device(device): """Set the current device context for MultiGPU operations.""" @@ -183,6 +210,73 @@ def get_current_unet_offload_device(): """Get the current UNet offload device context at runtime.""" return current_unet_offload_device +def _coerce_torch_device(device): + """Best-effort conversion to torch.device for guard and patch helpers.""" + if device is None: + return None + if isinstance(device, torch.device): + return device + try: + return torch.device(device) + except (TypeError, RuntimeError, ValueError): + return None + +@contextmanager +def cuda_device_guard(device, reason="runtime"): + """Temporarily switch the real CUDA current device for non-primary execution paths.""" + target_device = _coerce_torch_device(device) + previous_device_index = None + switched_device = False + + if ( + target_device is not None + and target_device.type == "cuda" + and target_device.index is not None + and torch.cuda.is_available() + ): + previous_device_index = torch.cuda.current_device() + if previous_device_index != target_device.index: + logger.info( + f"[MultiGPU CUDA Guard] Switching CUDA current device {previous_device_index} -> {target_device.index} ({reason})" + ) + torch.cuda.set_device(target_device.index) + switched_device = True + + try: + yield target_device + finally: + if switched_device and previous_device_index is not None: + torch.cuda.set_device(previous_device_index) + logger.info( + f"[MultiGPU CUDA Guard] Restored CUDA current device {target_device.index} -> {previous_device_index} ({reason})" + ) + +def _get_runtime_device_from_model(model): + """Resolve the actual execution device from a model or patcher wrapper.""" + if hasattr(model, "load_device"): + return getattr(model, "load_device") + patcher = getattr(model, "patcher", None) + if patcher is not None and hasattr(patcher, "load_device"): + return patcher.load_device + inner_model = getattr(model, "model", None) + if inner_model is not None and hasattr(inner_model, "load_device"): + return inner_model.load_device + return None + +@contextmanager +def multigpu_runtime_device_guard(device, reason="runtime"): + """Align MultiGPU logical device state with the real runtime device for inference.""" + original_device = get_current_device() + target_device = _coerce_torch_device(device) or device + if target_device is not None: + set_current_device(target_device) + logger.info(f"[MultiGPU Runtime] Using runtime device {target_device} ({reason})") + try: + with cuda_device_guard(target_device, reason=reason): + yield _coerce_torch_device(target_device) + finally: + set_current_device(original_device) + def get_torch_device_patched(): """Return MultiGPU-aware device selection for patched mm.get_torch_device.""" device = None @@ -216,7 +310,112 @@ def unet_offload_device_patched(): logger.debug(f"[MultiGPU Core Patching] unet_offload_device_patched returning device: {device} (current_unet_offload_device={current_unet_offload_device})") return device -logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device") +def _patch_model_management_current_stream(): + """Make ComfyUI stream lookup honor the requested CUDA device.""" + current_stream = getattr(mm, "current_stream", None) + if current_stream is None: + return False + if getattr(current_stream, "_multigpu_cuda_device_aware", False): + return True + + def current_stream_device_aware(device): + target_device = _coerce_torch_device(device) + if target_device is not None and target_device.type == "cuda": + return torch.cuda.current_stream(device=target_device) + return current_stream(device) + + current_stream_device_aware._multigpu_cuda_device_aware = True + current_stream_device_aware._multigpu_original = current_stream + mm.current_stream = current_stream_device_aware + logger.info("[MultiGPU] Patched comfy.model_management.current_stream to honor CUDA device arguments") + return True + +def _initialize_aimdo_visible_cuda_devices(): + """Ensure DynamicVRAM initializes every visible CUDA device once when enabled.""" + if not getattr(comfy.memory_management, "aimdo_enabled", False): + logger.info("[MultiGPU] DynamicVRAM not enabled; skipping multi-device aimdo initialization") + return False + if not torch.cuda.is_available(): + logger.info("[MultiGPU] CUDA unavailable; skipping multi-device aimdo initialization") + return False + + try: + from comfy_aimdo import control as aimdo_control + except ImportError: + logger.warning("[MultiGPU] comfy_aimdo unavailable during multi-device initialization") + return False + + init_device = getattr(aimdo_control, "init_device", None) + if not callable(init_device): + logger.warning("[MultiGPU] comfy_aimdo.control.init_device missing; skipping multi-device initialization") + return False + + initialized_any = False + for device_index in range(torch.cuda.device_count()): + if device_index in _aimdo_initialized_devices: + continue + logger.info(f"[MultiGPU] Initializing comfy_aimdo for CUDA device {device_index}") + initialized = bool(init_device(device_index)) + logger.info(f"[MultiGPU] comfy_aimdo init_device({device_index}) -> {initialized}") + if initialized: + _aimdo_initialized_devices.add(device_index) + initialized_any = True + + return initialized_any + +def _patch_comfy_sample_runtime_device(): + """Wrap Comfy sampling entrypoints so runtime device state matches the model load device.""" + sample_fn = getattr(comfy_sample, "sample", None) + if callable(sample_fn) and not getattr(sample_fn, "_multigpu_runtime_device_guard", False): + def sample_with_runtime_device(model, *args, **kwargs): + runtime_device = _get_runtime_device_from_model(model) + with multigpu_runtime_device_guard(runtime_device, reason=f"comfy.sample.sample:{type(model).__name__}"): + return sample_fn(model, *args, **kwargs) + + sample_with_runtime_device._multigpu_runtime_device_guard = True + sample_with_runtime_device._multigpu_original = sample_fn + comfy_sample.sample = sample_with_runtime_device + logger.info("[MultiGPU] Patched comfy.sample.sample with runtime device guard") + + sample_custom_fn = getattr(comfy_sample, "sample_custom", None) + if callable(sample_custom_fn) and not getattr(sample_custom_fn, "_multigpu_runtime_device_guard", False): + def sample_custom_with_runtime_device(model, *args, **kwargs): + runtime_device = _get_runtime_device_from_model(model) + with multigpu_runtime_device_guard(runtime_device, reason=f"comfy.sample.sample_custom:{type(model).__name__}"): + return sample_custom_fn(model, *args, **kwargs) + + sample_custom_with_runtime_device._multigpu_runtime_device_guard = True + sample_custom_with_runtime_device._multigpu_original = sample_custom_fn + comfy_sample.sample_custom = sample_custom_with_runtime_device + logger.info("[MultiGPU] Patched comfy.sample.sample_custom with runtime device guard") + +def _patch_comfy_kitchen_dlpack_device_guard(): + """Guard comfy_kitchen DLPack export by switching to the tensor's CUDA device.""" + try: + comfy_kitchen_cuda = importlib.import_module("comfy_kitchen.backends.cuda") + except ImportError: + logger.debug("[MultiGPU] comfy_kitchen not found - skipping CUDA DLPack compat patch") + return False + + wrap_for_dlpack = getattr(comfy_kitchen_cuda, "_wrap_for_dlpack", None) + if wrap_for_dlpack is None: + logger.debug("[MultiGPU] comfy_kitchen.backends.cuda._wrap_for_dlpack not found - skipping compat patch") + return False + + if getattr(wrap_for_dlpack, "_multigpu_cuda_device_guard", False): + return True + + def wrap_for_dlpack_with_device_guard(*args, **kwargs): + tensor = args[0] if args else kwargs.get("tensor") + with cuda_device_guard(getattr(tensor, "device", None), reason="comfy_kitchen._wrap_for_dlpack"): + return wrap_for_dlpack(*args, **kwargs) + + wrap_for_dlpack_with_device_guard._multigpu_cuda_device_guard = True + comfy_kitchen_cuda._wrap_for_dlpack = wrap_for_dlpack_with_device_guard + logger.info("[MultiGPU] Applied comfy_kitchen CUDA DLPack device guard patch") + return True + +logger.info("[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device") logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}") logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}") logger.info(f"[MultiGPU DEBUG] Initial current_unet_offload_device: {current_unet_offload_device}") @@ -224,8 +423,13 @@ def unet_offload_device_patched(): mm.get_torch_device = get_torch_device_patched mm.text_encoder_device = text_encoder_device_patched mm.unet_offload_device = unet_offload_device_patched +_patch_model_management_current_stream() +_patch_comfy_sample_runtime_device() +_patch_comfy_kitchen_dlpack_device_guard() +_initialize_aimdo_visible_cuda_devices() from .nodes import ( + DeviceSelectorMultiGPU, UnetLoaderGGUF, UnetLoaderGGUFAdvanced, CLIPLoaderGGUF, @@ -246,47 +450,24 @@ def unet_offload_device_patched(): UNetLoaderLP, ) -from .wanvideo import ( - LoadWanVideoT5TextEncoder, - WanVideoTextEncode, - WanVideoTextEncodeCached, - WanVideoTextEncodeSingle, - WanVideoVAELoader, - WanVideoTinyVAELoader, - WanVideoBlockSwap, - WanVideoImageToVideoEncode, - WanVideoDecode, - WanVideoModelLoader, - WanVideoSampler, - WanVideoVACEEncode, - WanVideoEncode, - LoadWanVideoClipTextEncoder, - WanVideoClipVisionEncode, - WanVideoControlnetLoader, - FantasyTalkingModelLoader, - Wav2VecModelLoader, - WanVideoUni3C_ControlnetLoader, - DownloadAndLoadWav2VecModel, -) - from .wrappers import ( override_class, override_class_offload, override_class_clip, override_class_clip_no_device, override_class_with_distorch_gguf, - override_class_with_distorch_gguf_v2, + override_class_with_distorch_gguf_v2 as override_class_with_distorch_gguf_v2, override_class_with_distorch_clip, override_class_with_distorch_clip_no_device, - override_class_with_distorch, + override_class_with_distorch as override_class_with_distorch, override_class_with_distorch_safetensor_v2, override_class_with_distorch_safetensor_v2_clip, override_class_with_distorch_safetensor_v2_clip_no_device, ) from .distorch_2 import ( - register_patched_safetensor_modelpatcher, - analyze_safetensor_loading, - calculate_safetensor_vvram_allocation, + register_patched_safetensor_modelpatcher as register_patched_safetensor_modelpatcher, + analyze_safetensor_loading as analyze_safetensor_loading, + calculate_safetensor_vvram_allocation as calculate_safetensor_vvram_allocation, ) from .checkpoint_multigpu import ( @@ -294,9 +475,57 @@ def unet_offload_device_patched(): CheckpointLoaderAdvancedDisTorch2MultiGPU ) +def _load_wanvideo_nodes(): + from .wanvideo import ( + LoadWanVideoT5TextEncoder, + WanVideoTextEncode, + WanVideoTextEncodeCached, + WanVideoTextEncodeSingle, + WanVideoVAELoader, + WanVideoTinyVAELoader, + WanVideoBlockSwap, + WanVideoImageToVideoEncode, + WanVideoDecode, + WanVideoModelLoader, + WanVideoSampler, + WanVideoVACEEncode, + WanVideoEncode, + LoadWanVideoClipTextEncoder, + WanVideoClipVisionEncode, + WanVideoControlnetLoader, + FantasyTalkingModelLoader, + Wav2VecModelLoader, + WanVideoUni3C_ControlnetLoader, + DownloadAndLoadWav2VecModel, + ) + + return { + "LoadWanVideoT5TextEncoderMultiGPU": LoadWanVideoT5TextEncoder, + "WanVideoTextEncodeMultiGPU": WanVideoTextEncode, + "WanVideoTextEncodeCachedMultiGPU": WanVideoTextEncodeCached, + "WanVideoTextEncodeSingleMultiGPU": WanVideoTextEncodeSingle, + "WanVideoVAELoaderMultiGPU": WanVideoVAELoader, + "WanVideoTinyVAELoaderMultiGPU": WanVideoTinyVAELoader, + "WanVideoBlockSwapMultiGPU": WanVideoBlockSwap, + "WanVideoImageToVideoEncodeMultiGPU": WanVideoImageToVideoEncode, + "WanVideoDecodeMultiGPU": WanVideoDecode, + "WanVideoModelLoaderMultiGPU": WanVideoModelLoader, + "WanVideoSamplerMultiGPU": WanVideoSampler, + "WanVideoVACEEncodeMultiGPU": WanVideoVACEEncode, + "WanVideoEncodeMultiGPU": WanVideoEncode, + "LoadWanVideoClipTextEncoderMultiGPU": LoadWanVideoClipTextEncoder, + "WanVideoClipVisionEncodeMultiGPU": WanVideoClipVisionEncode, + "WanVideoControlnetLoaderMultiGPU": WanVideoControlnetLoader, + "FantasyTalkingModelLoaderMultiGPU": FantasyTalkingModelLoader, + "Wav2VecModelLoaderMultiGPU": Wav2VecModelLoader, + "WanVideoUni3C_ControlnetLoaderMultiGPU": WanVideoUni3C_ControlnetLoader, + "DownloadAndLoadWav2VecModelMultiGPU": DownloadAndLoadWav2VecModel, + } + NODE_CLASS_MAPPINGS = { "CheckpointLoaderAdvancedMultiGPU": CheckpointLoaderAdvancedMultiGPU, "CheckpointLoaderAdvancedDisTorch2MultiGPU": CheckpointLoaderAdvancedDisTorch2MultiGPU, + "DeviceSelectorMultiGPU": DeviceSelectorMultiGPU, "UNetLoaderLP": UNetLoaderLP, } @@ -339,14 +568,20 @@ def register_and_count(module_names, node_map): if check_module_exists(name): found = True break - + count = 0 if found: + try: + resolved_node_map = node_map() if callable(node_map) else node_map + except Exception as exc: + logger.warning(f"[MultiGPU] Failed to register nodes for {module_names[0]}: {exc}") + resolved_node_map = {} + initial_len = len(NODE_CLASS_MAPPINGS) - for key, value in node_map.items(): + for key, value in resolved_node_map.items(): NODE_CLASS_MAPPINGS[key] = value count = len(NODE_CLASS_MAPPINGS) - initial_len - + registration_data.append({"name": module_names[0], "found": "Y" if found else "N", "count": count}) return found @@ -401,29 +636,7 @@ def register_and_count(module_names, node_map): } register_and_count(["PuLID_ComfyUI", "pulid_comfyui"], pulid_nodes) -wanvideo_nodes = { - "LoadWanVideoT5TextEncoderMultiGPU": LoadWanVideoT5TextEncoder, - "WanVideoTextEncodeMultiGPU": WanVideoTextEncode, - "WanVideoTextEncodeCachedMultiGPU": WanVideoTextEncodeCached, - "WanVideoTextEncodeSingleMultiGPU": WanVideoTextEncodeSingle, - "WanVideoVAELoaderMultiGPU": WanVideoVAELoader, - "WanVideoTinyVAELoaderMultiGPU": WanVideoTinyVAELoader, - "WanVideoBlockSwapMultiGPU": WanVideoBlockSwap, - "WanVideoImageToVideoEncodeMultiGPU": WanVideoImageToVideoEncode, - "WanVideoDecodeMultiGPU": WanVideoDecode, - "WanVideoModelLoaderMultiGPU": WanVideoModelLoader, - "WanVideoSamplerMultiGPU": WanVideoSampler, - "WanVideoVACEEncodeMultiGPU": WanVideoVACEEncode, - "WanVideoEncodeMultiGPU": WanVideoEncode, - "LoadWanVideoClipTextEncoderMultiGPU": LoadWanVideoClipTextEncoder, - "WanVideoClipVisionEncodeMultiGPU": WanVideoClipVisionEncode, - "WanVideoControlnetLoaderMultiGPU": WanVideoControlnetLoader, - "FantasyTalkingModelLoaderMultiGPU": FantasyTalkingModelLoader, - "Wav2VecModelLoaderMultiGPU": Wav2VecModelLoader, - "WanVideoUni3C_ControlnetLoaderMultiGPU": WanVideoUni3C_ControlnetLoader, - "DownloadAndLoadWav2VecModelMultiGPU": DownloadAndLoadWav2VecModel, -} -register_and_count(["ComfyUI-WanVideoWrapper", "comfyui-wanvideowrapper"], wanvideo_nodes) +register_and_count(["ComfyUI-WanVideoWrapper", "comfyui-wanvideowrapper"], _load_wanvideo_nodes) for item in registration_data: logger.info(fmt_reg.format(item['name'], item['found'], str(item['count']))) diff --git a/checkpoint_multigpu.py b/checkpoint_multigpu.py index 40aba3f..b42b453 100644 --- a/checkpoint_multigpu.py +++ b/checkpoint_multigpu.py @@ -21,28 +21,39 @@ def patch_load_state_dict_guess_config(): """Monkey patch comfy.sd.load_state_dict_guess_config with MultiGPU-aware checkpoint loading.""" global original_load_state_dict_guess_config - + if original_load_state_dict_guess_config is not None: logger.debug("[MultiGPU Checkpoint] load_state_dict_guess_config is already patched.") return - + logger.info("[MultiGPU Core Patching] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading.") original_load_state_dict_guess_config = comfy.sd.load_state_dict_guess_config comfy.sd.load_state_dict_guess_config = patched_load_state_dict_guess_config def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, - te_model_options={}, metadata=None): + te_model_options={}, metadata=None, disable_dynamic=False): """Patched checkpoint loader with MultiGPU and DisTorch2 device placement support.""" from . import set_current_device, set_current_text_encoder_device, get_current_device, get_current_text_encoder_device - + sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel')) config_hash = str(sd_size) device_config = checkpoint_device_config.get(config_hash) distorch_config = checkpoint_distorch_config.get(config_hash) if not device_config and not distorch_config: - return original_load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options, metadata) + return original_load_state_dict_guess_config( + sd, + output_vae=output_vae, + output_clip=output_clip, + output_clipvision=output_clipvision, + embedding_directory=embedding_directory, + output_model=output_model, + model_options=model_options, + te_model_options=te_model_options, + metadata=metadata, + disable_dynamic=disable_dynamic, + ) logger.debug("[MultiGPU Checkpoint] ENTERING Patched Checkpoint Loader") logger.debug(f"[MultiGPU Checkpoint] Received Device Config: {device_config}") @@ -53,7 +64,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, vae = None model = None model_patcher = None - + # Capture the current devices at runtime so we can restore them after loading original_main_device = get_current_device() original_clip_device = get_current_text_encoder_device() @@ -68,12 +79,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) model_config = comfy.model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) - + if model_config is None: logger.warning("[MultiGPU] Warning: Not a standard checkpoint file. Trying to load as diffusion model only.") # Simplified fallback for non-checkpoints set_current_device(device_config.get('unet_device', original_main_device)) - diffusion_model = comfy.sd.load_diffusion_model_state_dict(sd, model_options={}) + diffusion_model = comfy.sd.load_diffusion_model_state_dict( + sd, + model_options={}, + metadata=metadata, + disable_dynamic=disable_dynamic, + ) if diffusion_model is None: return None return (diffusion_model, None, VAE(sd={}), None) @@ -83,18 +99,18 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: weight_dtype = None - + if custom_operations is not None: model_config.custom_operations = custom_operations unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) - - unet_compute_device = device_config.get('unet_device', original_main_device) + + unet_compute_device = torch.device(device_config.get('unet_device', original_main_device)) if model_config.scaled_fp8 is not None: - manual_cast_dtype = mm.unet_manual_cast(None, torch.device(unet_compute_device), model_config.supported_inference_dtypes) + manual_cast_dtype = mm.unet_manual_cast(None, unet_compute_device, model_config.supported_inference_dtypes) else: - manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(unet_compute_device), model_config.supported_inference_dtypes) + manual_cast_dtype = mm.unet_manual_cast(unet_dtype, unet_compute_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) logger.info(f"UNet DType: {unet_dtype}, Manual Cast: {manual_cast_dtype}") @@ -103,19 +119,20 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, clipvision = comfy.clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) if output_model: - unet_compute_device = device_config.get('unet_device', original_main_device) - set_current_device(unet_compute_device) + unet_compute_device = torch.device(device_config.get('unet_device', original_main_device)) + set_current_device(unet_compute_device) inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype) multigpu_memory_log(f"unet:{config_hash[:8]}", "pre-load") model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) + model_patcher_class = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + model_patcher = model_patcher_class(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device()) + model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights") logger.mgpu_mm_log("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup") soft_empty_cache_multigpu() - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device()) multigpu_memory_log(f"unet:{config_hash[:8]}", "post-model") if distorch_config and 'unet_allocation' in distorch_config: @@ -131,7 +148,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, vae_target_device = torch.device(device_config.get('vae_device', original_main_device)) set_current_device(vae_target_device) # Use main device context for VAE multigpu_memory_log(f"vae:{config_hash[:8]}", "pre-load") - + vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) vae = VAE(sd=vae_sd, metadata=metadata) @@ -151,7 +168,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, for pref in scaled_fp8_list: skip = skip or k.startswith(pref) if not skip: - out_sd[k] = sd[k] + out_sd[k] = sd[k] for pref in scaled_fp8_list: quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) @@ -159,9 +176,9 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, out_sd[k] = quant_sd[k] sd = out_sd - clip_target_device = device_config.get('clip_device', original_clip_device) + clip_target_device = torch.device(device_config.get('clip_device', original_clip_device)) set_current_text_encoder_device(clip_target_device) - + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) @@ -170,7 +187,15 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, multigpu_memory_log(f"clip:{config_hash[:8]}", "pre-load") soft_empty_cache_multigpu() clip_params = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=clip_params, model_options=te_model_options) + clip = CLIP( + clip_target, + embedding_directory=embedding_directory, + tokenizer_data=clip_sd, + parameters=clip_params, + state_dict=clip_sd, + model_options=te_model_options, + disable_dynamic=disable_dynamic, + ) if distorch_config and 'clip_allocation' in distorch_config: clip_alloc = distorch_config['clip_allocation'] @@ -181,16 +206,13 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, logger.info(f"[CHECKPOINT_META] CLIP inner_model id=0x{id(inner_clip):x}") clip.patcher.model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True) - m, u = clip.load_sd(clip_sd, full_model=True) # This respects the patched text_encoder_device - if len(m) > 0: logger.warning(f"CLIP missing keys: {m}") - if len(u) > 0: logger.debug(f"CLIP unexpected keys: {u}") logger.info("CLIP Loaded.") multigpu_memory_log(f"clip:{config_hash[:8]}", "post-load") else: logger.warning("No CLIP/text encoder weights in checkpoint.") else: logger.warning("CLIP target not found in model config.") - + finally: set_current_device(original_main_device) set_current_text_encoder_device(original_clip_device) @@ -206,7 +228,7 @@ def INPUT_TYPES(s): import folder_paths devices = get_device_list() default_device = devices[1] if len(devices) > 1 else devices[0] - + return { "required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), @@ -215,27 +237,27 @@ def INPUT_TYPES(s): "vae_device": (devices, {"default": default_device}), } } - + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "multigpu" TITLE = "Checkpoint Loader Advanced (MultiGPU)" - + def load_checkpoint(self, ckpt_name, unet_device, clip_device, vae_device): patch_load_state_dict_guess_config() - + import folder_paths import comfy.utils - + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) sd = comfy.utils.load_torch_file(ckpt_path) sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel')) config_hash = str(sd_size) - + checkpoint_device_config[config_hash] = { 'unet_device': unet_device, 'clip_device': clip_device, 'vae_device': vae_device } - + # Load using standard loader, our patch will intercept from nodes import CheckpointLoaderSimple return CheckpointLoaderSimple().load_checkpoint(ckpt_name) @@ -247,7 +269,7 @@ def INPUT_TYPES(s): import folder_paths devices = get_device_list() compute_device = devices[1] if len(devices) > 1 else devices[0] - + return { "required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), @@ -265,18 +287,18 @@ def INPUT_TYPES(s): "eject_models": ("BOOLEAN", {"default": True}), } } - + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "multigpu/distorch_2" TITLE = "Checkpoint Loader Advanced (DisTorch2)" - + def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb, unet_donor_device, clip_compute_device, clip_virtual_vram_gb, clip_donor_device, vae_device, unet_expert_mode_allocations="", clip_expert_mode_allocations="", high_precision_loras=True, eject_models=True): - + if eject_models: - logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction") + logger.mgpu_mm_log("[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction") ejection_count = 0 for i, lm in enumerate(mm.current_loaded_models): model_name = type(getattr(lm.model, 'model', lm.model)).__name__ if lm.model else 'Unknown' @@ -289,17 +311,17 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb, logger.mgpu_mm_log(f"[EJECT_MARKED] Model {i}: {model_name} (direct patcher) → marked for eviction") ejection_count += 1 logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP_COMPLETE] Marked {ejection_count} models for Comfy Core eviction during load_models_gpu") - - patch_load_state_dict_guess_config() - + + patch_load_state_dict_guess_config() + import folder_paths import comfy.utils - + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) sd = comfy.utils.load_torch_file(ckpt_path) sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel')) config_hash = str(sd_size) - + checkpoint_device_config[config_hash] = { 'unet_device': unet_compute_device, 'clip_device': clip_compute_device, @@ -312,7 +334,7 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb, elif unet_expert_mode_allocations: unet_vram_str = unet_compute_device unet_alloc = f"{unet_expert_mode_allocations}#{unet_vram_str}" if unet_expert_mode_allocations or unet_vram_str else "" - + clip_vram_str = "" if clip_virtual_vram_gb > 0: clip_vram_str = f"{clip_compute_device};{clip_virtual_vram_gb};{clip_donor_device}" @@ -327,6 +349,6 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb, 'unet_settings': hashlib.sha256(f"{unet_alloc}{high_precision_loras}".encode()).hexdigest(), 'clip_settings': hashlib.sha256(f"{clip_alloc}{high_precision_loras}".encode()).hexdigest(), } - + from nodes import CheckpointLoaderSimple return CheckpointLoaderSimple().load_checkpoint(ckpt_name) diff --git a/ci/extract_allocation.py b/ci/extract_allocation.py index a11c05e..734ee23 100644 --- a/ci/extract_allocation.py +++ b/ci/extract_allocation.py @@ -3,10 +3,16 @@ import argparse import json +import sys from pathlib import Path from typing import Iterable, Iterator, Dict, Any +def _write_stdout(message: str = "") -> None: + sys.stdout.write(f"{message}\n") + sys.stdout.flush() + + def load_json_lines(path: Path) -> Iterator[Dict[str, Any]]: with path.open("r", encoding="utf-8") as handle: for line in handle: @@ -37,12 +43,12 @@ def main() -> int: entries = list(load_json_lines(args.logfile)) if not entries: - print("No entries found in log file.") + _write_stdout("No entries found in log file.") return 0 matched = [entry for entry in entries if is_allocation_event(entry, args.keywords)] if not matched: - print("No allocation events matched provided keywords.") + _write_stdout("No allocation events matched provided keywords.") return 0 for entry in matched: @@ -51,9 +57,9 @@ def main() -> int: component = entry.get("component", "") header_bits = [bit for bit in (timestamp, category, component) if bit] header = " | ".join(header_bits) if header_bits else "allocation" - print(f"## {header}") - print(entry.get("message", "")) - print() + _write_stdout(f"## {header}") + _write_stdout(entry.get("message", "")) + _write_stdout() return 0 diff --git a/ci/run_workflows.py b/ci/run_workflows.py index 8ad7e3d..a2b9879 100644 --- a/ci/run_workflows.py +++ b/ci/run_workflows.py @@ -19,6 +19,16 @@ DEFAULT_WORKFLOW_TIMEOUT = int(os.environ.get("COMFYUI_WORKFLOW_TIMEOUT", "900")) +def _write_stdout(message: str) -> None: + sys.stdout.write(f"{message}\n") + sys.stdout.flush() + + +def _write_stderr(message: str) -> None: + sys.stderr.write(f"{message}\n") + sys.stderr.flush() + + class ComfyWorkflowRunner: def __init__(self, host: str, port: int, connect_timeout: int, workflow_timeout: int, secure: bool = False) -> None: self.host = host @@ -78,7 +88,7 @@ def wait_for_completion(self, prompt_id: str) -> bool: except websocket.WebSocketTimeoutException: continue except Exception as exc: # noqa: BLE001 - print(f"WebSocket error: {exc}", file=sys.stderr, flush=True) + _write_stderr(f"WebSocket error: {exc}") return False if isinstance(message, bytes): @@ -94,16 +104,16 @@ def wait_for_completion(self, prompt_id: str) -> bool: if message_type == "execution_error": if data.get("prompt_id") == prompt_id: - print(f"Execution error: {payload}", file=sys.stderr, flush=True) + _write_stderr(f"Execution error: {payload}") return False elif message_type == "status" and data.get("status") == "error": if data.get("prompt_id") == prompt_id: - print(f"Status error: {payload}", file=sys.stderr, flush=True) + _write_stderr(f"Status error: {payload}") return False elif message_type == "executing": if data.get("prompt_id") == prompt_id and data.get("node") is None: return True - print("Workflow timed out", file=sys.stderr, flush=True) + _write_stderr("Workflow timed out") return False def run_workflow(self, workflow_path: Path) -> bool: @@ -126,25 +136,25 @@ def restore_env() -> None: with workflow_path.open("r", encoding="utf-8") as handle: workflow = json.load(handle) except (OSError, json.JSONDecodeError) as exc: - print(f"Failed to load workflow {workflow_path}: {exc}", file=sys.stderr, flush=True) + _write_stderr(f"Failed to load workflow {workflow_path}: {exc}") restore_env() return False - print(f"Running workflow {workflow_path}", flush=True) + _write_stdout(f"Running workflow {workflow_path}") start = time.monotonic() try: prompt_id = self.queue_prompt(workflow) os.environ["MGPU_JSON_PROMPT"] = prompt_id except requests.HTTPError as exc: - print(f"HTTP error while queueing workflow: {exc}", file=sys.stderr, flush=True) + _write_stderr(f"HTTP error while queueing workflow: {exc}") restore_env() return False except requests.RequestException as exc: - print(f"Request error while queueing workflow: {exc}", file=sys.stderr, flush=True) + _write_stderr(f"Request error while queueing workflow: {exc}") restore_env() return False except RuntimeError as exc: - print(str(exc), file=sys.stderr, flush=True) + _write_stderr(str(exc)) restore_env() return False @@ -152,7 +162,7 @@ def restore_env() -> None: if not self.wait_for_completion(prompt_id): return False duration = time.monotonic() - start - print(f"Workflow {workflow_path} completed in {duration:.2f}s", flush=True) + _write_stdout(f"Workflow {workflow_path} completed in {duration:.2f}s") return True finally: restore_env() diff --git a/ci/summarize_log.py b/ci/summarize_log.py index f1d3b4b..1e6fd94 100644 --- a/ci/summarize_log.py +++ b/ci/summarize_log.py @@ -3,10 +3,16 @@ import argparse import json +import sys from pathlib import Path from typing import Iterator, Dict, Any +def _write_stdout(message: str = "") -> None: + sys.stdout.write(f"{message}\n") + sys.stdout.flush() + + def load_json_lines(path: Path) -> Iterator[Dict[str, Any]]: with path.open("r", encoding="utf-8") as handle: for line in handle: @@ -32,11 +38,11 @@ def main() -> int: entries = list(load_json_lines(args.logfile)) if not entries: - print("No entries found in log file.") + _write_stdout("No entries found in log file.") return 0 - print("| Timestamp | Level | Component | Message |") - print("| --- | --- | --- | --- |") + _write_stdout("| Timestamp | Level | Component | Message |") + _write_stdout("| --- | --- | --- | --- |") for entry in entries: level = entry.get("level", "") if args.severity and level not in args.severity: @@ -50,7 +56,7 @@ def main() -> int: continue timestamp = entry.get("timestamp", "") message = entry.get("message", "").replace("|", "\u2502") - print(f"| {timestamp} | {level} | {component} | {message} |") + _write_stdout(f"| {timestamp} | {level} | {component} | {message} |") return 0 diff --git a/device_utils.py b/device_utils.py index 6647572..a29b83b 100644 --- a/device_utils.py +++ b/device_utils.py @@ -1,9 +1,9 @@ import torch import logging -import hashlib import psutil import comfy.model_management as mm import gc +import importlib logger = logging.getLogger("MultiGPU") @@ -14,7 +14,7 @@ def get_device_list(): Enumerate ALL physically available devices that can store torch tensors. This includes all device types supported by ComfyUI core. Results are cached after first call since devices don't change during runtime. - + Returns a comprehensive list of all available devices across all types: - CPU (always available) - CUDA devices (NVIDIA GPUs + AMD w/ ROCm GPUs) @@ -26,51 +26,51 @@ def get_device_list(): - CoreX/IXUCA devices """ global _DEVICE_LIST_CACHE - + if _DEVICE_LIST_CACHE is not None: return _DEVICE_LIST_CACHE - + devs = [] - + devs.append("cpu") - + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available(): device_count = torch.cuda.device_count() devs += [f"cuda:{i}" for i in range(device_count)] logger.debug(f"[MultiGPU_Device_Utils] Found {device_count} CUDA device(s)") - + try: - import intel_extension_for_pytorch as ipex + importlib.import_module("intel_extension_for_pytorch") except ImportError: pass - + if hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available(): device_count = torch.xpu.device_count() devs += [f"xpu:{i}" for i in range(device_count)] logger.debug(f"[MultiGPU_Device_Utils] Found {device_count} XPU device(s)") - + try: - import torch_npu + importlib.import_module("torch_npu") if hasattr(torch, "npu") and hasattr(torch.npu, "is_available") and torch.npu.is_available(): device_count = torch.npu.device_count() devs += [f"npu:{i}" for i in range(device_count)] logger.debug(f"[MultiGPU_Device_Utils] Found {device_count} NPU device(s)") except ImportError: pass - + try: - import torch_mlu + importlib.import_module("torch_mlu") if hasattr(torch, "mlu") and hasattr(torch.mlu, "is_available") and torch.mlu.is_available(): device_count = torch.mlu.device_count() devs += [f"mlu:{i}" for i in range(device_count)] logger.debug(f"[MultiGPU_Device_Utils] Found {device_count} MLU device(s)") except ImportError: pass - + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): devs.append("mps") logger.debug("[MultiGPU_Device_Utils] Found MPS device") - + try: import torch_directml adapter_count = torch_directml.device_count() @@ -79,7 +79,7 @@ def get_device_list(): logger.debug(f"[MultiGPU_Device_Utils] Found {adapter_count} DirectML adapter(s)") except ImportError: pass - + try: if hasattr(torch, "corex"): if hasattr(torch.corex, "device_count"): @@ -91,30 +91,30 @@ def get_device_list(): logger.debug("[MultiGPU_Device_Utils] Found CoreX device") except ImportError: pass - + _DEVICE_LIST_CACHE = devs - + logger.debug(f"[MultiGPU_Device_Utils] Device list initialized: {devs}") - + return devs def is_accelerator_available(): """Check if any GPU or accelerator device is available including CUDA, XPU, NPU, MLU, MPS, DirectML, or CoreX.""" if hasattr(torch, "cuda") and torch.cuda.is_available(): return True - + if hasattr(torch, "xpu") and hasattr(torch.xpu, "is_available") and torch.xpu.is_available(): return True - + try: - import torch_npu + importlib.import_module("torch_npu") if hasattr(torch, "npu") and hasattr(torch.npu, "is_available") and torch.npu.is_available(): return True except ImportError: pass try: - import torch_mlu + importlib.import_module("torch_mlu") if hasattr(torch, "mlu") and hasattr(torch.mlu, "is_available") and torch.mlu.is_available(): return True except ImportError: @@ -132,7 +132,7 @@ def is_accelerator_available(): if hasattr(torch, "corex"): return True - + return False def is_device_compatible(device_string): @@ -156,7 +156,7 @@ def parse_device_string(device_string): def soft_empty_cache_multigpu(): """Clear allocator caches across all devices using context managers to preserve calling thread device context.""" from .model_management_mgpu import multigpu_memory_log - + logger.mgpu_mm_log("soft_empty_cache_multigpu: starting GC and multi-device cache clear") gc.collect() @@ -164,7 +164,7 @@ def soft_empty_cache_multigpu(): # Clear cache for ALL devices (not just ComfyUI's single device) all_devices = get_device_list() logger.mgpu_mm_log(f"soft_empty_cache_multigpu: devices to clear = {all_devices}") - + # Check global availability first to avoid unnecessary iteration if backend is missing is_cuda_available = hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available() @@ -175,6 +175,7 @@ def soft_empty_cache_multigpu(): logger.mgpu_mm_log(f"Clearing CUDA cache on {device_str} (idx={device_idx})") multigpu_memory_log("general", f"pre-empty:{device_str}") with torch.cuda.device(device_idx): + torch.cuda.synchronize() torch.cuda.empty_cache() if hasattr(torch.cuda, "ipc_collect"): torch.cuda.ipc_collect() @@ -234,15 +235,15 @@ def soft_empty_cache_multigpu(): def soft_empty_cache_distorch2_patched(force=False): """Patched mm.soft_empty_cache managing VRAM across all devices, CPU RAM with adaptive thresholding, and DisTorch store pruning.""" - from .model_management_mgpu import multigpu_memory_log, check_cpu_memory_threshold, trigger_executor_cache_reset - + from .model_management_mgpu import check_cpu_memory_threshold, trigger_executor_cache_reset + is_distorch_active = False for i, lm in enumerate(mm.current_loaded_models): mp = lm.model if mp is not None: inner_model = mp.model - + if hasattr(inner_model, '_distorch_v2_meta'): is_distorch_active = True break diff --git a/distorch_2.py b/distorch_2.py index 94d2f53..093f9fe 100644 --- a/distorch_2.py +++ b/distorch_2.py @@ -3,21 +3,16 @@ Contains all safetensor related code for distributed memory management """ -import sys import torch import logging -import hashlib import re -import gc +from collections import defaultdict logger = logging.getLogger("MultiGPU") -import copy -import inspect -from collections import defaultdict import comfy.model_management as mm import comfy.model_patcher -from .device_utils import get_device_list, soft_empty_cache_multigpu -from .model_management_mgpu import multigpu_memory_log, force_full_system_cleanup +from .device_utils import get_device_list +from .model_management_mgpu import multigpu_memory_log @@ -36,19 +31,16 @@ def unpack_load_item(item): def register_patched_safetensor_modelpatcher(): """Register and patch the ModelPatcher for distributed safetensor loading""" - from comfy.model_patcher import wipe_lowvram_weight, move_weight_functions # Patch ComfyUI's ModelPatcher if not hasattr(comfy.model_patcher.ModelPatcher, '_distorch_patched'): # PATCH load_models_gpu with correct memory calculations per model flags - original_load_models_gpu = mm.load_models_gpu - def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): from comfy.model_management import cleanup_models_gc, get_free_memory, free_memory, current_loaded_models from comfy.model_management import VRAMState, vram_state, lowvram_available, MIN_WEIGHT_MEMORY_RATIO from comfy.model_management import minimum_inference_memory, extra_reserved_memory, is_device_cpu - + multigpu_memory_log("load_models_gpu_top_level", "start") cleanup_models_gc() @@ -78,7 +70,7 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False logger.debug(f"[MultiGPU DisTorch V2] Registering mm_patch: {type(mm_patch).__name__}") models_temp.add(mm_patch) continue - + for mm_patch in m.model_patches_models(): models_temp.add(mm_patch) patches = m.model_patches_to(m.load_device) @@ -94,7 +86,7 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False loaded_model = mm.LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) - except: + except ValueError: loaded_model_index = None if loaded_model_index is not None: @@ -108,8 +100,8 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False for loaded_model in models_to_load: to_unload = [] - for i in range(len(current_loaded_models)): - if loaded_model.model.is_clone(current_loaded_models[i].model): + for i, current_loaded_model in enumerate(current_loaded_models): + if loaded_model.model.is_clone(current_loaded_model.model): to_unload = [i] + to_unload for i in to_unload: model_to_unload = current_loaded_models.pop(i) @@ -125,16 +117,16 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False base_memory = loaded_model.model_memory_required(device) inner_model = loaded_model.model.model - + if hasattr(inner_model, '_distorch_v2_meta'): meta = inner_model._distorch_v2_meta allocation_str = meta['full_allocation'] - + # Parse allocation string: "expert#compute_device;virtual_vram_gb;donors" parts = allocation_str.split('#') virtual_vram_gb = 0.0 has_eject = False - + if len(parts) > 1: virtual_vram_str = parts[1] virtual_info = virtual_vram_str.split(';') @@ -142,11 +134,11 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False virtual_vram_gb = float(virtual_info[1]) if len(virtual_info) > 2 and virtual_info[2]: has_eject = True - + if has_eject: eject_device = device logger.mgpu_mm_log("DisTorch eject_models detected - MAX memory eviction") - + virtual_vram_bytes = virtual_vram_gb * (1024**3) adjusted_memory = max(0, base_memory - virtual_vram_bytes) total_memory_required[device] = total_memory_required.get(device, 0) + adjusted_memory @@ -156,24 +148,24 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False total_memory_required[device] = total_memory_required.get(device, 0) + base_memory logger.mgpu_mm_log(f"[LOAD_MODELS_GPU] Standard model {(base_memory)/(1024**3):.2f}GB for device {device}") - for device in total_memory_required: + for device, device_memory in total_memory_required.items(): if device != torch.device("cpu"): - requested_mem = total_memory_required[device] * 1.1 + extra_mem - logger.mgpu_mm_log(f"[FREE_MEMORY_CALL] Device {device}: requesting {requested_mem/(1024**3):.2f}GB = {total_memory_required[device]/(1024**3):.2f}GB * 1.1 + {extra_mem/(1024**3):.2f}GB inference") - - + requested_mem = device_memory * 1.1 + extra_mem + logger.mgpu_mm_log(f"[FREE_MEMORY_CALL] Device {device}: requesting {requested_mem/(1024**3):.2f}GB = {device_memory/(1024**3):.2f}GB * 1.1 + {extra_mem/(1024**3):.2f}GB inference") + + multigpu_memory_log("free_memory", "pre") - for device in total_memory_required: + for device, device_memory in total_memory_required.items(): if device != torch.device("cpu"): if device == eject_device: total_device_memory = mm.get_total_memory(device) logger.mgpu_mm_log(f"[LOAD_MODELS_GPU] eject_models=1, is_distorch=1 → using MAX memory ({total_device_memory/(1024**3):.2f}GB) for eviction") free_memory(total_device_memory,device) else: - logger.mgpu_mm_log(f"[LOAD_MODELS_GPU] eject_models=0, using Comfy Core Computed memory ({(total_memory_required[device] * 1.1 + extra_mem)/(1024**3):.2f}GB) for eviction") - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) - + logger.mgpu_mm_log(f"[LOAD_MODELS_GPU] eject_models=0, using Comfy Core Computed memory ({(device_memory * 1.1 + extra_mem)/(1024**3):.2f}GB) for eviction") + free_memory(device_memory * 1.1 + extra_mem, device) + multigpu_memory_log("free_memory/minimum_memory_required", "post/pre") for device in total_memory_required: @@ -186,7 +178,7 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False if free_mem < minimum_memory_required: models_l = free_memory(minimum_memory_required, device) logger.mgpu_mm_log(f"[EVICTION] Device {device}: unloaded {len(models_l)} models due to insufficient memory") - logging.info("{} models unloaded.".format(len(models_l))) + logging.info(f"{len(models_l)} models unloaded.") multigpu_memory_log("minimum_memory_required", "post") @@ -198,7 +190,7 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False else: vram_set_state = vram_state lowvram_model_memory = 0 - if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load: + if lowvram_available and vram_set_state in (VRAMState.LOW_VRAM, VRAMState.NORMAL_VRAM) and not force_full_load: loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory @@ -218,21 +210,20 @@ def patched_load_models_gpu(models, memory_required=0, force_patch_weights=False def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_patch_weights=False, **kwargs): """Override to use direct model annotation for allocation""" - + mp_id = id(self) - mp_patches_uuid = self.patches_uuid inner_model = self.model inner_model_id = id(inner_model) - + if not hasattr(inner_model, "_distorch_v2_meta"): logger.debug(f"[DISTORCH_SKIP] ModelPatcher=0x{mp_id:x} inner_model=0x{inner_model_id:x} type={type(inner_model).__name__} - no metadata, using standard loading") result = original_partially_load(self, device_to, extra_memory, force_patch_weights) if hasattr(self, '_distorch_block_assignments'): del self._distorch_block_assignments return result - + allocations = inner_model._distorch_v2_meta['full_allocation'] - + if not hasattr(self.model, '_distorch_high_precision_loras'): self.model._distorch_high_precision_loras = True @@ -242,7 +233,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) if unpatch_weights: - logger.debug(f"[MultiGPU DisTorch V2] Patches changed or forced. Unpatching model.") + logger.debug("[MultiGPU DisTorch V2] Patches changed or forced. Unpatching model.") self.unpatch_model(self.offload_device, unpatch_weights=True) self.patch_model(load_weights=False) @@ -254,7 +245,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p # Check for valid cache allocations_match = hasattr(self, '_distorch_last_allocations') and self._distorch_last_allocations == allocations cache_exists = hasattr(self, '_distorch_cached_assignments') - + if cache_exists and allocations_match and not unpatch_weights and not force_patch_weights: device_assignments = self._distorch_cached_assignments logger.debug(f"[MultiGPU DisTorch V2] Reusing cached analysis for {type(inner_model).__name__}") @@ -262,7 +253,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p device_assignments = analyze_safetensor_loading(self, allocations, is_clip=is_clip_model) ## This should be the only required line - that is how it worked previous release so if it doesn't it is Comfy changes self._distorch_cached_assignments = device_assignments self._distorch_last_allocations = allocations - + model_original_dtype = comfy.utils.weight_dtype(self.model.state_dict()) high_precision_loras = getattr(self.model, "_distorch_high_precision_loras", True) # Use standard ComfyUI load list - the device comparison fix ensures we don't crash @@ -270,12 +261,12 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p loading.sort(reverse=True) for item in loading: module_size, module_name, module_object, params = unpack_load_item(item) - if not unpatch_weights and hasattr(module_object, "comfy_patched_weights") and module_object.comfy_patched_weights == True: + if not unpatch_weights and hasattr(module_object, "comfy_patched_weights") and module_object.comfy_patched_weights is True: block_target_device = device_assignments['block_assignments'].get(module_name, device_to) current_module_device = None try: if any(p.numel() > 0 for p in module_object.parameters(recurse=False)): - current_module_device = next(module_object.parameters(recurse=False)).device + current_module_device = next(module_object.parameters(recurse=False)).device except StopIteration: pass @@ -290,8 +281,8 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p module_object.to(device_to) # Step 2: Apply LoRa patches while on compute device - weight_key = "{}.weight".format(module_name) - bias_key = "{}.bias".format(module_name) + weight_key = f"{module_name}.weight" + bias_key = f"{module_name}.bias" if weight_key in self.patches: self.patch_weight_to_device(weight_key, device_to=device_to) @@ -335,7 +326,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p return 0 - + comfy.model_patcher.ModelPatcher.partially_load = new_partially_load comfy.model_patcher.ModelPatcher._distorch_patched = True logger.info("[MultiGPU Core Patching] Successfully patched ModelPatcher.partially_load") @@ -347,9 +338,9 @@ def _extract_clip_head_blocks(raw_block_list, compute_device): distributable_blocks = [] head_memory = 0 block_assignments = {} - + block_assignments = {} - + for item in raw_block_list: module_size, module_name, module_object, params = unpack_load_item(item) if any(kw in module_name.lower() for kw in head_keywords): @@ -358,7 +349,7 @@ def _extract_clip_head_blocks(raw_block_list, compute_device): head_memory += module_size else: distributable_blocks.append((module_size, module_name, module_object, params)) - + return head_blocks, distributable_blocks, block_assignments, head_memory def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False): @@ -370,8 +361,6 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) device_table = {} distorch_alloc = "" virtual_vram_str = "" - virtual_vram_gb = 0.0 - if '#' in allocations_string: distorch_alloc, virtual_vram_str = allocations_string.split('#', 1) else: @@ -381,16 +370,13 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) logger.debug(f"[MultiGPU DisTorch V2] Compute Device: {compute_device}") if not distorch_alloc: - mode = "fraction" distorch_alloc = calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str) elif any(c in distorch_alloc.lower() for c in ['g', 'm', 'k', 'b']): - mode = "byte" distorch_alloc = calculate_fraction_from_byte_expert_string(model_patcher, distorch_alloc) elif "%" in distorch_alloc: - mode = "ratio" distorch_alloc = calculate_fraction_from_ratio_expert_string(model_patcher, distorch_alloc) - + all_devices = get_device_list() present_devices = {item.split(',')[0] for item in distorch_alloc.split(';') if ',' in item} for device in all_devices: @@ -421,20 +407,20 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) logger.info(eq_line) logger.info(" DisTorch2 Model Device Allocations") logger.info(eq_line) - + fmt_rosetta = "{:<8}{:>9}{:>9}{:>11}{:>10}" logger.info(fmt_rosetta.format("Device", "VRAM GB", "Dev %", "Model GB", "Dist %")) logger.info(dash_line) sorted_devices = sorted(device_table.keys(), key=lambda d: (d == "cpu", d)) - + total_allocated_model_bytes = sum(d["alloc_gb"] * (1024**3) for d in device_table.values()) for dev in sorted_devices: total_dev_gb = device_table[dev]["total_gb"] alloc_fraction = device_table[dev]["fraction"] alloc_gb = device_table[dev]["alloc_gb"] - + dist_ratio_percent = (alloc_gb * (1024**3) / total_allocated_model_bytes) * 100 if total_allocated_model_bytes > 0 else 0 logger.info(fmt_rosetta.format( @@ -444,7 +430,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) f"{alloc_gb:.2f}", f"{dist_ratio_percent:.1f}%" )) - + logger.info(dash_line) block_summary = {} @@ -485,9 +471,9 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) module_size, module_name, module_object, params = unpack_load_item(item) distributable_all_blocks.append((module_name, module_object, type(module_object).__name__, module_size)) - block_list = [b for b in distributable_all_blocks if (b[3] >= MIN_BLOCK_THRESHOLD and hasattr(b[1], "bias"))] + block_list = [b for b in distributable_all_blocks if (b[3] >= MIN_BLOCK_THRESHOLD and hasattr(b[1], "bias"))] tiny_block_list = [b for b in distributable_all_blocks if b not in block_list] - + logger.debug(f"[MultiGPU DisTorch V2] Total blocks: {len(all_blocks)}") logger.debug(f"[MultiGPU DisTorch V2] Distributable blocks: {len(block_list)}") logger.debug(f"[MultiGPU DisTorch V2] Tiny blocks (<0.01%): {len(tiny_block_list)}") @@ -497,19 +483,19 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) fmt_layer = "{:<18}{:>7}{:>14}{:>10}" logger.info(fmt_layer.format("Layer Type", "Layers", "Memory (MB)", "% Total")) logger.info(dash_line) - + for layer_type, count in block_summary.items(): mem_mb = memory_by_type[layer_type] / (1024 * 1024) mem_percent = (memory_by_type[layer_type] / total_memory) * 100 if total_memory > 0 else 0 logger.info(fmt_layer.format(layer_type[:18], str(count), f"{mem_mb:.2f}", f"{mem_percent:.1f}%")) - + logger.info(dash_line) # Distribute blocks sequentially from the tail of the model - device_assignments = {device: [] for device in DEVICE_RATIOS_DISTORCH.keys()} + device_assignments = {device: [] for device in DEVICE_RATIOS_DISTORCH} # Create a memory quota for each donor device based on its calculated allocation. - donor_devices = [d for d in sorted_devices] + donor_devices = list(sorted_devices) donor_quotas = { dev: device_table[dev]["alloc_gb"] * (1024**3) for dev in donor_devices @@ -529,7 +515,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) donor_quotas[donor] -= block_memory assigned_to_donor = True break # Move to the next block - + if not assigned_to_donor: #Note - small rounding errors and tensor-fitting on devices make a block occasionally an orphan. We treat orphans the same as tiny_block_list as they are generally small rounding errors block_assignments[block_name] = compute_device @@ -549,7 +535,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) logger.info(dash_line) logger.info(fmt_assign.format("Device", "Layers", "Memory (MB)", "% Total")) logger.info(dash_line) - + if tiny_block_list: tiny_block_memory = sum(b[3] for b in tiny_block_list) tiny_mem_mb = tiny_block_memory / (1024 * 1024) @@ -560,7 +546,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) total_assigned_memory = 0 device_memories = {} - + for device, blocks in device_assignments.items(): dist_blocks = [b for b in blocks if b[3] >= MIN_BLOCK_THRESHOLD] if not dist_blocks: @@ -577,11 +563,11 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) dist_blocks = [b for b in device_assignments[dev] if b[3] >= MIN_BLOCK_THRESHOLD] if not dist_blocks: continue - + mem_mb = device_memories[dev] / (1024 * 1024) mem_percent = (device_memories[dev] / total_memory) * 100 if total_memory > 0 else 0 logger.info(fmt_assign.format(dev, str(len(dist_blocks)), f"{mem_mb:.2f}", f"{mem_percent:.1f}%")) - + logger.info(dash_line) return { @@ -595,10 +581,10 @@ def parse_memory_string(mem_str): match = re.match(r'(\d+\.?\d*)\s*([gmkb]?)', mem_str) if not match: raise ValueError(f"Invalid memory string format: {mem_str}") - + val, unit = match.groups() val = float(val) - + if unit == 'g': return val * (1024**3) elif unit == 'm': @@ -623,7 +609,7 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str): continue dev_name, val_str = allocation.split(',', 1) is_wildcard = '*' in val_str - + if is_wildcard: wildcard_device = dev_name # Don't add wildcard to the priority list yet @@ -640,12 +626,12 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str): # Determine the actual bytes to allocate to this device bytes_to_assign = min(requested_bytes, remaining_model_bytes) - + if bytes_to_assign > 0: final_byte_allocations[dev] = bytes_to_assign remaining_model_bytes -= bytes_to_assign logger.info(f"[MultiGPU DisTorch V2] Assigning {bytes_to_assign / (1024**2):.2f}MB of model to {dev} (requested {requested_bytes / (1024**2):.2f}MB).") - + if remaining_model_bytes <= 0: logger.info("[MultiGPU DisTorch V2] All model blocks have been allocated. Subsequent devices in the string will receive no assignment.") break @@ -662,7 +648,7 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str): if total_device_vram > 0: fraction = bytes_alloc / total_device_vram allocation_parts.append(f"{dev},{fraction:.4f}") - + allocations_string = ";".join(allocation_parts) return allocations_string @@ -674,7 +660,8 @@ def calculate_fraction_from_ratio_expert_string(model_patcher, ratio_str): raw_ratios = {} for allocation in ratio_str.split(';'): - if ',' not in allocation: continue + if ',' not in allocation: + continue dev_name, val_str = allocation.split(',', 1) # Assumes the value is a unitless ratio number, ignores '%' for simplicity. value = float(val_str.replace('%','').strip()) @@ -696,7 +683,7 @@ def calculate_fraction_from_ratio_expert_string(model_patcher, ratio_str): ratio_string = ":".join(ratio_values) normalized_pcts = [(v / total_ratio_parts) * 100 for v in raw_ratios.values()] - + put_parts = [] for i, dev_name in enumerate(raw_ratios.keys()): put_parts.append(f"{int(normalized_pcts[i])}% on {dev_name}") @@ -707,7 +694,7 @@ def calculate_fraction_from_ratio_expert_string(model_patcher, ratio_str): put_part = f"{put_parts[0]} and {put_parts[1]}" else: put_part = ", ".join(put_parts[:-1]) + f", and {put_parts[-1]}" - + logger.info(f"[MultiGPU DisTorch V2] Ratio(%) Mode - {ratio_str} -> {ratio_string} ratio, put {put_part}") allocations_string = ";".join(allocation_parts) @@ -736,31 +723,31 @@ def calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str): logger.info(fmt_assign.format(recipient_device, 'recip', f"{recipient_vram:.2f}GB",f"{recipient_virtual:.2f}GB", f"+{virtual_vram_gb:.2f}GB")) # Handle donor devices - ram_donors = [d for d in donors.split(',')] + ram_donors = list(donors.split(',')) remaining_vram_needed = virtual_vram_gb - + donor_device_info = {} donor_allocations = {} - + for donor in ram_donors: donor_vram = mm.get_total_memory(torch.device(donor)) / (1024**3) max_donor_capacity = donor_vram - + donation = min(remaining_vram_needed, max_donor_capacity) donor_virtual = donor_vram - donation remaining_vram_needed -= donation donor_allocations[donor] = donation - + donor_device_info[donor] = (donor_vram, donor_virtual) logger.info(fmt_assign.format(donor, 'donor', f"{donor_vram:.2f}GB", f"{donor_virtual:.2f}GB", f"-{donation:.2f}GB")) - - + + logger.info(dash_line) # Calculate model size model = model_patcher.model if hasattr(model_patcher, 'model') else model_patcher total_memory = 0 - + for name, module in model.named_modules(): if hasattr(module, "weight"): if module.weight is not None: @@ -790,6 +777,6 @@ def calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str): donor_vram = donor_device_info[donor][0] donor_percent = donor_allocations[donor] / donor_vram allocation_parts.append(f"{donor},{donor_percent:.4f}") - + allocations_string = ";".join(allocation_parts) return allocations_string diff --git a/example_workflows/ComfyUI-starter_multigpu.jpg b/example_workflows/ComfyUI-starter_multigpu.jpg new file mode 100644 index 0000000..24d95a6 Binary files /dev/null and b/example_workflows/ComfyUI-starter_multigpu.jpg differ diff --git a/example_workflows/ComfyUI-starter_multigpu.json b/example_workflows/ComfyUI-starter_multigpu.json new file mode 100644 index 0000000..278c0a2 --- /dev/null +++ b/example_workflows/ComfyUI-starter_multigpu.json @@ -0,0 +1,743 @@ +{ + "id": "73611c02-cdf1-4ae8-a41a-d9c057e201ed", + "revision": 0, + "last_node_id": 116, + "last_link_id": 156, + "nodes": [ + { + "id": 74, + "type": "MarkdownNote", + "pos": [ + 10906.666666666668, + -1101.6666666666667 + ], + "size": [ + 210, + 34 + ], + "flags": { + "collapsed": true + }, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [], + "title": "Hit run šŸ‘†", + "properties": {}, + "widgets_values": [ + "" + ], + "color": "#222", + "bgcolor": "#000" + }, + { + "id": 84, + "type": "MarkdownNote", + "pos": [ + 10898.333333333334, + -960 + ], + "size": [ + 229.36197916666669, + 92.65625 + ], + "flags": { + "collapsed": false + }, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [], + "title": "Step 2 - Download image", + "properties": {}, + "widgets_values": [ + "1. The result is here\nšŸ‘ˆāœØ\n\n2. Right-click and download the image." + ], + "color": "#222", + "bgcolor": "#000" + }, + { + "id": 81, + "type": "MarkdownNote", + "pos": [ + 7882.315398447294, + -1132.6809840806743 + ], + "size": [ + 210, + 90 + ], + "flags": { + "collapsed": false + }, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [], + "title": "Step 1", + "properties": {}, + "widgets_values": [ + "1. Download all relevant models and place them inside your ComfyUI/models folder šŸ‘‡" + ], + "color": "#222", + "bgcolor": "#000" + }, + { + "id": 83, + "type": "MarkdownNote", + "pos": [ + 10004.246413879386, + -1279.1407138616241 + ], + "size": [ + 220.32552083333334, + 88 + ], + "flags": { + "collapsed": false + }, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [], + "title": "Step 1 - Connect nodes", + "properties": {}, + "widgets_values": [ + "Try to connect these 2 nodes šŸ‘‡" + ], + "color": "#222", + "bgcolor": "#000" + }, + { + "id": 107, + "type": "ModelSamplingAuraFlow", + "pos": [ + 9213.292447420406, + -979.6763617965224 + ], + "size": [ + 310, + 85 + ], + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 147 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "slot_index": 0, + "links": [ + 152 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.64", + "Node name for S&R": "ModelSamplingAuraFlow", + "enableTabs": false, + "hasSecondTab": false, + "secondTabOffset": 80, + "secondTabText": "Send Back", + "secondTabWidth": 65, + "tabWidth": 65, + "tabXOffset": 10 + }, + "widgets_values": [ + 3 + ] + }, + { + "id": 108, + "type": "ConditioningZeroOut", + "pos": [ + 8963.292447420406, + -599.6763617965224 + ], + "size": [ + 204.134765625, + 51.00000000000001 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 148 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 154 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.73", + "Node name for S&R": "ConditioningZeroOut", + "enableTabs": false, + "hasSecondTab": false, + "secondTabOffset": 80, + "secondTabText": "Send Back", + "secondTabWidth": 65, + "tabWidth": 65, + "tabXOffset": 10 + }, + "widgets_values": [] + }, + { + "id": 109, + "type": "VAELoader", + "pos": [ + 8433.556257604858, + -704.4218567543277 + ], + "size": [ + 270, + 83 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 150 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.73", + "Node name for S&R": "VAELoader", + "enableTabs": false, + "hasSecondTab": false, + "models": [ + { + "name": "ae.safetensors", + "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/vae/ae.safetensors", + "directory": "vae" + } + ], + "secondTabOffset": 80, + "secondTabText": "Send Back", + "secondTabWidth": 65, + "tabWidth": 65, + "tabXOffset": 10 + }, + "widgets_values": [ + "ae.safetensors" + ] + }, + { + "id": 110, + "type": "EmptySD3LatentImage", + "pos": [ + 8433.292447420406, + -539.6763617965224 + ], + "size": [ + 269.4929183224277, + 133.85035641015304 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "slot_index": 0, + "links": [ + 155 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.64", + "Node name for S&R": "EmptySD3LatentImage", + "enableTabs": false, + "hasSecondTab": false, + "secondTabOffset": 80, + "secondTabText": "Send Back", + "secondTabWidth": 65, + "tabWidth": 65, + "tabXOffset": 10 + }, + "widgets_values": [ + 1280, + 720, + 1 + ] + }, + { + "id": 113, + "type": "MarkdownNote", + "pos": [ + 7659.877374560984, + -967.4803967921797 + ], + "size": [ + 533.3333333333334, + 558.3333333333334 + ], + "flags": { + "collapsed": false + }, + "order": 6, + "mode": 0, + "inputs": [], + "outputs": [], + "title": "For local users", + "properties": {}, + "widgets_values": [ + "## Report workflow issue\n\nIf you found any issues when running this workflow, [report template issue here](https://github.com/Comfy-Org/workflow_templates/issues).\n\n\n## Model links\n\n**diffusion_models**\n\n- [z_image_turbo_bf16.safetensors](https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/diffusion_models/z_image_turbo_bf16.safetensors)\n\n\n**text_encoders**\n\n- [qwen_3_4b.safetensors](https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/text_encoders/qwen_3_4b.safetensors)\n\n\n**vae**\n\n- [ae.safetensors](https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/vae/ae.safetensors)\n\n\n## Model Storage Location\n\n```\nšŸ“‚ ComfyUI/\nā”œā”€ā”€ šŸ“‚ models/\n│ ā”œā”€ā”€ šŸ“‚ diffusion_models/\n│ │ └── z_image_turbo_bf16.safetensors\n│ ā”œā”€ā”€ šŸ“‚ text_encoders/\n│ │ └── qwen_3_4b.safetensors\n│ └── šŸ“‚ vae/\n│ └── ae.safetensors\n```\n" + ], + "color": "#222", + "bgcolor": "#000" + }, + { + "id": 116, + "type": "KSampler", + "pos": [ + 9214.293768995101, + -873.9826233663099 + ], + "size": [ + 312.33082247057655, + 499 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 152 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 153 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 154 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 155 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "slot_index": 0, + "links": [ + 149 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.64", + "Node name for S&R": "KSampler", + "enableTabs": false, + "hasSecondTab": false, + "secondTabOffset": 80, + "secondTabText": "Send Back", + "secondTabWidth": 65, + "tabWidth": 65, + "tabXOffset": 10 + }, + "widgets_values": [ + 25220804433832, + "randomize", + 8, + 1, + "res_multistep", + "simple", + 1 + ] + }, + { + "id": 111, + "type": "VAEDecode", + "pos": [ + 9543.292447420406, + -979.6763617965224 + ], + "size": [ + 210, + 71 + ], + "flags": { + "collapsed": false + }, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 149 + }, + { + "name": "vae", + "type": "VAE", + "link": 150 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "slot_index": 0, + "links": [ + 156 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.64", + "Node name for S&R": "VAEDecode", + "enableTabs": false, + "hasSecondTab": false, + "secondTabOffset": 80, + "secondTabText": "Send Back", + "secondTabWidth": 65, + "tabWidth": 65, + "tabXOffset": 10 + }, + "widgets_values": [] + }, + { + "id": 115, + "type": "CLIPTextEncode", + "pos": [ + 8753.292447420406, + -979.6763617965224 + ], + "size": [ + 408.4818178377949, + 359.99707452069583 + ], + "flags": {}, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 151 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 148, + 153 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.73", + "Node name for S&R": "CLIPTextEncode", + "enableTabs": false, + "hasSecondTab": false, + "secondTabOffset": 80, + "secondTabText": "Send Back", + "secondTabWidth": 65, + "tabWidth": 65, + "tabXOffset": 10 + }, + "widgets_values": [ + "A towering technological monolith in a cyberpunk cityscape at night, with \"Multi-GPU\" emblazoned across its surface in massive neon blue-green mixed with purple letters that illuminate the surrounding buildings. The text occupies the central third of the frame, crafted from glowing plasma tubes and crackling energy. Rain-slicked streets below reflect the brilliant signage, while holographic advertisements and flying vehicles populate the background. Moody atmospheric lighting, heavy contrast, photorealistic textures, cinematic color grading. " + ], + "color": "#232", + "bgcolor": "#353" + }, + { + "id": 76, + "type": "SaveImage", + "pos": [ + 9973.333671371092, + -959.9988781468267 + ], + "size": [ + 783.3333333333334, + 575 + ], + "flags": { + "collapsed": false + }, + "order": 14, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 156 + } + ], + "outputs": [], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.71" + }, + "widgets_values": [ + "starter_multigpu" + ] + }, + { + "id": 114, + "type": "UNETLoaderMultiGPU", + "pos": [ + 8440.866118819535, + -978.949997949714 + ], + "size": [ + 270, + 106 + ], + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 147 + ] + } + ], + "properties": { + "aux_id": "pollockjj/ComfyUI-MultiGPU", + "ver": "7af256ab6ea86b105777631fce116c4174547362", + "Node name for S&R": "UNETLoaderMultiGPU" + }, + "widgets_values": [ + "z_image_turbo_bf16.safetensors", + "default", + "cuda:0" + ] + }, + { + "id": 112, + "type": "CLIPLoaderMultiGPU", + "pos": [ + 8437.79119373746, + -840.6484188479288 + ], + "size": [ + 270, + 106 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 151 + ] + } + ], + "properties": { + "aux_id": "pollockjj/ComfyUI-MultiGPU", + "ver": "7af256ab6ea86b105777631fce116c4174547362", + "Node name for S&R": "CLIPLoaderMultiGPU" + }, + "widgets_values": [ + "qwen_3_4b.safetensors", + "lumina2", + "cuda:1" + ] + } + ], + "links": [ + [ + 147, + 114, + 0, + 107, + 0, + "MODEL" + ], + [ + 148, + 115, + 0, + 108, + 0, + "CONDITIONING" + ], + [ + 149, + 116, + 0, + 111, + 0, + "LATENT" + ], + [ + 150, + 109, + 0, + 111, + 1, + "VAE" + ], + [ + 151, + 112, + 0, + 115, + 0, + "CLIP" + ], + [ + 152, + 107, + 0, + 116, + 0, + "MODEL" + ], + [ + 153, + 115, + 0, + 116, + 1, + "CONDITIONING" + ], + [ + 154, + 108, + 0, + 116, + 2, + "CONDITIONING" + ], + [ + 155, + 110, + 0, + 116, + 3, + "LATENT" + ], + [ + 156, + 111, + 0, + 76, + 0, + "IMAGE" + ] + ], + "groups": [ + { + "id": 2, + "title": "Step2 - Image size", + "bounding": [ + 8423.292447420406, + -609.6763617965224, + 290, + 220 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 3, + "title": "Step3 - Prompt", + "bounding": [ + 8733.292447420406, + -1049.6763617965225, + 450, + 660 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 4, + "title": "Step1 - Load models", + "bounding": [ + 8423.292447420406, + -1049.6763617965225, + 290, + 420 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + }, + { + "id": 6, + "title": "Step4 - Sampling", + "bounding": [ + 9203.292447420406, + -1049.6763617965225, + 570, + 660 + ], + "color": "#3f789e", + "font_size": 24, + "flags": {} + } + ], + "config": {}, + "extra": { + "VHS_KeepIntermediate": true, + "VHS_MetadataImage": true, + "VHS_latentpreview": false, + "VHS_latentpreviewrate": 0, + "frontendVersion": "1.39.19", + "workflowRendererVersion": "LG", + "ds": { + "scale": 0.8374407695202094, + "offset": [ + -8150.250707826723, + 1637.4975835478049 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/model_management_mgpu.py b/model_management_mgpu.py index bd3be8f..1e5a239 100644 --- a/model_management_mgpu.py +++ b/model_management_mgpu.py @@ -8,10 +8,8 @@ import hashlib import psutil import comfy.model_management as mm -import gc from datetime import datetime, timezone import server -from collections import defaultdict @@ -46,13 +44,13 @@ def _capture_memory_snapshot(): """Capture memory snapshot for CPU and all devices""" # Import here to avoid circular dependency from .device_utils import get_device_list - + snapshot = {} - + # CPU vm = psutil.virtual_memory() snapshot["cpu"] = (vm.used, vm.total) - + # GPU devices devices = [d for d in get_device_list() if d != "cpu"] for dev_str in devices: @@ -85,26 +83,26 @@ def multigpu_memory_log(identifier, tag): ts = datetime.now(timezone.utc) curr = _capture_memory_snapshot() - + # Store in series if identifier not in _MEM_SNAPSHOT_SERIES: _MEM_SNAPSHOT_SERIES[identifier] = [] _MEM_SNAPSHOT_SERIES[identifier].append((ts, tag, curr)) - + # Clean aligned format: timestamp + padded tag + memory values ts_str = ts.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" tag_padded = f"{identifier}_{tag}".ljust(35) - + parts = [] cpu_used, _ = curr.get("cpu", (0, 0)) parts.append(f"cpu|{cpu_used/(1024**3):.2f}") - - for dev in sorted([k for k in curr.keys() if k != "cpu"]): + + for dev in sorted(k for k in curr if k != "cpu"): used, _ = curr[dev] parts.append(f"{dev}|{used/(1024**3):.2f}") - + logger.mgpu_mm_log(f"{ts_str} {tag_padded} {' '.join(parts)}") - + _MEM_SNAPSHOT_LAST[identifier] = (tag, curr) diff --git a/nodes.py b/nodes.py index 4a75b78..0d1d178 100644 --- a/nodes.py +++ b/nodes.py @@ -1,14 +1,28 @@ -import torch import folder_paths from pathlib import Path from nodes import NODE_CLASS_MAPPINGS from .device_utils import get_device_list -from .model_management_mgpu import force_full_system_cleanup + +class DeviceSelectorMultiGPU: + @classmethod + def INPUT_TYPES(s): + devices = get_device_list() + return {"required": {"device": (devices,)}} + + RETURN_TYPES = ("MULTIGPUDEVICE",) + RETURN_NAMES = ("device",) + FUNCTION = "select_device" + CATEGORY = "multigpu" + TITLE = "Device Selector (MultiGPU)" + + def select_device(self, device): + """Return the selected device label without side effects.""" + return (device,) class UnetLoaderGGUF: @classmethod def INPUT_TYPES(s): - unet_names = [x for x in folder_paths.get_filename_list("unet_gguf")] + unet_names = list(folder_paths.get_filename_list("unet_gguf")) return { "required": { "unet_name": (unet_names,), @@ -28,7 +42,7 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de class UnetLoaderGGUFAdvanced(UnetLoaderGGUF): @classmethod def INPUT_TYPES(s): - unet_names = [x for x in folder_paths.get_filename_list("unet_gguf")] + unet_names = list(folder_paths.get_filename_list("unet_gguf")) return { "required": { "unet_name": (unet_names,), @@ -208,7 +222,7 @@ class DownloadAndLoadFlorence2Model: def INPUT_TYPES(s): return {"required": { "model": ( - [ + [ 'microsoft/Florence-2-base', 'microsoft/Florence-2-base-ft', 'microsoft/Florence-2-large', @@ -428,12 +442,14 @@ def INPUT_TYPES(s): def load_unet(self, unet_name): """Load UNet with low-precision LoRA flag for CPU storage optimization.""" original_loader = NODE_CLASS_MAPPINGS["UNETLoader"]() - out = original_loader.load_unet(unet_name) - + out = original_loader.load_unet(unet_name, "default") + # Set the low-precision LoRA flag on the loaded model if hasattr(out[0], 'model'): out[0].model._distorch_high_precision_loras = False - elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'): - out[0].patcher.model._distorch_high_precision_loras = False - + else: + patcher = getattr(out[0], "patcher", None) + if patcher is not None and hasattr(patcher, "model"): + patcher.model._distorch_high_precision_loras = False + return out diff --git a/pyproject.toml b/pyproject.toml index 789db4b..6de1f20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-multigpu" description = "Provides a suite of custom nodes to manage multiple GPUs for ComfyUI, including advanced model offloading for both GGUF and Safetensor formats with DisTorch, and bespoke MultiGPU support for WanVideoWrapper and other custom nodes." -version = "2.5.11" +version = "2.6.0" license = {file = "LICENSE"} [project.urls] @@ -11,4 +11,80 @@ Repository = "https://github.com/pollockjj/ComfyUI-MultiGPU" [tool.comfy] PublisherId = "pollockjj" DisplayName = "ComfyUI-MultiGPU" -Icon = "https://raw.githubusercontent.com/pollockjj/ComfyUI-MultiGPU/main/assets/multigpu_icon.png" \ No newline at end of file +Icon = "https://raw.githubusercontent.com/pollockjj/ComfyUI-MultiGPU/main/assets/multigpu_icon.png" + +[tool.ruff] +lint.select = [ + "N805", # invalid-first-argument-name-for-method + "S307", # suspicious-eval-usage + "S102", # exec + "E", + "T", # print-usage + "W", + # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. + # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f + "F", +] + +lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"] + +exclude = ["*.ipynb", "**/generated/*.pyi"] + +[tool.pylint] +master.py-version = "3.10" +master.init-hook = "import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../..')))" +master.extension-pkg-allow-list = [ + "pydantic", +] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "missing-module-docstring", + "missing-class-docstring", + "missing-function-docstring", + "line-too-long", + "too-few-public-methods", + "too-many-public-methods", + "too-many-instance-attributes", + "too-many-positional-arguments", + "broad-exception-raised", + "too-many-lines", + "invalid-name", + "unused-argument", + "broad-exception-caught", + "consider-using-with", + "fixme", + "too-many-statements", + "too-many-branches", + "too-many-locals", + "too-many-arguments", + "too-many-return-statements", + "too-many-nested-blocks", + "duplicate-code", + "abstract-method", + "superfluous-parens", + "arguments-differ", + "redefined-builtin", + "unnecessary-lambda", + "dangerous-default-value", + "invalid-overridden-method", + # next warnings should be fixed in future + "bad-classmethod-argument", # Class method should have 'cls' as first argument + "wrong-import-order", # Standard imports should be placed before third party imports + "ungrouped-imports", + "unnecessary-pass", + "unnecessary-lambda-assignment", + "no-else-return", + "unused-variable", + "arguments-renamed", + "cyclic-import", + "global-statement", + "import-outside-toplevel", + "logging-format-interpolation", + "logging-fstring-interpolation", + "protected-access", + "redefined-outer-name", + "reimported", + "useless-import-alias", + "wrong-import-position", +] diff --git a/wanvideo.py b/wanvideo.py index 4ebecf9..0f3be1a 100644 --- a/wanvideo.py +++ b/wanvideo.py @@ -1,19 +1,11 @@ import logging import torch -import sys import inspect import copy import folder_paths import comfy.model_management as mm from nodes import NODE_CLASS_MAPPINGS from .device_utils import get_device_list -from .model_management_mgpu import multigpu_memory_log -from comfy.utils import load_torch_file, ProgressBar -import gc -import numpy as np -from accelerate import init_empty_weights -import os -import importlib.util logger = logging.getLogger("MultiGPU") @@ -45,7 +37,7 @@ def INPUT_TYPES(s): "model": (folder_paths.get_filename_list("unet_gguf") + folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), "base_precision": (["fp32", "bf16", "fp16", "fp16_fast"], {"default": "bf16"}), - "quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e4m3fn_scaled", "fp8_e4m3fn_scaled_fast", "fp8_e5m2", "fp8_e5m2_fast", "fp8_e5m2_scaled", "fp8_e5m2_scaled_fast"], {"default": "disabled", + "quantization": (["disabled", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e4m3fn_scaled", "fp8_e4m3fn_scaled_fast", "fp8_e5m2", "fp8_e5m2_fast", "fp8_e5m2_scaled", "fp8_e5m2_scaled_fast"], {"default": "disabled", "tooltip": "Optional quantization method, 'disabled' acts as autoselect based by weights. Scaled modes only work with matching weights, _fast modes (fp8 matmul) require CUDA compute capability >= 8.9 (NVIDIA 4000 series and up), e4m3fn generally can not be torch.compiled on compute capability < 8.9 (3000 series and under)"}), "load_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}), "compute_device": (devices, {"default": default_device}), @@ -64,6 +56,7 @@ def INPUT_TYPES(s): "lora": ("WANVIDLORA", {"default": None}), "vram_management_args": ("VRAM_MANAGEMENTARGS", {"default": None, "tooltip": "Alternative offloading method from DiffSynth-Studio, more aggressive in reducing memory use than block swapping, but can be slower"}), "extra_model": ("VACEPATH", {"default": None, "tooltip": "Extra model to add to the main model, ie. VACE or MTV Crafter"}), + "vace_model": ("VACEPATH", {"default": None, "tooltip": "Backward-compatible alias for extra_model"}), "fantasytalking_model": ("FANTASYTALKINGMODEL", {"default": None, "tooltip": "FantasyTalking model https://github.com/Fantasy-AMAP"}), "multitalk_model": ("MULTITALKMODEL", {"default": None, "tooltip": "Multitalk model"}), "fantasyportrait_model": ("FANTASYPORTRAITMODEL", {"default": None, "tooltip": "FantasyPortrait model"}), @@ -83,7 +76,11 @@ def loadmodel(self, model, base_precision, compute_device, quantization, load_de loader_module = inspect.getmodule(original_loader) original_module_device = loader_module.device - set_current_device(compute_device) + vace_model = kwargs.pop("vace_model", None) + if kwargs.get("extra_model") is None and vace_model is not None: + kwargs["extra_model"] = vace_model + + set_current_device(compute_device) compute_device_to_be_patched = mm.get_torch_device() loader_module.device = compute_device_to_be_patched @@ -138,16 +135,16 @@ def INPUT_TYPES(s): "add_noise_to_samples": ("BOOLEAN", {"default": False, "tooltip": "Add noise to the samples before sampling, needed for video2video sampling when starting from clean video"}), } } - + RETURN_TYPES = ("LATENT", "LATENT",) RETURN_NAMES = ("samples", "denoised_samples",) FUNCTION = "process" CATEGORY = "multigpu/WanVideoWrapper" DESCRIPTION = "MultiGPU-aware sampler that ensures correct device for each model" - + def process(self, model, compute_device, **kwargs): from . import set_current_device - + original_sampler = NODE_CLASS_MAPPINGS["WanVideoSampler"]() sampler_module = inspect.getmodule(original_sampler) @@ -250,7 +247,7 @@ def loadmodel(self, model_name, precision, device=None, quantization="disabled") from . import set_current_device set_current_device(device) - + if device == "cpu": load_device = "offload_device" else: @@ -288,7 +285,7 @@ def loadmodel(self, model_name, precision, device=None): from . import set_current_device set_current_device(device) - + if device == "cpu": load_device = "offload_device" else: @@ -425,7 +422,7 @@ def INPUT_TYPES(s): }, "optional": { "load_device": (devices, {"default": default_device}), - "precision": (["fp16", "fp32", "bf16"], {"default": "fp16"}), + "precision": (["fp16", "fp32", "bf16"], {"default": "fp16"}), "parallel": ("BOOLEAN", {"default": False, "tooltip": "uses more memory but is faster"}), } } @@ -505,8 +502,8 @@ def INPUT_TYPES(s): FUNCTION = "process" CATEGORY = "multigpu/WanVideoWrapper" - def process(self, width, height, num_frames, force_offload, noise_aug_strength, - start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False, + def process(self, width, height, num_frames, force_offload, noise_aug_strength, + start_latent_strength, end_latent_strength, start_image=None, end_image=None, control_embeds=None, fun_or_fl2v_model=False, temporal_mask=None, extra_latents=None, clip_embeds=None, tiled_vae=False, add_cond_latents=None, vae=None, load_device=None): from . import set_current_device @@ -576,7 +573,7 @@ def decode(self, vae, load_device, samples, enable_vae_tiling, tile_x, tile_y, t decode_module = inspect.getmodule(original_decode) original_module_device = decode_module.device - set_current_device(load_device) + set_current_device(load_device) compute_device_to_be_patched = mm.get_torch_device() decode_module.device = compute_device_to_be_patched @@ -678,7 +675,7 @@ def INPUT_TYPES(s): "clip_vision": ("CLIP_VISION",), "load_device": ("MULTIGPUDEVICE",), "image_1": ("IMAGE", {"tooltip": "Image to encode"}), - "strength_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), + "strength_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), "strength_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001, "tooltip": "Additional clip embed multiplier"}), "crop": (["center", "disabled"], {"default": "center", "tooltip": "Crop image to 224x224 before encoding"}), "combine_embeds": (["average", "sum", "concat", "batch"], {"default": "average", "tooltip": "Method to combine multiple clip embeds"}), @@ -738,7 +735,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, device): from . import set_current_device set_current_device(device) - + original_loader = NODE_CLASS_MAPPINGS["WanVideoControlnetLoader"]() return original_loader.loadmodel(model, base_precision, load_device, quantization) @@ -765,7 +762,7 @@ def loadmodel(self, model, base_precision, device): from . import set_current_device set_current_device(device) - + original_loader = NODE_CLASS_MAPPINGS["FantasyTalkingModelLoader"]() return original_loader.loadmodel(model, base_precision) @@ -793,7 +790,7 @@ def loadmodel(self, model, base_precision, load_device, device): from . import set_current_device set_current_device(device) - + original_loader = NODE_CLASS_MAPPINGS["Wav2VecModelLoader"]() return original_loader.loadmodel(model, base_precision, load_device) @@ -826,7 +823,7 @@ def loadmodel(self, model, base_precision, load_device, device): from . import set_current_device set_current_device(device) - + original_loader = NODE_CLASS_MAPPINGS["DownloadAndLoadWav2VecModel"]() return original_loader.loadmodel(model, base_precision, load_device) @@ -861,6 +858,6 @@ def loadmodel(self, model, base_precision, load_device, device, quantization, at from . import set_current_device set_current_device(device) - + original_loader = NODE_CLASS_MAPPINGS["WanVideoUni3C_ControlnetLoader"]() - return original_loader.loadmodel(model, base_precision, load_device, quantization, attention_mode, compile_args) \ No newline at end of file + return original_loader.loadmodel(model, base_precision, load_device, quantization, attention_mode, compile_args) diff --git a/wrappers.py b/wrappers.py index a132792..f3d36f4 100644 --- a/wrappers.py +++ b/wrappers.py @@ -18,8 +18,7 @@ def _create_distorch_safetensor_v2_override(cls, device_param_name, device_setter_func, apply_device_kwarg_workaround, eject_models_default=True): """Internal factory function creating DisTorch2 override class with parameterized device selection behavior.""" from .distorch_2 import register_patched_safetensor_modelpatcher - from .model_management_mgpu import force_full_system_cleanup - + class NodeOverrideDisTorchSafetensorV2(cls): @classmethod def INPUT_TYPES(s): @@ -83,7 +82,7 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu", logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP_COMPLETE] Marked {ejection_count} models for Comfy Core eviction during load_models_gpu") else: - logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP] eject_models=False - loading without eviction") + logger.mgpu_mm_log("[EJECT_MODELS_SETUP] eject_models=False - loading without eviction") if device_value is not None: device_setter_func(device_value) @@ -93,7 +92,7 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu", if k not in [device_param_name, 'virtual_vram_gb', 'donor_device', 'expert_mode_allocations', 'eject_models']} - + if apply_device_kwarg_workaround: clean_kwargs['device'] = 'default' @@ -106,10 +105,10 @@ def override(self, *args, virtual_vram_gb=4.0, donor_device="cpu", vram_string = device_value full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else "" - + fn = getattr(super(), cls.FUNCTION) out = fn(*args, **clean_kwargs) - + model_to_check = None if hasattr(out[0], 'model'): model_to_check = out[0] @@ -174,7 +173,7 @@ def override_class_with_distorch_gguf(cls): """DisTorch V1 Legacy wrapper - maintains V1 UI but calls V2 backend""" from . import set_current_device, get_current_device from .distorch_2 import register_patched_safetensor_modelpatcher - + class NodeOverrideDisTorchGGUFLegacy(cls): @classmethod def INPUT_TYPES(s): @@ -197,14 +196,14 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra original_device = get_current_device() if device is not None: set_current_device(device) - + # Strip MultiGPU-specific parameters before calling original function - clean_kwargs = {k: v for k, v in kwargs.items() - if k not in ['device', 'virtual_vram_gb', 'use_other_vram', + clean_kwargs = {k: v for k, v in kwargs.items() + if k not in ['device', 'virtual_vram_gb', 'use_other_vram', 'expert_mode_allocations']} - + register_patched_safetensor_modelpatcher() - + vram_string = "" if virtual_vram_gb > 0: if use_other_vram: @@ -217,7 +216,7 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra vram_string = f"{device};{virtual_vram_gb};cpu" full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else "" - + fn = getattr(super(), cls.FUNCTION) out = fn(*args, **clean_kwargs) @@ -226,7 +225,7 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra model_to_check = out[0] elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'): model_to_check = out[0].patcher - + if model_to_check and full_allocation: inner_model = model_to_check.model inner_model._distorch_v2_meta = {"full_allocation": full_allocation} @@ -242,14 +241,14 @@ def override_class_with_distorch_gguf_v2(cls): """DisTorch V2 wrapper for GGUF models""" from . import set_current_device, get_current_device from .distorch_2 import register_patched_safetensor_modelpatcher - + class NodeOverrideDisTorchGGUFv2(cls): @classmethod def INPUT_TYPES(s): inputs = copy.deepcopy(cls.INPUT_TYPES()) devices = get_device_list() compute_device = devices[1] if len(devices) > 1 else devices[0] - + inputs["optional"] = inputs.get("optional", {}) inputs["optional"]["compute_device"] = (devices, {"default": compute_device}) inputs["optional"]["virtual_vram_gb"] = ("FLOAT", {"default": 4.0, "min": 0.0, "max": 128.0, "step": 0.1}) @@ -265,14 +264,14 @@ def override(self, *args, compute_device=None, virtual_vram_gb=4.0, donor_device original_device = get_current_device() if compute_device is not None: set_current_device(compute_device) - + # Strip MultiGPU-specific parameters before calling original function - clean_kwargs = {k: v for k, v in kwargs.items() - if k not in ['compute_device', 'virtual_vram_gb', + clean_kwargs = {k: v for k, v in kwargs.items() + if k not in ['compute_device', 'virtual_vram_gb', 'donor_device', 'expert_mode_allocations']} - + register_patched_safetensor_modelpatcher() - + vram_string = "" if virtual_vram_gb > 0: vram_string = f"{compute_device};{virtual_vram_gb};{donor_device}" @@ -280,18 +279,18 @@ def override(self, *args, compute_device=None, virtual_vram_gb=4.0, donor_device vram_string = compute_device full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else "" - + logger.info(f"[MultiGPU DisTorch V2] Full allocation string: {full_allocation}") - + fn = getattr(super(), cls.FUNCTION) out = fn(*args, **clean_kwargs) - + model_to_check = None if hasattr(out[0], 'model'): model_to_check = out[0] elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'): model_to_check = out[0].patcher - + if model_to_check and full_allocation: inner_model = model_to_check.model inner_model._distorch_v2_meta = {"full_allocation": full_allocation} @@ -307,7 +306,7 @@ def override_class_with_distorch_clip(cls): """DisTorch V1 wrapper for CLIP models - calls V2 backend""" from . import set_current_text_encoder_device, get_current_text_encoder_device from .distorch_2 import register_patched_safetensor_modelpatcher - + class NodeOverrideDisTorchClip(cls): @classmethod def INPUT_TYPES(s): @@ -329,14 +328,14 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra original_text_device = get_current_text_encoder_device() if device is not None: set_current_text_encoder_device(device) - + # Strip MultiGPU-specific parameters before calling original function - clean_kwargs = {k: v for k, v in kwargs.items() - if k not in ['device', 'virtual_vram_gb', 'use_other_vram', + clean_kwargs = {k: v for k, v in kwargs.items() + if k not in ['device', 'virtual_vram_gb', 'use_other_vram', 'expert_mode_allocations']} - + register_patched_safetensor_modelpatcher() - + vram_string = "" if virtual_vram_gb > 0: if use_other_vram: @@ -349,16 +348,16 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra vram_string = f"{device};{virtual_vram_gb};cpu" full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else "" - + fn = getattr(super(), cls.FUNCTION) out = fn(*args, **clean_kwargs) - + model_to_check = None if hasattr(out[0], 'model'): model_to_check = out[0] elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'): model_to_check = out[0].patcher - + if model_to_check and full_allocation: inner_model = model_to_check.model inner_model._distorch_v2_meta = {"full_allocation": full_allocation} @@ -374,7 +373,7 @@ def override_class_with_distorch_clip_no_device(cls): """DisTorch V1 wrapper for Triple/Quad CLIP models - calls V2 backend""" from . import set_current_text_encoder_device, get_current_text_encoder_device from .distorch_2 import register_patched_safetensor_modelpatcher - + class NodeOverrideDisTorchClipNoDevice(cls): @classmethod def INPUT_TYPES(s): @@ -396,14 +395,14 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra original_text_device = get_current_text_encoder_device() if device is not None: set_current_text_encoder_device(device) - + # Strip MultiGPU-specific parameters before calling original function - clean_kwargs = {k: v for k, v in kwargs.items() - if k not in ['device', 'virtual_vram_gb', 'use_other_vram', + clean_kwargs = {k: v for k, v in kwargs.items() + if k not in ['device', 'virtual_vram_gb', 'use_other_vram', 'expert_mode_allocations']} - + register_patched_safetensor_modelpatcher() - + vram_string = "" if virtual_vram_gb > 0: if use_other_vram: @@ -416,16 +415,16 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra vram_string = f"{device};{virtual_vram_gb};cpu" full_allocation = f"{expert_mode_allocations}#{vram_string}" if expert_mode_allocations or vram_string else "" - + fn = getattr(super(), cls.FUNCTION) out = fn(*args, **clean_kwargs) - + model_to_check = None if hasattr(out[0], 'model'): model_to_check = out[0] elif hasattr(out[0], 'patcher') and hasattr(out[0].patcher, 'model'): model_to_check = out[0].patcher - + if model_to_check and full_allocation: inner_model = model_to_check.model inner_model._distorch_v2_meta = {"full_allocation": full_allocation} @@ -447,8 +446,8 @@ def override(self, *args, device=None, expert_mode_allocations="", use_other_vra def override_class(cls): """Standard MultiGPU device override for UNet/VAE models""" - from . import set_current_device, get_current_device - + from . import set_current_device, get_current_device, cuda_device_guard + class NodeOverride(cls): @classmethod def INPUT_TYPES(s): @@ -466,10 +465,11 @@ def override(self, *args, device=None, **kwargs): original_device = get_current_device() if device is not None: set_current_device(device) + target_device = device if device is not None else get_current_device() fn = getattr(super(), cls.FUNCTION) - out = fn(*args, **kwargs) try: - return out + with cuda_device_guard(target_device, reason=f"{type(self).__name__}.{cls.FUNCTION}"): + return fn(*args, **kwargs) finally: set_current_device(original_device) @@ -477,8 +477,14 @@ def override(self, *args, device=None, **kwargs): def override_class_offload(cls): """Standard MultiGPU device override for UNet/VAE models""" - from . import set_current_device, set_current_unet_offload_device, get_current_device, get_current_unet_offload_device - + from . import ( + set_current_device, + set_current_unet_offload_device, + get_current_device, + get_current_unet_offload_device, + cuda_device_guard, + ) + class NodeOverride(cls): @classmethod def INPUT_TYPES(s): @@ -500,10 +506,11 @@ def override(self, *args, device=None, offload_device=None, **kwargs): set_current_device(device) if offload_device is not None: set_current_unet_offload_device(offload_device) + target_device = device if device is not None else get_current_device() fn = getattr(super(), cls.FUNCTION) - out = fn(*args, **kwargs) try: - return out + with cuda_device_guard(target_device, reason=f"{type(self).__name__}.{cls.FUNCTION}"): + return fn(*args, **kwargs) finally: set_current_device(original_device) set_current_unet_offload_device(original_offload_device) @@ -515,7 +522,7 @@ def override(self, *args, device=None, offload_device=None, **kwargs): def override_class_clip(cls): """Standard MultiGPU device override for CLIP models (with device kwarg workaround)""" from . import set_current_text_encoder_device, get_current_text_encoder_device - + class NodeOverride(cls): @classmethod def INPUT_TYPES(s): @@ -547,7 +554,7 @@ def override(self, *args, device=None, **kwargs): def override_class_clip_no_device(cls): """Standard MultiGPU device override for Triple/Quad CLIP models (no device kwarg workaround)""" from . import set_current_text_encoder_device, get_current_text_encoder_device - + class NodeOverride(cls): @classmethod def INPUT_TYPES(s):