Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,14 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F
# filter and strip prefix
has_prefix = False
if handle_prefix is not None:
prefix_len = len(handle_prefix)
tensor_names = set(tensor.name for tensor in reader.tensors)
prefix_len = len(handle_prefix)
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
# Some stable-diffusion.cpp exports (anima) use a "net." prefix
if (not has_prefix) and (not is_text_model) and tensor_names and all(s.startswith("net.") for s in tensor_names):
handle_prefix = "net."
prefix_len = len(handle_prefix)
has_prefix = True

tensors = []
for tensor in reader.tensors:
Expand All @@ -97,12 +102,22 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F
if is_text_model:
raise ValueError(f"This gguf file is incompatible with llama.cpp!\nConsider using safetensors or a compatible gguf file\n({path})")
compat = "sd.cpp" if arch_str is None else arch_str
# import here to avoid changes to convert.py breaking regular models
from .tools.convert import detect_arch
try:
arch_str = detect_arch(set(val[0] for val in tensors)).arch
except Exception as e:
raise ValueError(f"This model is not currently supported - ({e})")
tensor_keys = set(val[0] for val in tensors)
# stable-diffusion.cpp qwen-image tensors overlap some legacy flux/sd3 markers
if {
"img_in.weight",
"proj_out.weight",
"time_text_embed.timestep_embedder.linear_1.weight",
"norm_out.linear.weight",
}.issubset(tensor_keys):
arch_str = "qwen_image"
else:
# import here to avoid changes to convert.py breaking regular models
from .tools.convert import detect_arch
try:
arch_str = detect_arch(tensor_keys).arch
except Exception as e:
raise ValueError(f"This model is not currently supported - ({e})")
elif arch_str not in TXT_ARCH_LIST and is_text_model:
if type_str not in VIS_TYPE_LIST:
raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
Expand All @@ -112,6 +127,17 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F
if compat:
logging.warning(f"Warning: This gguf model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]")

wan_dim = None
if compat == "sd.cpp" and arch_str == "wan":
# Used to restore collapsed Conv3d patch embedding shape in sd.cpp exports
head_mod = next((t for k, t in tensors if k == "head.modulation"), None)
if head_mod is not None:
mod_shape = get_orig_shape(reader, head_mod.name)
if mod_shape is None:
mod_shape = torch.Size(tuple(int(v) for v in reversed(head_mod.shape)))
if len(mod_shape) >= 1:
wan_dim = int(mod_shape[-1])

# main loading loop
state_dict = {}
qtype_dict = {}
Expand All @@ -132,6 +158,23 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F
if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]):
while len(shape) > 2 and shape[-1] == 1:
shape = shape[:-1]
# Workaround for stable-diffusion.cpp Lumina2 pad token shape
if compat == "sd.cpp" and arch_str == "lumina2":
if len(shape) == 1 and sd_key in {"x_pad_token", "cap_pad_token"}:
shape = torch.Size((1, shape[0]))
# Workaround for stable-diffusion.cpp Wan 2.1 shape collapse
if compat == "sd.cpp" and arch_str == "wan":
if len(shape) == 2 and sd_key.endswith(".modulation"):
shape = torch.Size((1, shape[0], shape[1]))
if (
len(shape) == 4
and sd_key.endswith("patch_embedding.weight")
and shape[1] == 1
and wan_dim is not None
and shape[0] % wan_dim == 0
):
in_dim = shape[0] // wan_dim
shape = torch.Size((wan_dim, in_dim, 1, shape[2], shape[3]))

# add to state dict
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
Expand All @@ -142,6 +185,12 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F
if len(shape) <= 1 and tensor.tensor_type == gguf.GGMLQuantizationType.BF16:
state_dict[sd_key] = dequantize_tensor(state_dict[sd_key], dtype=torch.float32)

if compat == "sd.cpp" and len(shape) <= 1 and is_quantized(state_dict[sd_key]):
state_dict[sd_key] = dequantize_tensor(state_dict[sd_key], dtype=torch.float32)
if compat == "sd.cpp" and arch_str == "wan":
if sd_key.endswith(".modulation") and is_quantized(state_dict[sd_key]):
state_dict[sd_key] = dequantize_tensor(state_dict[sd_key], dtype=torch.float32)

# keep track of loaded tensor types
tensor_type_str = getattr(tensor.tensor_type, "name", repr(tensor.tensor_type))
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
Expand Down