From 11d4b5e95e43d2507ff2de9df041db7db9ce6ea2 Mon Sep 17 00:00:00 2001 From: "Ivan R." Date: Thu, 22 Jan 2026 18:57:45 +0500 Subject: [PATCH] fix CUDA device index issue --- __init__.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/__init__.py b/__init__.py index 2f5efd4..a45e247 100644 --- a/__init__.py +++ b/__init__.py @@ -216,6 +216,40 @@ def unet_offload_device_patched(): logger.debug(f"[MultiGPU Core Patching] unet_offload_device_patched returning device: {device} (current_unet_offload_device={current_unet_offload_device})") return device +# --- cu130 / PyTorch DLPack device-guard patch --- +# comfy_kitchen's CUDA backend uses tensor.__dlpack__ which requires the current CUDA device +# to match the tensor's device index. With multi-GPU setups the global current device can differ, +# causing: BufferError: Can't export tensors on a different CUDA device index. +def _patch_comfy_kitchen_dlpack_device_guard(): + try: + import comfy_kitchen.backends.cuda as ck_cuda # type: ignore + except Exception: + return # comfy_kitchen not installed or no CUDA backend + + try: + orig_wrap = getattr(ck_cuda, "_wrap_for_dlpack", None) + if orig_wrap is None or getattr(orig_wrap, "_multigpu_patched", False): + return + + def _wrap_for_dlpack_guarded(tensor): + try: + if torch.cuda.is_available() and hasattr(tensor, "is_cuda") and tensor.is_cuda: + idx = tensor.device.index + if idx is not None: + # Ensure current CUDA device matches the tensor before exporting via DLPack. + torch.cuda.set_device(idx) + except Exception: + pass + return orig_wrap(tensor) + + _wrap_for_dlpack_guarded._multigpu_patched = True # type: ignore + setattr(ck_cuda, "_wrap_for_dlpack", _wrap_for_dlpack_guarded) + logger.info("[MultiGPU] Patched comfy_kitchen CUDA DLPack wrapper with device guard.") + except Exception as e: + logger.debug(f"[MultiGPU] Failed to patch comfy_kitchen DLPack wrapper: {e}") + +_patch_comfy_kitchen_dlpack_device_guard() + logger.info(f"[MultiGPU Core Patching] Patching mm.get_torch_device, mm.text_encoder_device, mm.unet_offload_device") logger.info(f"[MultiGPU DEBUG] Initial current_device: {current_device}") logger.info(f"[MultiGPU DEBUG] Initial current_text_encoder_device: {current_text_encoder_device}")