diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index ccc8347b7..5eab8a70b 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -850,7 +850,78 @@ std::string convert_diffusers_vae_to_original_sd1(std::string name) { return result; } -std::string convert_first_stage_model_name(std::string name, std::string prefix) { +std::string convert_diffusers_to_original_wan_vae(std::string name) { + static const std::vector> prefix_map = { + {"quant_conv.", "conv1."}, + {"post_quant_conv.", "conv2."}, + + {"decoder.up_blocks.0.resnets.0.", "decoder.upsamples.0.residual."}, + {"decoder.up_blocks.0.resnets.1.", "decoder.upsamples.1.residual."}, + {"decoder.up_blocks.0.resnets.2.", "decoder.upsamples.2.residual."}, + {"decoder.up_blocks.0.upsamplers.0.", "decoder.upsamples.3."}, + + {"decoder.up_blocks.1.resnets.0.conv_shortcut.", "decoder.upsamples.4.shortcut."}, + {"decoder.up_blocks.1.resnets.0.", "decoder.upsamples.4.residual."}, + {"decoder.up_blocks.1.resnets.1.", "decoder.upsamples.5.residual."}, + {"decoder.up_blocks.1.resnets.2.", "decoder.upsamples.6.residual."}, + {"decoder.up_blocks.1.upsamplers.0.", "decoder.upsamples.7."}, + {"decoder.up_blocks.2.resnets.0.", "decoder.upsamples.8.residual."}, + {"decoder.up_blocks.2.resnets.1.", "decoder.upsamples.9.residual."}, + {"decoder.up_blocks.2.resnets.2.", "decoder.upsamples.10.residual."}, + {"decoder.up_blocks.2.upsamplers.0.", "decoder.upsamples.11."}, + {"decoder.up_blocks.3.resnets.0.", "decoder.upsamples.12.residual."}, + {"decoder.up_blocks.3.resnets.1.", "decoder.upsamples.13.residual."}, + {"decoder.up_blocks.3.resnets.2.", "decoder.upsamples.14.residual."}, + + {"encoder.down_blocks.0.", "encoder.downsamples.0.residual."}, + {"encoder.down_blocks.1.", "encoder.downsamples.1.residual."}, + {"encoder.down_blocks.2.", "encoder.downsamples.2."}, + {"encoder.down_blocks.3.conv_shortcut.", "encoder.downsamples.3.shortcut."}, + {"encoder.down_blocks.3.", "encoder.downsamples.3.residual."}, + {"encoder.down_blocks.4.", "encoder.downsamples.4.residual."}, + {"encoder.down_blocks.5.", "encoder.downsamples.5."}, + {"encoder.down_blocks.6.conv_shortcut.", "encoder.downsamples.6.shortcut."}, + {"encoder.down_blocks.6.", "encoder.downsamples.6.residual."}, + {"encoder.down_blocks.7.", "encoder.downsamples.7.residual."}, + {"encoder.down_blocks.8.", "encoder.downsamples.8."}, + {"encoder.down_blocks.9.", "encoder.downsamples.9.residual."}, + {"encoder.down_blocks.10.", "encoder.downsamples.10.residual."}, + }; + + static const std::vector> shared_name_map = { + {".conv_in.", ".conv1."}, + {".norm_out.", ".head.0."}, + {".conv_out.", ".head.2."}, + + {".mid_block.attentions.0.", ".middle.1."}, + {".mid_block.resnets.0.", ".middle.0.residual."}, + {".mid_block.resnets.1.", ".middle.2.residual."}, + }; + + static const std::vector> resnet_name_map = { + {".norm1.", ".0."}, + {".conv1.", ".2."}, + {".norm2.", ".3."}, + {".conv2.", ".6."}, + }; + + replace_with_name_map(name, shared_name_map); + replace_with_prefix_map(name, prefix_map); + + // Only apply the ResNet-specific renaming if the tensor belongs to a ResNet block. + // This prevents generic ".conv1." or ".conv2." matching on top-level encoder/decoder convolutions. + if (name.find(".residual.") != std::string::npos) { + replace_with_name_map(name, resnet_name_map); + } + + + return name; +} + +std::string convert_first_stage_model_name(std::string name, std::string prefix, SDVersion version) { + if (sd_version_uses_wan_vae(version)) { + return convert_diffusers_to_original_wan_vae(name); + } static std::unordered_map vae_name_map = { {"decoder.post_quant_conv.", "post_quant_conv."}, {"encoder.quant_conv.", "quant_conv."}, @@ -1239,7 +1310,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) { { for (const auto& prefix : first_stage_model_prefix_vec) { if (starts_with(name, prefix)) { - name = convert_first_stage_model_name(name.substr(prefix.size()), prefix); + name = convert_first_stage_model_name(name.substr(prefix.size()), prefix, version); if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) { name = "tae." + name; } else {