Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,40 @@ def unet_offload_device_patched():
logger.debug(f"[MultiGPU Core Patching] unet_offload_device_patched returning device: {device} (current_unet_offload_device={current_unet_offload_device})")
return device

# --- cu130 / PyTorch DLPack device-guard patch ---
# comfy_kitchen's CUDA backend uses tensor.__dlpack__ which requires the current CUDA device
# to match the tensor's device index. With multi-GPU setups the global current device can differ,
# causing: BufferError: Can't export tensors on a different CUDA device index.
def _patch_comfy_kitchen_dlpack_device_guard():
try:
import comfy_kitchen.backends.cuda as ck_cuda # type: ignore
except Exception:
return # comfy_kitchen not installed or no CUDA backend

try:
orig_wrap = getattr(ck_cuda, "_wrap_for_dlpack", None)
if orig_wrap is None or getattr(orig_wrap, "_multigpu_patched", False):
return

def _wrap_for_dlpack_guarded(tensor):
try:
if torch.cuda.is_available() and hasattr(tensor, "is_cuda") and tensor.is_cuda:
idx = tensor.device.index
if idx is not None:
# Ensure current CUDA device matches the tensor before exporting via DLPack.
torch.cuda.set_device(idx)
except Exception:
pass
return orig_wrap(tensor)

_wrap_for_dlpack_guarded._multigpu_patched = True # type: ignore
setattr(ck_cuda, "_wrap_for_dlpack", _wrap_for_dlpack_guarded)
logger.info("[MultiGPU] Patched comfy_kitchen CUDA DLPack wrapper with device guard.")
except Exception as e:
logger.debug(f"[MultiGPU] Failed to patch comfy_kitchen DLPack wrapper: {e}")

_patch_comfy_kitchen_dlpack_device_guard()

logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device")
logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}")
logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")
Expand Down