diff --git a/src/conditioning/conditioner.hpp b/src/conditioning/conditioner.hpp index ae1a5b5b3..e037fe76b 100644 --- a/src/conditioning/conditioner.hpp +++ b/src/conditioning/conditioner.hpp @@ -1518,7 +1518,7 @@ struct LLMEmbedder : public Conditioner { arch = LLM::LLMArch::GPT_OSS_20B; } else if (sd_version_is_pid(version)) { arch = LLM::LLMArch::GEMMA2_2B; - } else if (sd_version_is_ideogram4(version) || sd_version_is_boogu_image(version)) { + } else if (sd_version_is_ideogram4(version) || sd_version_is_boogu_image(version) || sd_version_is_krea2(version)) { arch = LLM::LLMArch::QWEN3_VL; } else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) { arch = LLM::LLMArch::QWEN3; @@ -1837,6 +1837,17 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n"; } + } else if (sd_version_is_krea2(version)) { + prompt_template_encode_start_idx = 34; + out_layers = {2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35}; + + prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; } else if (sd_version_is_longcat(version)) { spell_quotes = true; diff --git a/src/core/ggml_extend.hpp b/src/core/ggml_extend.hpp index a3dda16b2..65196813b 100644 --- a/src/core/ggml_extend.hpp +++ b/src/core/ggml_extend.hpp @@ -1382,7 +1382,16 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx, if (!ggml_backend_supports_op(backend, kqv)) { kqv = nullptr; } else { - kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0); + kqv = ggml_view_4d(ctx, + kqv, + d_head, + n_head, + L_q, + N, + kqv->nb[1], + kqv->nb[2], + kqv->nb[1] * n_head, + 0); } } } diff --git a/src/model.h b/src/model.h index d02ed65b8..cce309138 100644 --- a/src/model.h +++ b/src/model.h @@ -49,6 +49,7 @@ enum SDVersion { VERSION_LONGCAT, VERSION_PID, VERSION_IDEOGRAM4, + VERSION_KREA2, VERSION_ESRGAN, VERSION_COUNT, }; @@ -186,6 +187,13 @@ static inline bool sd_version_is_ideogram4(SDVersion version) { return false; } +static inline bool sd_version_is_krea2(SDVersion version) { + if (version == VERSION_KREA2) { + return true; + } + return false; +} + static inline bool sd_version_uses_flux_vae(SDVersion version) { if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_boogu_image(version) || sd_version_is_longcat(version)) { return true; @@ -226,7 +234,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_lens(version) || sd_version_is_longcat(version) || sd_version_is_pid(version) || - sd_version_is_ideogram4(version)) { + sd_version_is_ideogram4(version) || + sd_version_is_krea2(version)) { return true; } return false; diff --git a/src/model/diffusion/krea2.hpp b/src/model/diffusion/krea2.hpp new file mode 100644 index 000000000..02e655590 --- /dev/null +++ b/src/model/diffusion/krea2.hpp @@ -0,0 +1,683 @@ +#ifndef __SD_MODEL_DIFFUSION_KREA2_HPP__ +#define __SD_MODEL_DIFFUSION_KREA2_HPP__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/ggml_extend.hpp" +#include "core/ggml_graph_cut.h" +#include "model/common/rope.hpp" +#include "model/diffusion/dit.hpp" +#include "model/diffusion/flux.hpp" +#include "model/diffusion/model.hpp" +#include "model_loader.h" + +namespace Krea2 { + constexpr int KREA2_GRAPH_SIZE = 65536; + + struct Krea2Config { + int patch_size = 2; + int64_t in_channels = 16; + int64_t out_channels = 16; + int64_t features = 6144; + int64_t timestep_dim = 256; + int64_t text_dim = 2560; + int64_t text_layers = 12; + int64_t layers = 28; + int64_t heads = 48; + int64_t kv_heads = 12; + int64_t text_heads = 20; + int64_t text_kv_heads = 20; + int64_t mlp_multiplier = 4; + float theta = 1000.f; + float norm_eps = 1e-5f; + std::vector axes_dim = {32, 48, 48}; + int axes_dim_sum = 128; + + int64_t head_dim() const { + return features / heads; + } + + static int64_t count_blocks(const String2TensorStorage& tensor_storage_map, + const std::string& prefix, + const std::string& block_prefix) { + int64_t count = 0; + std::string full_prefix = prefix.empty() ? block_prefix : prefix + "." + block_prefix; + for (const auto& [name, _] : tensor_storage_map) { + if (!starts_with(name, full_prefix)) { + continue; + } + std::string tail = name.substr(full_prefix.size()); + size_t dot = tail.find('.'); + if (dot == std::string::npos) { + continue; + } + int block_index = std::atoi(tail.substr(0, dot).c_str()); + count = std::max(count, block_index + 1); + } + return count; + } + + void update_axes_dim() { + int64_t dim_head = head_dim(); + int64_t unit = dim_head / 16; + axes_dim = { + static_cast(dim_head - 12 * unit), + static_cast(6 * unit), + static_cast(6 * unit), + }; + axes_dim_sum = axes_dim[0] + axes_dim[1] + axes_dim[2]; + } + + static Krea2Config detect_from_weights(const String2TensorStorage& tensor_storage_map, + const std::string& prefix) { + Krea2Config config; + int64_t detected_head_dim = 0; + int64_t detected_text_head_dim = 0; + + for (const auto& [name, tensor_storage] : tensor_storage_map) { + if (!starts_with(name, prefix)) { + continue; + } + if (ends_with(name, "first.weight") && tensor_storage.n_dims == 2) { + config.in_channels = tensor_storage.ne[0] / (config.patch_size * config.patch_size); + config.out_channels = config.in_channels; + config.features = tensor_storage.ne[1]; + } else if (ends_with(name, "blocks.0.attn.qknorm.qnorm.scale") && tensor_storage.n_dims == 1) { + detected_head_dim = tensor_storage.ne[0]; + } else if (ends_with(name, "blocks.0.attn.wq.weight") && tensor_storage.n_dims == 2) { + if (detected_head_dim > 0) { + config.heads = tensor_storage.ne[1] / detected_head_dim; + } + } else if (ends_with(name, "blocks.0.attn.wk.weight") && tensor_storage.n_dims == 2) { + if (detected_head_dim > 0) { + config.kv_heads = tensor_storage.ne[1] / detected_head_dim; + } + } else if (ends_with(name, "txtfusion.projector.weight") && tensor_storage.n_dims == 2) { + config.text_layers = tensor_storage.ne[0]; + } else if (ends_with(name, "txtfusion.layerwise_blocks.0.prenorm.scale") && tensor_storage.n_dims == 1) { + config.text_dim = tensor_storage.ne[0]; + } else if (ends_with(name, "txtfusion.layerwise_blocks.0.attn.qknorm.qnorm.scale") && tensor_storage.n_dims == 1) { + detected_text_head_dim = tensor_storage.ne[0]; + } else if (ends_with(name, "txtfusion.layerwise_blocks.0.attn.wq.weight") && tensor_storage.n_dims == 2) { + if (detected_text_head_dim > 0) { + config.text_heads = tensor_storage.ne[1] / detected_text_head_dim; + } + } else if (ends_with(name, "txtfusion.layerwise_blocks.0.attn.wk.weight") && tensor_storage.n_dims == 2) { + if (detected_text_head_dim > 0) { + config.text_kv_heads = tensor_storage.ne[1] / detected_text_head_dim; + } + } else if (ends_with(name, "last.linear.weight") && tensor_storage.n_dims == 2) { + config.out_channels = tensor_storage.ne[1] / (config.patch_size * config.patch_size); + } + } + + config.layers = std::max(1, count_blocks(tensor_storage_map, prefix, "blocks.")); + if (detected_head_dim > 0 && config.features > 0) { + config.heads = config.features / detected_head_dim; + } + if (detected_head_dim > 0) { + std::string wk_name = prefix.empty() ? "blocks.0.attn.wk.weight" : prefix + ".blocks.0.attn.wk.weight"; + auto it = tensor_storage_map.find(wk_name); + if (it != tensor_storage_map.end() && it->second.n_dims == 2) { + config.kv_heads = it->second.ne[1] / detected_head_dim; + } + } + if (detected_text_head_dim > 0 && config.text_dim > 0) { + config.text_heads = config.text_dim / detected_text_head_dim; + } + if (detected_text_head_dim > 0) { + std::string wk_name = prefix.empty() ? "txtfusion.layerwise_blocks.0.attn.wk.weight" : prefix + ".txtfusion.layerwise_blocks.0.attn.wk.weight"; + auto it = tensor_storage_map.find(wk_name); + if (it != tensor_storage_map.end() && it->second.n_dims == 2) { + config.text_kv_heads = it->second.ne[1] / detected_text_head_dim; + } + } + config.update_axes_dim(); + + LOG_DEBUG("krea2: layers=%" PRId64 ", features=%" PRId64 ", heads=%" PRId64 ", kv_heads=%" PRId64 ", text_dim=%" PRId64 ", text_layers=%" PRId64 ", text_heads=%" PRId64 ", text_kv_heads=%" PRId64 ", channels=%" PRId64, + config.layers, + config.features, + config.heads, + config.kv_heads, + config.text_dim, + config.text_layers, + config.text_heads, + config.text_kv_heads, + config.in_channels); + return config; + } + }; + + __STATIC_INLINE__ int64_t ceil_to_multiple(int64_t value, int64_t multiple) { + return ((value + multiple - 1) / multiple) * multiple; + } + + class KreaRMSNorm : public UnaryBlock { + protected: + int64_t hidden_size; + float eps; + std::string prefix; + + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + GGML_UNUSED(tensor_storage_map); + this->prefix = prefix; + params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + } + + public: + KreaRMSNorm(int64_t hidden_size, float eps = 1e-5f) + : hidden_size(hidden_size), + eps(eps) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + ggml_tensor* scale = params["scale"]; + scale = ggml_add(ctx->ggml_ctx, scale, ggml_ext_ones(ctx->ggml_ctx, scale->ne[0], 1, 1, 1)); + x = ggml_rms_norm(ctx->ggml_ctx, x, eps); + x = ggml_mul_inplace(ctx->ggml_ctx, x, scale); + return x; + } + }; + + class KreaSwiGLU : public UnaryBlock { + public: + KreaSwiGLU(int64_t features, int64_t multiplier) { + int64_t mlp_dim = ceil_to_multiple(((2 * features) / 3) * multiplier, 128); + blocks["gate"] = std::make_shared(features, mlp_dim, false); + blocks["up"] = std::make_shared(features, mlp_dim, false); + blocks["down"] = std::make_shared(mlp_dim, features, false); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto gate = std::dynamic_pointer_cast(blocks["gate"]); + auto up = std::dynamic_pointer_cast(blocks["up"]); + auto down = std::dynamic_pointer_cast(blocks["down"]); + + auto gated = ggml_silu(ctx->ggml_ctx, gate->forward(ctx, x)); + auto up_x = up->forward(ctx, x); + x = ggml_mul(ctx->ggml_ctx, gated, up_x); + return down->forward(ctx, x); + } + }; + + class KreaAttention : public GGMLBlock { + protected: + int64_t features; + int64_t heads; + int64_t kv_heads; + int64_t head_dim_; + + ggml_tensor* attention_no_rope(GGMLRunnerContext* ctx, + ggml_tensor* q, + ggml_tensor* k, + ggml_tensor* v, + ggml_tensor* mask) { + int64_t Lq = q->ne[2]; + int64_t Lk = k->ne[2]; + int64_t N = q->ne[3]; + q = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim_ * heads, Lq, N); + k = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim_ * kv_heads, Lk, N); + v = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim_ * kv_heads, Lk, N); + return ggml_ext_attention_ext(ctx->ggml_ctx, + ctx->backend, + q, + k, + v, + heads, + mask, + false, + ctx->flash_attn_enabled); + } + + public: + KreaAttention(int64_t features, + int64_t heads, + int64_t kv_heads, + float eps = 1e-5f) + : features(features), + heads(heads), + kv_heads(kv_heads), + head_dim_(features / heads) { + blocks["wq"] = std::make_shared(features, heads * head_dim_, false); + blocks["wk"] = std::make_shared(features, kv_heads * head_dim_, false); + blocks["wv"] = std::make_shared(features, kv_heads * head_dim_, false); + blocks["gate"] = std::make_shared(features, features, false); + blocks["qknorm.qnorm"] = std::make_shared(head_dim_, eps); + blocks["qknorm.knorm"] = std::make_shared(head_dim_, eps); + blocks["wo"] = std::make_shared(features, features, false); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* pe = nullptr, + ggml_tensor* mask = nullptr) { + auto wq = std::dynamic_pointer_cast(blocks["wq"]); + auto wk = std::dynamic_pointer_cast(blocks["wk"]); + auto wv = std::dynamic_pointer_cast(blocks["wv"]); + auto gate = std::dynamic_pointer_cast(blocks["gate"]); + auto qnorm = std::dynamic_pointer_cast(blocks["qknorm.qnorm"]); + auto knorm = std::dynamic_pointer_cast(blocks["qknorm.knorm"]); + auto wo = std::dynamic_pointer_cast(blocks["wo"]); + + if (sd_backend_is(ctx->backend, "Vulkan")) { + wo->set_force_prec_f32(true); + } + + int64_t L = x->ne[1]; + int64_t N = x->ne[2]; + + auto q = wq->forward(ctx, x); + q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim_, heads, L, N); + auto k = wk->forward(ctx, x); + k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim_, kv_heads, L, N); + auto v = wv->forward(ctx, x); + v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim_, kv_heads, L, N); + + q = qnorm->forward(ctx, q); + k = knorm->forward(ctx, k); + + auto out = pe != nullptr ? Rope::attention(ctx, q, k, v, pe, mask) + : attention_no_rope(ctx, q, k, v, mask); + out = ggml_mul(ctx->ggml_ctx, out, ggml_sigmoid(ctx->ggml_ctx, gate->forward(ctx, x))); + out = wo->forward(ctx, out); + return out; + } + }; + + class KreaDoubleSharedModulation : public GGMLBlock { + protected: + int64_t dim; + + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + GGML_UNUSED(tensor_storage_map); + GGML_UNUSED(prefix); + params["lin"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim * 6); + } + + public: + KreaDoubleSharedModulation(int64_t dim) + : dim(dim) {} + + std::vector forward(GGMLRunnerContext* ctx, ggml_tensor* vec) { + auto lin = ggml_repeat(ctx->ggml_ctx, params["lin"], vec); + auto out = ggml_add(ctx->ggml_ctx, vec, lin); + return ggml_ext_chunk(ctx->ggml_ctx, out, 6, 0); + } + }; + + class KreaFinalModulation : public GGMLBlock { + protected: + int64_t dim; + + void init_params(ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + GGML_UNUSED(tensor_storage_map); + GGML_UNUSED(prefix); + params["lin"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, 2); + } + + public: + KreaFinalModulation(int64_t dim) + : dim(dim) {} + + std::vector forward(GGMLRunnerContext* ctx, ggml_tensor* vec) { + auto out = ggml_add(ctx->ggml_ctx, params["lin"], vec); + return ggml_ext_chunk(ctx->ggml_ctx, out, 2, 1); + } + }; + + class KreaTextFusionBlock : public UnaryBlock { + public: + KreaTextFusionBlock(int64_t dim, + int64_t heads, + int64_t kv_heads, + int64_t multiplier, + float eps) { + blocks["prenorm"] = std::make_shared(dim, eps); + blocks["postnorm"] = std::make_shared(dim, eps); + blocks["attn"] = std::make_shared(dim, heads, kv_heads, eps); + blocks["mlp"] = std::make_shared(dim, multiplier); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto prenorm = std::dynamic_pointer_cast(blocks["prenorm"]); + auto postnorm = std::dynamic_pointer_cast(blocks["postnorm"]); + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + x = ggml_add(ctx->ggml_ctx, x, attn->forward(ctx, prenorm->forward(ctx, x))); + x = ggml_add(ctx->ggml_ctx, x, mlp->forward(ctx, postnorm->forward(ctx, x))); + return x; + } + }; + + class KreaTextFusionTransformer : public UnaryBlock { + protected: + Krea2Config config; + + public: + explicit KreaTextFusionTransformer(Krea2Config config) + : config(std::move(config)) { + for (int i = 0; i < 2; ++i) { + blocks["layerwise_blocks." + std::to_string(i)] = std::make_shared(this->config.text_dim, + this->config.text_heads, + this->config.text_kv_heads, + this->config.mlp_multiplier, + this->config.norm_eps); + blocks["refiner_blocks." + std::to_string(i)] = std::make_shared(this->config.text_dim, + this->config.text_heads, + this->config.text_kv_heads, + this->config.mlp_multiplier, + this->config.norm_eps); + } + blocks["projector"] = std::make_shared(this->config.text_layers, 1, false); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* context) override { + int64_t text_tokens = context->ne[1]; + int64_t batch = context->ne[2]; + + context = ggml_reshape_3d(ctx->ggml_ctx, + context, + config.text_dim, + config.text_layers, + text_tokens * batch); + + for (int i = 0; i < 2; ++i) { + auto block = std::dynamic_pointer_cast(blocks["layerwise_blocks." + std::to_string(i)]); + context = block->forward(ctx, context); + } + + context = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context, 1, 0, 2, 3)); + auto projector = std::dynamic_pointer_cast(blocks["projector"]); + context = projector->forward(ctx, context); + context = ggml_reshape_3d(ctx->ggml_ctx, context, config.text_dim, text_tokens, batch); + + for (int i = 0; i < 2; ++i) { + auto block = std::dynamic_pointer_cast(blocks["refiner_blocks." + std::to_string(i)]); + context = block->forward(ctx, context); + } + return context; + } + }; + + class KreaSingleStreamBlock : public UnaryBlock { + public: + explicit KreaSingleStreamBlock(Krea2Config config) { + blocks["mod"] = std::make_shared(config.features); + blocks["prenorm"] = std::make_shared(config.features, config.norm_eps); + blocks["postnorm"] = std::make_shared(config.features, config.norm_eps); + blocks["attn"] = std::make_shared(config.features, config.heads, config.kv_heads, config.norm_eps); + blocks["mlp"] = std::make_shared(config.features, config.mlp_multiplier); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* vec, + ggml_tensor* pe) { + auto mod = std::dynamic_pointer_cast(blocks["mod"]); + auto prenorm = std::dynamic_pointer_cast(blocks["prenorm"]); + auto postnorm = std::dynamic_pointer_cast(blocks["postnorm"]); + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + auto mods = mod->forward(ctx, vec); + auto attn_input = Flux::modulate(ctx->ggml_ctx, + prenorm->forward(ctx, x), + mods[1], + mods[0], + true); + auto attn_out = attn->forward(ctx, attn_input, pe); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, mods[2])); + + auto mlp_input = Flux::modulate(ctx->ggml_ctx, + postnorm->forward(ctx, x), + mods[4], + mods[3], + true); + auto mlp_out = mlp->forward(ctx, mlp_input); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, mods[5])); + return x; + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + GGML_UNUSED(ctx); + GGML_UNUSED(x); + GGML_ABORT("KreaSingleStreamBlock requires conditioning"); + return nullptr; + } + }; + + class KreaTimeMLP : public UnaryBlock { + public: + explicit KreaTimeMLP(Krea2Config config) { + blocks["0"] = std::make_shared(config.timestep_dim, config.features, true); + blocks["2"] = std::make_shared(config.features, config.features, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto linear_0 = std::dynamic_pointer_cast(blocks["0"]); + auto linear_2 = std::dynamic_pointer_cast(blocks["2"]); + x = linear_0->forward(ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, false); + x = linear_2->forward(ctx, x); + return x; + } + }; + + class KreaTProj : public UnaryBlock { + public: + explicit KreaTProj(Krea2Config config) { + blocks["1"] = std::make_shared(config.features, config.features * 6, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto linear_1 = std::dynamic_pointer_cast(blocks["1"]); + x = ggml_ext_gelu(ctx->ggml_ctx, x, false); + x = linear_1->forward(ctx, x); + return x; + } + }; + + class KreaTextMLP : public UnaryBlock { + public: + explicit KreaTextMLP(Krea2Config config) { + blocks["0"] = std::make_shared(config.text_dim, config.norm_eps); + blocks["1"] = std::make_shared(config.text_dim, config.features, true); + blocks["3"] = std::make_shared(config.features, config.features, true); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + auto norm = std::dynamic_pointer_cast(blocks["0"]); + auto linear_1 = std::dynamic_pointer_cast(blocks["1"]); + auto linear_3 = std::dynamic_pointer_cast(blocks["3"]); + x = norm->forward(ctx, x); + x = linear_1->forward(ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, true); + x = linear_3->forward(ctx, x); + return x; + } + }; + + class KreaLastLayer : public GGMLBlock { + public: + explicit KreaLastLayer(Krea2Config config) { + blocks["norm"] = std::make_shared(config.features, config.norm_eps); + blocks["linear"] = std::make_shared(config.features, config.patch_size * config.patch_size * config.out_channels, true); + blocks["modulation"] = std::make_shared(config.features); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x, ggml_tensor* vec) { + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); + + auto mods = modulation->forward(ctx, vec); + x = Flux::modulate(ctx->ggml_ctx, + norm->forward(ctx, x), + mods[1], + mods[0], + true); + x = linear->forward(ctx, x); + return x; + } + }; + + class Krea2Model : public GGMLBlock { + protected: + Krea2Config config; + + public: + Krea2Model() = default; + explicit Krea2Model(Krea2Config config) + : config(std::move(config)) { + blocks["first"] = std::make_shared(this->config.patch_size * this->config.patch_size * this->config.in_channels, + this->config.features, + true); + blocks["tmlp"] = std::make_shared(this->config); + blocks["txtfusion"] = std::make_shared(this->config); + blocks["txtmlp"] = std::make_shared(this->config); + blocks["tproj"] = std::make_shared(this->config); + for (int i = 0; i < this->config.layers; ++i) { + blocks["blocks." + std::to_string(i)] = std::make_shared(this->config); + } + blocks["last"] = std::make_shared(this->config); + } + + ggml_tensor* forward(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* timestep, + ggml_tensor* context, + ggml_tensor* pe) { + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t N = x->ne[3]; + GGML_ASSERT(N == 1); + + auto first = std::dynamic_pointer_cast(blocks["first"]); + auto tmlp = std::dynamic_pointer_cast(blocks["tmlp"]); + auto txtfusion = std::dynamic_pointer_cast(blocks["txtfusion"]); + auto txtmlp = std::dynamic_pointer_cast(blocks["txtmlp"]); + auto tproj = std::dynamic_pointer_cast(blocks["tproj"]); + auto last = std::dynamic_pointer_cast(blocks["last"]); + + auto img = DiT::pad_and_patchify(ctx, x, config.patch_size, config.patch_size, true); + int64_t img_len = img->ne[1]; + img = first->forward(ctx, img); + + auto t = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, static_cast(config.timestep_dim), 10000, 1000.f); + t = tmlp->forward(ctx, t); + t = ggml_reshape_3d(ctx->ggml_ctx, t, t->ne[0], 1, t->ne[1]); + auto tvec = tproj->forward(ctx, t); + + auto txt = txtfusion->forward(ctx, context); + txt = txtmlp->forward(ctx, txt); + int64_t txt_len = txt->ne[1]; + + auto hidden_states = ggml_concat(ctx->ggml_ctx, txt, img, 1); + for (int i = 0; i < config.layers; ++i) { + auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); + hidden_states = block->forward(ctx, hidden_states, tvec, pe); + sd::ggml_graph_cut::mark_graph_cut(hidden_states, "krea2.blocks." + std::to_string(i), "hidden_states"); + } + + hidden_states = last->forward(ctx, hidden_states, t); + hidden_states = ggml_ext_slice(ctx->ggml_ctx, hidden_states, 1, txt_len, txt_len + img_len); + hidden_states = DiT::unpatchify_and_crop(ctx->ggml_ctx, hidden_states, H, W, config.patch_size, config.patch_size, true); + return hidden_states; + } + }; + + __STATIC_INLINE__ std::vector gen_krea2_pe(int h, + int w, + int patch_size, + int bs, + int context_len, + float theta, + const std::vector& axes_dim) { + auto txt_ids = Rope::gen_flux_txt_ids(bs, context_len, 3, {}); + auto img_ids = Rope::gen_flux_img_ids(h, w, patch_size, bs, 3, 0, 0, 0, false); + auto ids = Rope::concat_ids(txt_ids, img_ids, bs); + return Rope::embed_nd(ids, bs, theta, axes_dim); + } + + struct Krea2Runner : public DiffusionModelRunner { + Krea2Config config; + Krea2Model model; + std::vector pe_vec; + + Krea2Runner(ggml_backend_t backend, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "", + std::shared_ptr weight_manager = nullptr) + : DiffusionModelRunner(backend, prefix, weight_manager), + config(Krea2Config::detect_from_weights(tensor_storage_map, prefix)) { + model = Krea2Model(config); + model.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "krea2"; + } + + void get_param_tensors(std::map& tensors, const std::string& prefix) override { + model.get_param_tensors(tensors, prefix); + } + + ggml_cgraph* build_graph(const sd::Tensor& x_tensor, + const sd::Tensor& timesteps_tensor, + const sd::Tensor& context_tensor) { + ggml_cgraph* gf = new_graph_custom(KREA2_GRAPH_SIZE); + ggml_tensor* x = make_input(x_tensor); + ggml_tensor* timesteps = make_input(timesteps_tensor); + GGML_ASSERT(x->ne[3] == 1); + GGML_ASSERT(!context_tensor.empty()); + ggml_tensor* context = make_input(context_tensor); + + pe_vec = gen_krea2_pe(static_cast(x->ne[1]), + static_cast(x->ne[0]), + config.patch_size, + static_cast(x->ne[3]), + static_cast(context->ne[1]), + config.theta, + config.axes_dim); + int pos_len = static_cast(pe_vec.size() / config.axes_dim_sum / 2); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, config.axes_dim_sum / 2, pos_len); + set_backend_tensor_data(pe, pe_vec.data()); + + auto runner_ctx = get_context(); + ggml_tensor* out = model.forward(&runner_ctx, x, timesteps, context, pe); + ggml_build_forward_expand(gf, out); + return gf; + } + + sd::Tensor compute(int n_threads, + const sd::Tensor& x, + const sd::Tensor& timesteps, + const sd::Tensor& context) { + auto get_graph = [&]() -> ggml_cgraph* { + return build_graph(x, timesteps, context); + }; + return restore_trailing_singleton_dims(GGMLRunner::compute(get_graph, n_threads, false, false, false), x.dim()); + } + + sd::Tensor compute(int n_threads, + const DiffusionParams& diffusion_params) override { + GGML_ASSERT(diffusion_params.x != nullptr); + GGML_ASSERT(diffusion_params.timesteps != nullptr); + return compute(n_threads, + *diffusion_params.x, + *diffusion_params.timesteps, + tensor_or_empty(diffusion_params.context)); + } + }; +} // namespace Krea2 + +#endif // __SD_MODEL_DIFFUSION_KREA2_HPP__ diff --git a/src/model_loader.cpp b/src/model_loader.cpp index 2fd854a84..33c056b35 100644 --- a/src/model_loader.cpp +++ b/src/model_loader.cpp @@ -453,6 +453,10 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("embed_image_indicator.weight") != std::string::npos) { return VERSION_IDEOGRAM4; } + if (tensor_storage.name.find("model.diffusion_model.txtfusion.projector.weight") != std::string::npos || + tensor_storage.name.find("model.diffusion_model.text_fusion.projector.weight") != std::string::npos) { + return VERSION_KREA2; + } if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { return VERSION_CHROMA_RADIANCE; } diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index da2a8d5ed..ccc8347b7 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -704,6 +704,38 @@ std::string convert_other_dit_to_original_anima(std::string name) { return name; } +std::string convert_diffusers_dit_to_original_krea2(std::string name) { + static const std::vector> prefix_map = { + {"img_in.", "first."}, + {"time_embed.linear_1.", "tmlp.0."}, + {"time_embed.linear_2.", "tmlp.2."}, + {"time_mod_proj.", "tproj.1."}, + {"txt_in.linear_1.", "txtmlp.1."}, + {"txt_in.linear_2.", "txtmlp.3."}, + {"text_fusion.", "txtfusion."}, + {"transformer_blocks.", "blocks."}, + {"final_layer.", "last."}, + }; + static const std::vector> name_map = { + {"attn.to_out.0.", "attn.wo."}, + {"attn.to_out.", "attn.wo."}, + {"attn.to_gate.", "attn.gate."}, + {"attn.to_q.", "attn.wq."}, + {"attn.to_k.", "attn.wk."}, + {"attn.to_v.", "attn.wv."}, + {"ff.gate.", "mlp.gate."}, + {"ff.up.", "mlp.up."}, + {"ff.down.", "mlp.down."}, + {"txt_in.norm.", "txtmlp.0."}, + {"last.norm.weight", "last.norm.scale"}, + {"last.modulation.weight", "last.modulation.lin"}, + }; + + replace_with_prefix_map(name, prefix_map); + replace_with_name_map(name, name_map); + return name; +} + std::string convert_diffusion_model_name(std::string name, std::string prefix, SDVersion version) { if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) { name = convert_diffusers_unet_to_original_sd1(name); @@ -717,6 +749,8 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_dit_to_original_lumina2(name); } else if (sd_version_is_anima(version)) { name = convert_other_dit_to_original_anima(name); + } else if (sd_version_is_krea2(version)) { + name = convert_diffusers_dit_to_original_krea2(name); } return name; } @@ -1175,7 +1209,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) { replace_with_prefix_map(name, prefix_map); - if (sd_version_is_boogu_image(version) && starts_with(name, "text_encoders.llm.visual.")) { + if ((sd_version_is_boogu_image(version) || sd_version_is_krea2(version)) && starts_with(name, "text_encoders.llm.visual.")) { name = convert_qwen3_vl_vision_name(std::move(name)); } diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index c4fe44210..63b3e90d3 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -26,6 +26,7 @@ #include "model/diffusion/flux.hpp" #include "model/diffusion/hidream_o1.hpp" #include "model/diffusion/ideogram4.hpp" +#include "model/diffusion/krea2.hpp" #include "model/diffusion/lens.hpp" #include "model/diffusion/ltxv.hpp" #include "model/diffusion/mmdit.hpp" @@ -95,6 +96,7 @@ const char* model_version_to_str[] = { "Longcat-Image", "PiD", "Ideogram 4", + "Krea2", "ESRGAN", }; @@ -645,6 +647,17 @@ class StableDiffusionGGML { tensor_storage_map, "model.diffusion_model", model_manager); + } else if (sd_version_is_krea2(version)) { + cond_stage_model = std::make_shared(backend_for(SDBackendModule::TE), + tensor_storage_map, + version, + "", + false, + model_manager); + diffusion_model = std::make_shared(backend_for(SDBackendModule::DIFFUSION), + tensor_storage_map, + "model.diffusion_model", + model_manager); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -881,6 +894,7 @@ class StableDiffusionGGML { auto create_tae = [&](bool decode_only) -> std::shared_ptr { if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || + sd_version_is_krea2(version) || sd_version_is_anima(version) || sd_version_is_ltxav(version)) { return std::make_shared(backend_for(SDBackendModule::VAE), @@ -921,6 +935,7 @@ class StableDiffusionGGML { model_manager); } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || + sd_version_is_krea2(version) || sd_version_is_anima(version)) { return std::make_shared(backend_for(SDBackendModule::VAE), tensor_storage_map, @@ -1267,7 +1282,8 @@ class StableDiffusionGGML { } else if (sd_version_is_flux(version) || sd_version_is_longcat(version) || sd_version_is_lens(version) || - sd_version_is_ltxav(version)) { + sd_version_is_ltxav(version) || + sd_version_is_krea2(version)) { pred_type = FLUX_FLOW_PRED; default_flow_shift = 1.0f; // TODO: validate @@ -1283,6 +1299,8 @@ class StableDiffusionGGML { default_flow_shift = 1.83f; } else if (sd_version_is_ltxav(version)) { default_flow_shift = 2.37f; + } else if (sd_version_is_krea2(version)) { + default_flow_shift = 1.15f; } } else if (sd_version_is_flux2(version)) { pred_type = FLUX2_FLOW_PRED; @@ -1724,7 +1742,7 @@ class StableDiffusionGGML { } else if (sd_version_uses_flux_vae(version)) { latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_bias = flux_latent_rgb_bias; - } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { + } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version) || sd_version_is_krea2(version)) { latent_rgb_proj = wan_21_latent_rgb_proj; latent_rgb_bias = wan_21_latent_rgb_bias; } else {