Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .ops import GGMLTensor
from .dequant import is_quantized, dequantize_tensor
from .quant_ops import make_quantized

IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"}
Expand Down Expand Up @@ -67,7 +68,7 @@ def get_gguf_metadata(reader):
continue
return metadata

def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=False):
def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=False, dynamic=False):
"""
Read state dict as fake tensors
"""
Expand Down Expand Up @@ -134,9 +135,15 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F
shape = shape[:-1]

# add to state dict
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
torch_tensor = torch_tensor.view(*shape)
state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
if dynamic:
if tensor.tensor_type not in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
state_dict[sd_key] = make_quantized(torch_tensor, tensor.tensor_type, shape)
else:
state_dict[sd_key] = torch_tensor.view(*shape)
else:
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
torch_tensor = torch_tensor.view(*shape)
state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)

# 1D tensors shouldn't be quantized, this is a fix for BF16
if len(shape) <= 1 and tensor.tensor_type == gguf.GGMLQuantizationType.BF16:
Expand Down
170 changes: 117 additions & 53 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import comfy.utils
import comfy.model_patcher
import comfy.model_management
import comfy.memory_management
import folder_paths

from .ops import GGMLOps, move_patch_to_device
Expand All @@ -32,6 +33,19 @@ def update_folder_names_and_paths(key, targets=[]):
update_folder_names_and_paths("unet_gguf", ["diffusion_models", "unet"])
update_folder_names_and_paths("clip_gguf", ["text_encoders", "clip"])

def _clone_as_gguf_model_patcher(self, *args, model_override=None, **kwargs):
if model_override is None:
model_override = self.get_clone_model_override()
mmap_released = model_override[2] if len(model_override) > 2 else False
src_cls = self.__class__
self.__class__ = GGUFModelPatcher
n = comfy.model_patcher.ModelPatcher.clone(self, *args, model_override=model_override, **kwargs)
n.__class__ = GGUFModelPatcher
self.__class__ = src_cls
n.patch_on_device = getattr(self, "patch_on_device", False)
n.mmap_released = mmap_released
return n

class GGUFModelPatcher(comfy.model_patcher.ModelPatcher):
patch_on_device = False

Expand Down Expand Up @@ -89,6 +103,9 @@ def pin_weight_to_device(self, key):
mmap_released = False
named_modules_to_munmap = {}

def get_clone_model_override(self):
return (*super().get_clone_model_override(), self.mmap_released)

def load(self, *args, force_patch_weights=False, **kwargs):
if not self.mmap_released:
self.named_modules_to_munmap = dict(self.model.named_modules())
Expand Down Expand Up @@ -120,18 +137,104 @@ def load(self, *args, force_patch_weights=False, **kwargs):
self.named_modules_to_munmap = {}

def clone(self, *args, **kwargs):
src_cls = self.__class__
self.__class__ = GGUFModelPatcher
n = super().clone(*args, **kwargs)
n.__class__ = GGUFModelPatcher
self.__class__ = src_cls
# GGUF specific clone values below
n.patch_on_device = getattr(self, "patch_on_device", False)
n.mmap_released = getattr(self, "mmap_released", False)
if src_cls != GGUFModelPatcher:
n = _clone_as_gguf_model_patcher(self, *args, **kwargs)
if self.__class__ != GGUFModelPatcher:
n.size = 0 # force recalc
return n

class GGUFModelPatcherDynamic(comfy.model_patcher.ModelPatcherDynamic):
patch_on_device = False

def load(self, *args, **kwargs):
super().load(*args, **kwargs)
# GGML can't requantize after LoRA - demote lowvram_function to weight_function
for n, m in self.model.named_modules():
for param_key in ("weight", "bias"):
attr = param_key + "_lowvram_function"
fn = getattr(m, attr, None)
if fn is not None:
setattr(m, attr, None)
fns = getattr(m, param_key + "_function", [])
fns.append(fn)
setattr(m, param_key + "_function", fns)
if self.patch_on_device:
for key in self.patches:
self.patches[key] = move_patch_to_device(self.patches[key], self.load_device)

def clone(self, disable_dynamic=False, model_override=None):
if disable_dynamic:
if model_override is None:
temp = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model_override = temp.get_clone_model_override()
n = _clone_as_gguf_model_patcher(self, model_override=model_override)
return n
n = super().clone(disable_dynamic=disable_dynamic, model_override=model_override)
n.patch_on_device = self.patch_on_device
return n

def _clone_patcher_to_gguf(model_patcher):
if model_patcher.is_dynamic():
src_cls = model_patcher.__class__
model_patcher.__class__ = GGUFModelPatcherDynamic
n = model_patcher.clone()
model_patcher.__class__ = src_cls
return n
else:
return GGUFModelPatcher.clone(model_patcher)

def _load_gguf_unet(unet_path, ops, disable_dynamic=False):
dynamic = not disable_dynamic and comfy.memory_management.aimdo_enabled
sd, extra = gguf_sd_loader(unet_path, dynamic=dynamic)

kwargs = {}
valid_params = inspect.signature(comfy.sd.load_diffusion_model_state_dict).parameters
if "metadata" in valid_params:
kwargs["metadata"] = extra.get("metadata", {})

model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={} if dynamic else { "custom_operations" : ops }, disable_dynamic=disable_dynamic, **kwargs,
)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
model = _clone_patcher_to_gguf(model)

model.cached_patcher_init = (_load_gguf_unet, (unet_path, ops))

return model

def _load_gguf_clip_patcher(clip_paths, clip_type, disable_dynamic=False):
return _load_gguf_clip(clip_paths, clip_type, disable_dynamic=disable_dynamic).patcher

def _load_gguf_clip(clip_paths, clip_type, disable_dynamic=False):
dynamic = not disable_dynamic and comfy.memory_management.aimdo_enabled

clip_data = []
for p in clip_paths:
if p.endswith(".gguf"):
sd = gguf_clip_loader(p)
else:
sd = comfy.utils.load_torch_file(p, safe_load=True)
if not dynamic and "scaled_fp8" in sd: # NOTE: Scaled FP8 would require different custom ops, but only one can be active
raise NotImplementedError(f"Mixing scaled FP8 with GGUF is not supported! Use regular CLIP loader or switch model(s)\n({p})")
clip_data.append(sd)

model_options = {"initial_device": comfy.model_management.text_encoder_offload_device()}
if not dynamic:
model_options["custom_operations"] = GGMLOps

clip = comfy.sd.load_text_encoder_state_dicts(
clip_type = clip_type,
state_dicts = clip_data,
model_options = model_options,
embedding_directory = folder_paths.get_folder_paths("embeddings"),
disable_dynamic = disable_dynamic,
)
clip.patcher = _clone_patcher_to_gguf(clip.patcher)

clip.patcher.cached_patcher_init = (_load_gguf_clip_patcher, (clip_paths, clip_type))
return clip

class UnetLoaderGGUF:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -164,22 +267,8 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de
else:
ops.Linear.patch_dtype = getattr(torch, patch_dtype)

# init model
unet_path = folder_paths.get_full_path("unet", unet_name)
sd, extra = gguf_sd_loader(unet_path)

kwargs = {}
valid_params = inspect.signature(comfy.sd.load_diffusion_model_state_dict).parameters
if "metadata" in valid_params:
kwargs["metadata"] = extra.get("metadata", {})

model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": ops}, **kwargs,
)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
model = GGUFModelPatcher.clone(model)
model = _load_gguf_unet(unet_path, ops)
model.patch_on_device = patch_on_device
return (model,)

Expand Down Expand Up @@ -220,35 +309,10 @@ def get_filename_list(s):
files += folder_paths.get_filename_list("clip_gguf")
return sorted(files)

def load_data(self, ckpt_paths):
clip_data = []
for p in ckpt_paths:
if p.endswith(".gguf"):
sd = gguf_clip_loader(p)
else:
sd = comfy.utils.load_torch_file(p, safe_load=True)
if "scaled_fp8" in sd: # NOTE: Scaled FP8 would require different custom ops, but only one can be active
raise NotImplementedError(f"Mixing scaled FP8 with GGUF is not supported! Use regular CLIP loader or switch model(s)\n({p})")
clip_data.append(sd)
return clip_data

def load_patcher(self, clip_paths, clip_type, clip_data):
clip = comfy.sd.load_text_encoder_state_dicts(
clip_type = clip_type,
state_dicts = clip_data,
model_options = {
"custom_operations": GGMLOps,
"initial_device": comfy.model_management.text_encoder_offload_device()
},
embedding_directory = folder_paths.get_folder_paths("embeddings"),
)
clip.patcher = GGUFModelPatcher.clone(clip.patcher)
return clip

def load_clip(self, clip_name, type="stable_diffusion"):
clip_path = folder_paths.get_full_path("clip", clip_name)
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
return (self.load_patcher([clip_path], clip_type, self.load_data([clip_path])),)
return (_load_gguf_clip([clip_path], clip_type),)

class DualCLIPLoaderGGUF(CLIPLoaderGGUF):
@classmethod
Expand All @@ -270,7 +334,7 @@ def load_clip(self, clip_name1, clip_name2, type):
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
clip_paths = (clip_path1, clip_path2)
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),)
return (_load_gguf_clip(clip_paths, clip_type),)

class TripleCLIPLoaderGGUF(CLIPLoaderGGUF):
@classmethod
Expand All @@ -292,7 +356,7 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, type="sd3"):
clip_path3 = folder_paths.get_full_path("clip", clip_name3)
clip_paths = (clip_path1, clip_path2, clip_path3)
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),)
return (_load_gguf_clip(clip_paths, clip_type),)

class QuadrupleCLIPLoaderGGUF(CLIPLoaderGGUF):
@classmethod
Expand All @@ -316,7 +380,7 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable
clip_path4 = folder_paths.get_full_path("clip", clip_name4)
clip_paths = (clip_path1, clip_path2, clip_path3, clip_path4)
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),)
return (_load_gguf_clip(clip_paths, clip_type),)

NODE_CLASS_MAPPINGS = {
"UnetLoaderGGUF": UnetLoaderGGUF,
Expand Down
83 changes: 83 additions & 0 deletions quant_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# GGML QuantizedTensor support for dynamic VRAM loading
import gguf
import torch
from dataclasses import dataclass

try:
from comfy_kitchen.tensor import (
QuantizedTensor,
QuantizedLayout,
BaseLayoutParams,
register_layout_class,
)
_CK_AVAILABLE = True
except ImportError:
_CK_AVAILABLE = False

class QuantizedTensor:
pass

class QuantizedLayout:
pass

class BaseLayoutParams:
pass

def register_layout_class(name, cls):
pass

from .dequant import dequantize_functions, TORCH_COMPATIBLE_QTYPES, is_quantized

if _CK_AVAILABLE:
@dataclass(frozen=True)
class GGMLLayoutParams(BaseLayoutParams):
tensor_type: int # gguf.GGMLQuantizationType stored as int

class GGMLLayout(QuantizedLayout):
Params = GGMLLayoutParams

@classmethod
def quantize(cls, tensor, **kwargs):
raise NotImplementedError("Quantization to GGML format is not supported")

@classmethod
def dequantize(cls, qdata, params):
qtype = gguf.GGMLQuantizationType(params.tensor_type)
oshape = params.orig_shape

if qtype in TORCH_COMPATIBLE_QTYPES:
return qdata.reshape(oshape).to(params.orig_dtype)

if qtype not in dequantize_functions:
from tqdm import tqdm
tqdm.write(f"Falling back to numpy dequant for qtype: {qtype.name}")
new = gguf.quants.dequantize(qdata.cpu().numpy(), qtype)
return torch.from_numpy(new).reshape(oshape).to(device=qdata.device, dtype=params.orig_dtype)

block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
raw = qdata.reshape(-1).view(torch.uint8)
n_blocks = raw.numel() // type_size
blocks = raw.reshape((n_blocks, type_size))
blocks = dequantize_functions[qtype](blocks, block_size, type_size, None)
return blocks.reshape(oshape).to(params.orig_dtype)

@classmethod
def get_plain_tensors(cls, qtensor):
return (qtensor._qdata,)

@classmethod
def state_dict_tensors(cls, qdata, params):
return {"weight": qdata}

register_layout_class("GGMLLayout", GGMLLayout)


def make_quantized(qdata, tensor_type, tensor_shape, orig_dtype=torch.float16):
"""Construct a GGML QuantizedTensor from raw packed data."""
params = GGMLLayoutParams(
scale=torch.ones((), dtype=torch.float32),
orig_dtype=orig_dtype,
orig_shape=tuple(tensor_shape),
tensor_type=tensor_type.value if not isinstance(tensor_type, int) else tensor_type,
)
return QuantizedTensor(qdata, "GGMLLayout", params)