Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated 202 files
7 changes: 2 additions & 5 deletions src/model/diffusion/anima.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ namespace Anima {
k4 = k_norm->forward(ctx, k4);

ggml_tensor* attn_out = nullptr;
float scale = (sd_backend_is(ctx->backend, "Vulkan") && ctx->flash_attn_enabled) ? 1.0f / 32.0f : 1.0f;
if (pe_q != nullptr || pe_k != nullptr) {
if (pe_q == nullptr) {
pe_q = pe_k;
Expand All @@ -245,8 +244,7 @@ namespace Anima {
num_heads,
nullptr,
true,
ctx->flash_attn_enabled,
scale);
ctx->flash_attn_enabled);
} else {
auto q_flat = ggml_reshape_3d(ctx->ggml_ctx, q4, head_dim * num_heads, L_q, N);
auto k_flat = ggml_reshape_3d(ctx->ggml_ctx, k4, head_dim * num_heads, L_k, N);
Expand All @@ -258,8 +256,7 @@ namespace Anima {
num_heads,
nullptr,
false,
ctx->flash_attn_enabled,
scale);
ctx->flash_attn_enabled);
}

return out_proj->forward(ctx, attn_out);
Expand Down
4 changes: 1 addition & 3 deletions src/model/diffusion/ernie_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ namespace ErnieImage {
int64_t S = x->ne[1];
int64_t N = x->ne[2];

float scale = (sd_backend_is(ctx->backend, "Vulkan") && ctx->flash_attn_enabled) ? 1.0f / 32.0f : 1.0f;

auto q = to_q->forward(ctx, x);
auto k = to_k->forward(ctx, x);
auto v = to_v->forward(ctx, x);
Expand All @@ -184,7 +182,7 @@ namespace ErnieImage {
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, heads, S, head_dim]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]);

x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, ctx->flash_attn_enabled, scale); // [N, S, hidden_size]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, ctx->flash_attn_enabled); // [N, S, hidden_size]
x = to_out_0->forward(ctx, x);
return x;
}
Expand Down
Loading