Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 283 additions & 70 deletions __init__.py

Large diffs are not rendered by default.

112 changes: 67 additions & 45 deletions checkpoint_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,39 @@
def patch_load_state_dict_guess_config():
"""Monkey patch comfy.sd.load_state_dict_guess_config with MultiGPU-aware checkpoint loading."""
global original_load_state_dict_guess_config

if original_load_state_dict_guess_config is not None:
logger.debug("[MultiGPU Checkpoint] load_state_dict_guess_config is already patched.")
return

logger.info("[MultiGPU Core Patching] Patching comfy.sd.load_state_dict_guess_config for advanced MultiGPU loading.")
original_load_state_dict_guess_config = comfy.sd.load_state_dict_guess_config
comfy.sd.load_state_dict_guess_config = patched_load_state_dict_guess_config

def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False,
embedding_directory=None, output_model=True, model_options={},
te_model_options={}, metadata=None):
te_model_options={}, metadata=None, disable_dynamic=False):
Comment on lines 33 to +35
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

patched_load_state_dict_guess_config uses mutable default arguments (model_options={}, te_model_options={}), which can leak state across invocations if ComfyUI mutates these dicts during loading. Use None defaults and create new dicts inside the function (e.g., model_options = model_options or {}) to avoid cross-call contamination.

Copilot uses AI. Check for mistakes.
"""Patched checkpoint loader with MultiGPU and DisTorch2 device placement support."""
from . import set_current_device, set_current_text_encoder_device, get_current_device, get_current_text_encoder_device

sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
config_hash = str(sd_size)
device_config = checkpoint_device_config.get(config_hash)
distorch_config = checkpoint_distorch_config.get(config_hash)

if not device_config and not distorch_config:
return original_load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options, metadata)
return original_load_state_dict_guess_config(
sd,
output_vae=output_vae,
output_clip=output_clip,
output_clipvision=output_clipvision,
embedding_directory=embedding_directory,
output_model=output_model,
model_options=model_options,
te_model_options=te_model_options,
metadata=metadata,
disable_dynamic=disable_dynamic,
)

logger.debug("[MultiGPU Checkpoint] ENTERING Patched Checkpoint Loader")
logger.debug(f"[MultiGPU Checkpoint] Received Device Config: {device_config}")
Expand All @@ -53,7 +64,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
vae = None
model = None
model_patcher = None

# Capture the current devices at runtime so we can restore them after loading
original_main_device = get_current_device()
original_clip_device = get_current_text_encoder_device()
Expand All @@ -68,12 +79,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)

model_config = comfy.model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)

if model_config is None:
logger.warning("[MultiGPU] Warning: Not a standard checkpoint file. Trying to load as diffusion model only.")
# Simplified fallback for non-checkpoints
set_current_device(device_config.get('unet_device', original_main_device))
diffusion_model = comfy.sd.load_diffusion_model_state_dict(sd, model_options={})
diffusion_model = comfy.sd.load_diffusion_model_state_dict(
sd,
model_options={},
metadata=metadata,
disable_dynamic=disable_dynamic,
)
if diffusion_model is None:
return None
return (diffusion_model, None, VAE(sd={}), None)
Expand All @@ -83,18 +99,18 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
weight_dtype = None

if custom_operations is not None:
model_config.custom_operations = custom_operations
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
unet_compute_device = device_config.get('unet_device', original_main_device)

unet_compute_device = torch.device(device_config.get('unet_device', original_main_device))
if model_config.scaled_fp8 is not None:
manual_cast_dtype = mm.unet_manual_cast(None, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
manual_cast_dtype = mm.unet_manual_cast(None, unet_compute_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, torch.device(unet_compute_device), model_config.supported_inference_dtypes)
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, unet_compute_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
logger.info(f"UNet DType: {unet_dtype}, Manual Cast: {manual_cast_dtype}")

Expand All @@ -103,19 +119,20 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
clipvision = comfy.clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)

if output_model:
unet_compute_device = device_config.get('unet_device', original_main_device)
set_current_device(unet_compute_device)
unet_compute_device = torch.device(device_config.get('unet_device', original_main_device))
set_current_device(unet_compute_device)
inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype)

multigpu_memory_log(f"unet:{config_hash[:8]}", "pre-load")

model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)
model_patcher_class = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = model_patcher_class(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device())
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-weights")

logger.mgpu_mm_log("Invoking soft_empty_cache_multigpu before UNet ModelPatcher setup")
soft_empty_cache_multigpu()
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=unet_compute_device, offload_device=mm.unet_offload_device())
multigpu_memory_log(f"unet:{config_hash[:8]}", "post-model")

if distorch_config and 'unet_allocation' in distorch_config:
Expand All @@ -131,7 +148,7 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
vae_target_device = torch.device(device_config.get('vae_device', original_main_device))
set_current_device(vae_target_device) # Use main device context for VAE
multigpu_memory_log(f"vae:{config_hash[:8]}", "pre-load")

vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd, metadata=metadata)
Expand All @@ -151,17 +168,17 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
for pref in scaled_fp8_list:
skip = skip or k.startswith(pref)
if not skip:
out_sd[k] = sd[k]
out_sd[k] = sd[k]

for pref in scaled_fp8_list:
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
for k in quant_sd:
out_sd[k] = quant_sd[k]
sd = out_sd

clip_target_device = device_config.get('clip_device', original_clip_device)
clip_target_device = torch.device(device_config.get('clip_device', original_clip_device))
set_current_text_encoder_device(clip_target_device)

clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
Expand All @@ -170,7 +187,15 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
multigpu_memory_log(f"clip:{config_hash[:8]}", "pre-load")
soft_empty_cache_multigpu()
clip_params = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=clip_params, model_options=te_model_options)
clip = CLIP(
clip_target,
embedding_directory=embedding_directory,
tokenizer_data=clip_sd,
parameters=clip_params,
state_dict=clip_sd,
model_options=te_model_options,
disable_dynamic=disable_dynamic,
)

if distorch_config and 'clip_allocation' in distorch_config:
clip_alloc = distorch_config['clip_allocation']
Expand All @@ -181,16 +206,13 @@ def patched_load_state_dict_guess_config(sd, output_vae=True, output_clip=True,
logger.info(f"[CHECKPOINT_META] CLIP inner_model id=0x{id(inner_clip):x}")
clip.patcher.model._distorch_high_precision_loras = distorch_config.get('high_precision_loras', True)

m, u = clip.load_sd(clip_sd, full_model=True) # This respects the patched text_encoder_device
if len(m) > 0: logger.warning(f"CLIP missing keys: {m}")
if len(u) > 0: logger.debug(f"CLIP unexpected keys: {u}")
logger.info("CLIP Loaded.")
multigpu_memory_log(f"clip:{config_hash[:8]}", "post-load")
else:
logger.warning("No CLIP/text encoder weights in checkpoint.")
else:
logger.warning("CLIP target not found in model config.")

finally:
set_current_device(original_main_device)
set_current_text_encoder_device(original_clip_device)
Expand All @@ -206,7 +228,7 @@ def INPUT_TYPES(s):
import folder_paths
devices = get_device_list()
default_device = devices[1] if len(devices) > 1 else devices[0]

return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
Expand All @@ -215,27 +237,27 @@ def INPUT_TYPES(s):
"vae_device": (devices, {"default": default_device}),
}
}

RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "multigpu"
TITLE = "Checkpoint Loader Advanced (MultiGPU)"

def load_checkpoint(self, ckpt_name, unet_device, clip_device, vae_device):
patch_load_state_dict_guess_config()

import folder_paths
import comfy.utils

ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
sd = comfy.utils.load_torch_file(ckpt_path)
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
config_hash = str(sd_size)

checkpoint_device_config[config_hash] = {
'unet_device': unet_device, 'clip_device': clip_device, 'vae_device': vae_device
}

# Load using standard loader, our patch will intercept
from nodes import CheckpointLoaderSimple
return CheckpointLoaderSimple().load_checkpoint(ckpt_name)
Expand All @@ -247,7 +269,7 @@ def INPUT_TYPES(s):
import folder_paths
devices = get_device_list()
compute_device = devices[1] if len(devices) > 1 else devices[0]

return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
Expand All @@ -265,18 +287,18 @@ def INPUT_TYPES(s):
"eject_models": ("BOOLEAN", {"default": True}),
}
}

RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "multigpu/distorch_2"
TITLE = "Checkpoint Loader Advanced (DisTorch2)"

def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb, unet_donor_device,
clip_compute_device, clip_virtual_vram_gb, clip_donor_device, vae_device,
unet_expert_mode_allocations="", clip_expert_mode_allocations="", high_precision_loras=True, eject_models=True):

if eject_models:
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction")
logger.mgpu_mm_log("[EJECT_MODELS_SETUP] eject_models=True - marking all loaded models for eviction")
ejection_count = 0
for i, lm in enumerate(mm.current_loaded_models):
model_name = type(getattr(lm.model, 'model', lm.model)).__name__ if lm.model else 'Unknown'
Expand All @@ -289,17 +311,17 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
logger.mgpu_mm_log(f"[EJECT_MARKED] Model {i}: {model_name} (direct patcher) → marked for eviction")
ejection_count += 1
logger.mgpu_mm_log(f"[EJECT_MODELS_SETUP_COMPLETE] Marked {ejection_count} models for Comfy Core eviction during load_models_gpu")
patch_load_state_dict_guess_config()

patch_load_state_dict_guess_config()

import folder_paths
import comfy.utils

ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
sd = comfy.utils.load_torch_file(ckpt_path)
sd_size = sum(p.numel() for p in sd.values() if hasattr(p, 'numel'))
config_hash = str(sd_size)

checkpoint_device_config[config_hash] = {
'unet_device': unet_compute_device,
'clip_device': clip_compute_device,
Expand All @@ -312,7 +334,7 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
elif unet_expert_mode_allocations:
unet_vram_str = unet_compute_device
unet_alloc = f"{unet_expert_mode_allocations}#{unet_vram_str}" if unet_expert_mode_allocations or unet_vram_str else ""

clip_vram_str = ""
if clip_virtual_vram_gb > 0:
clip_vram_str = f"{clip_compute_device};{clip_virtual_vram_gb};{clip_donor_device}"
Expand All @@ -327,6 +349,6 @@ def load_checkpoint(self, ckpt_name, unet_compute_device, unet_virtual_vram_gb,
'unet_settings': hashlib.sha256(f"{unet_alloc}{high_precision_loras}".encode()).hexdigest(),
'clip_settings': hashlib.sha256(f"{clip_alloc}{high_precision_loras}".encode()).hexdigest(),
}

from nodes import CheckpointLoaderSimple
return CheckpointLoaderSimple().load_checkpoint(ckpt_name)
16 changes: 11 additions & 5 deletions ci/extract_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@

import argparse
import json
import sys
from pathlib import Path
from typing import Iterable, Iterator, Dict, Any


def _write_stdout(message: str = "") -> None:
sys.stdout.write(f"{message}\n")
sys.stdout.flush()


def load_json_lines(path: Path) -> Iterator[Dict[str, Any]]:
with path.open("r", encoding="utf-8") as handle:
for line in handle:
Expand Down Expand Up @@ -37,12 +43,12 @@ def main() -> int:

entries = list(load_json_lines(args.logfile))
if not entries:
print("No entries found in log file.")
_write_stdout("No entries found in log file.")
return 0

matched = [entry for entry in entries if is_allocation_event(entry, args.keywords)]
if not matched:
print("No allocation events matched provided keywords.")
_write_stdout("No allocation events matched provided keywords.")
return 0

for entry in matched:
Expand All @@ -51,9 +57,9 @@ def main() -> int:
component = entry.get("component", "")
header_bits = [bit for bit in (timestamp, category, component) if bit]
header = " | ".join(header_bits) if header_bits else "allocation"
print(f"## {header}")
print(entry.get("message", ""))
print()
_write_stdout(f"## {header}")
_write_stdout(entry.get("message", ""))
_write_stdout()

return 0

Expand Down
Loading