diff --git a/loader.py b/loader.py index 7cefb11..3f063f1 100644 --- a/loader.py +++ b/loader.py @@ -8,6 +8,7 @@ from .ops import GGMLTensor from .dequant import is_quantized, dequantize_tensor +from .quant_ops import make_quantized IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"} TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"} @@ -67,7 +68,7 @@ def get_gguf_metadata(reader): continue return metadata -def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=False): +def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=False, dynamic=False): """ Read state dict as fake tensors """ @@ -134,9 +135,15 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F shape = shape[:-1] # add to state dict - if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: - torch_tensor = torch_tensor.view(*shape) - state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) + if dynamic: + if tensor.tensor_type not in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: + state_dict[sd_key] = make_quantized(torch_tensor, tensor.tensor_type, shape) + else: + state_dict[sd_key] = torch_tensor.view(*shape) + else: + if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: + torch_tensor = torch_tensor.view(*shape) + state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) # 1D tensors shouldn't be quantized, this is a fix for BF16 if len(shape) <= 1 and tensor.tensor_type == gguf.GGMLQuantizationType.BF16: diff --git a/nodes.py b/nodes.py index 4683514..d5ab69f 100644 --- a/nodes.py +++ b/nodes.py @@ -11,6 +11,7 @@ import comfy.utils import comfy.model_patcher import comfy.model_management +import comfy.memory_management import folder_paths from .ops import GGMLOps, move_patch_to_device @@ -32,6 +33,19 @@ def update_folder_names_and_paths(key, targets=[]): update_folder_names_and_paths("unet_gguf", ["diffusion_models", "unet"]) update_folder_names_and_paths("clip_gguf", ["text_encoders", "clip"]) +def _clone_as_gguf_model_patcher(self, *args, model_override=None, **kwargs): + if model_override is None: + model_override = self.get_clone_model_override() + mmap_released = model_override[2] if len(model_override) > 2 else False + src_cls = self.__class__ + self.__class__ = GGUFModelPatcher + n = comfy.model_patcher.ModelPatcher.clone(self, *args, model_override=model_override, **kwargs) + n.__class__ = GGUFModelPatcher + self.__class__ = src_cls + n.patch_on_device = getattr(self, "patch_on_device", False) + n.mmap_released = mmap_released + return n + class GGUFModelPatcher(comfy.model_patcher.ModelPatcher): patch_on_device = False @@ -89,6 +103,9 @@ def pin_weight_to_device(self, key): mmap_released = False named_modules_to_munmap = {} + def get_clone_model_override(self): + return (*super().get_clone_model_override(), self.mmap_released) + def load(self, *args, force_patch_weights=False, **kwargs): if not self.mmap_released: self.named_modules_to_munmap = dict(self.model.named_modules()) @@ -120,18 +137,104 @@ def load(self, *args, force_patch_weights=False, **kwargs): self.named_modules_to_munmap = {} def clone(self, *args, **kwargs): - src_cls = self.__class__ - self.__class__ = GGUFModelPatcher - n = super().clone(*args, **kwargs) - n.__class__ = GGUFModelPatcher - self.__class__ = src_cls - # GGUF specific clone values below - n.patch_on_device = getattr(self, "patch_on_device", False) - n.mmap_released = getattr(self, "mmap_released", False) - if src_cls != GGUFModelPatcher: + n = _clone_as_gguf_model_patcher(self, *args, **kwargs) + if self.__class__ != GGUFModelPatcher: n.size = 0 # force recalc return n +class GGUFModelPatcherDynamic(comfy.model_patcher.ModelPatcherDynamic): + patch_on_device = False + + def load(self, *args, **kwargs): + super().load(*args, **kwargs) + # GGML can't requantize after LoRA - demote lowvram_function to weight_function + for n, m in self.model.named_modules(): + for param_key in ("weight", "bias"): + attr = param_key + "_lowvram_function" + fn = getattr(m, attr, None) + if fn is not None: + setattr(m, attr, None) + fns = getattr(m, param_key + "_function", []) + fns.append(fn) + setattr(m, param_key + "_function", fns) + if self.patch_on_device: + for key in self.patches: + self.patches[key] = move_patch_to_device(self.patches[key], self.load_device) + + def clone(self, disable_dynamic=False, model_override=None): + if disable_dynamic: + if model_override is None: + temp = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + model_override = temp.get_clone_model_override() + n = _clone_as_gguf_model_patcher(self, model_override=model_override) + return n + n = super().clone(disable_dynamic=disable_dynamic, model_override=model_override) + n.patch_on_device = self.patch_on_device + return n + +def _clone_patcher_to_gguf(model_patcher): + if model_patcher.is_dynamic(): + src_cls = model_patcher.__class__ + model_patcher.__class__ = GGUFModelPatcherDynamic + n = model_patcher.clone() + model_patcher.__class__ = src_cls + return n + else: + return GGUFModelPatcher.clone(model_patcher) + +def _load_gguf_unet(unet_path, ops, disable_dynamic=False): + dynamic = not disable_dynamic and comfy.memory_management.aimdo_enabled + sd, extra = gguf_sd_loader(unet_path, dynamic=dynamic) + + kwargs = {} + valid_params = inspect.signature(comfy.sd.load_diffusion_model_state_dict).parameters + if "metadata" in valid_params: + kwargs["metadata"] = extra.get("metadata", {}) + + model = comfy.sd.load_diffusion_model_state_dict( + sd, model_options={} if dynamic else { "custom_operations" : ops }, disable_dynamic=disable_dynamic, **kwargs, + ) + if model is None: + logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) + raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + model = _clone_patcher_to_gguf(model) + + model.cached_patcher_init = (_load_gguf_unet, (unet_path, ops)) + + return model + +def _load_gguf_clip_patcher(clip_paths, clip_type, disable_dynamic=False): + return _load_gguf_clip(clip_paths, clip_type, disable_dynamic=disable_dynamic).patcher + +def _load_gguf_clip(clip_paths, clip_type, disable_dynamic=False): + dynamic = not disable_dynamic and comfy.memory_management.aimdo_enabled + + clip_data = [] + for p in clip_paths: + if p.endswith(".gguf"): + sd = gguf_clip_loader(p) + else: + sd = comfy.utils.load_torch_file(p, safe_load=True) + if not dynamic and "scaled_fp8" in sd: # NOTE: Scaled FP8 would require different custom ops, but only one can be active + raise NotImplementedError(f"Mixing scaled FP8 with GGUF is not supported! Use regular CLIP loader or switch model(s)\n({p})") + clip_data.append(sd) + + model_options = {"initial_device": comfy.model_management.text_encoder_offload_device()} + if not dynamic: + model_options["custom_operations"] = GGMLOps + + clip = comfy.sd.load_text_encoder_state_dicts( + clip_type = clip_type, + state_dicts = clip_data, + model_options = model_options, + embedding_directory = folder_paths.get_folder_paths("embeddings"), + disable_dynamic = disable_dynamic, + ) + clip.patcher = _clone_patcher_to_gguf(clip.patcher) + + clip.patcher.cached_patcher_init = (_load_gguf_clip_patcher, (clip_paths, clip_type)) + return clip + class UnetLoaderGGUF: @classmethod def INPUT_TYPES(s): @@ -164,22 +267,8 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de else: ops.Linear.patch_dtype = getattr(torch, patch_dtype) - # init model unet_path = folder_paths.get_full_path("unet", unet_name) - sd, extra = gguf_sd_loader(unet_path) - - kwargs = {} - valid_params = inspect.signature(comfy.sd.load_diffusion_model_state_dict).parameters - if "metadata" in valid_params: - kwargs["metadata"] = extra.get("metadata", {}) - - model = comfy.sd.load_diffusion_model_state_dict( - sd, model_options={"custom_operations": ops}, **kwargs, - ) - if model is None: - logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) - raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) - model = GGUFModelPatcher.clone(model) + model = _load_gguf_unet(unet_path, ops) model.patch_on_device = patch_on_device return (model,) @@ -220,35 +309,10 @@ def get_filename_list(s): files += folder_paths.get_filename_list("clip_gguf") return sorted(files) - def load_data(self, ckpt_paths): - clip_data = [] - for p in ckpt_paths: - if p.endswith(".gguf"): - sd = gguf_clip_loader(p) - else: - sd = comfy.utils.load_torch_file(p, safe_load=True) - if "scaled_fp8" in sd: # NOTE: Scaled FP8 would require different custom ops, but only one can be active - raise NotImplementedError(f"Mixing scaled FP8 with GGUF is not supported! Use regular CLIP loader or switch model(s)\n({p})") - clip_data.append(sd) - return clip_data - - def load_patcher(self, clip_paths, clip_type, clip_data): - clip = comfy.sd.load_text_encoder_state_dicts( - clip_type = clip_type, - state_dicts = clip_data, - model_options = { - "custom_operations": GGMLOps, - "initial_device": comfy.model_management.text_encoder_offload_device() - }, - embedding_directory = folder_paths.get_folder_paths("embeddings"), - ) - clip.patcher = GGUFModelPatcher.clone(clip.patcher) - return clip - def load_clip(self, clip_name, type="stable_diffusion"): clip_path = folder_paths.get_full_path("clip", clip_name) clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) - return (self.load_patcher([clip_path], clip_type, self.load_data([clip_path])),) + return (_load_gguf_clip([clip_path], clip_type),) class DualCLIPLoaderGGUF(CLIPLoaderGGUF): @classmethod @@ -270,7 +334,7 @@ def load_clip(self, clip_name1, clip_name2, type): clip_path2 = folder_paths.get_full_path("clip", clip_name2) clip_paths = (clip_path1, clip_path2) clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) - return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) + return (_load_gguf_clip(clip_paths, clip_type),) class TripleCLIPLoaderGGUF(CLIPLoaderGGUF): @classmethod @@ -292,7 +356,7 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, type="sd3"): clip_path3 = folder_paths.get_full_path("clip", clip_name3) clip_paths = (clip_path1, clip_path2, clip_path3) clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) - return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) + return (_load_gguf_clip(clip_paths, clip_type),) class QuadrupleCLIPLoaderGGUF(CLIPLoaderGGUF): @classmethod @@ -316,7 +380,7 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable clip_path4 = folder_paths.get_full_path("clip", clip_name4) clip_paths = (clip_path1, clip_path2, clip_path3, clip_path4) clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) - return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) + return (_load_gguf_clip(clip_paths, clip_type),) NODE_CLASS_MAPPINGS = { "UnetLoaderGGUF": UnetLoaderGGUF, diff --git a/quant_ops.py b/quant_ops.py new file mode 100644 index 0000000..856d381 --- /dev/null +++ b/quant_ops.py @@ -0,0 +1,83 @@ +# GGML QuantizedTensor support for dynamic VRAM loading +import gguf +import torch +from dataclasses import dataclass + +try: + from comfy_kitchen.tensor import ( + QuantizedTensor, + QuantizedLayout, + BaseLayoutParams, + register_layout_class, + ) + _CK_AVAILABLE = True +except ImportError: + _CK_AVAILABLE = False + + class QuantizedTensor: + pass + + class QuantizedLayout: + pass + + class BaseLayoutParams: + pass + + def register_layout_class(name, cls): + pass + +from .dequant import dequantize_functions, TORCH_COMPATIBLE_QTYPES, is_quantized + +if _CK_AVAILABLE: + @dataclass(frozen=True) + class GGMLLayoutParams(BaseLayoutParams): + tensor_type: int # gguf.GGMLQuantizationType stored as int + + class GGMLLayout(QuantizedLayout): + Params = GGMLLayoutParams + + @classmethod + def quantize(cls, tensor, **kwargs): + raise NotImplementedError("Quantization to GGML format is not supported") + + @classmethod + def dequantize(cls, qdata, params): + qtype = gguf.GGMLQuantizationType(params.tensor_type) + oshape = params.orig_shape + + if qtype in TORCH_COMPATIBLE_QTYPES: + return qdata.reshape(oshape).to(params.orig_dtype) + + if qtype not in dequantize_functions: + from tqdm import tqdm + tqdm.write(f"Falling back to numpy dequant for qtype: {qtype.name}") + new = gguf.quants.dequantize(qdata.cpu().numpy(), qtype) + return torch.from_numpy(new).reshape(oshape).to(device=qdata.device, dtype=params.orig_dtype) + + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + raw = qdata.reshape(-1).view(torch.uint8) + n_blocks = raw.numel() // type_size + blocks = raw.reshape((n_blocks, type_size)) + blocks = dequantize_functions[qtype](blocks, block_size, type_size, None) + return blocks.reshape(oshape).to(params.orig_dtype) + + @classmethod + def get_plain_tensors(cls, qtensor): + return (qtensor._qdata,) + + @classmethod + def state_dict_tensors(cls, qdata, params): + return {"weight": qdata} + + register_layout_class("GGMLLayout", GGMLLayout) + + +def make_quantized(qdata, tensor_type, tensor_shape, orig_dtype=torch.float16): + """Construct a GGML QuantizedTensor from raw packed data.""" + params = GGMLLayoutParams( + scale=torch.ones((), dtype=torch.float32), + orig_dtype=orig_dtype, + orig_shape=tuple(tensor_shape), + tensor_type=tensor_type.value if not isinstance(tensor_type, int) else tensor_type, + ) + return QuantizedTensor(qdata, "GGMLLayout", params)