Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,26 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
}
};

/** RMS_NORM + MUL **/

struct ggml_webgpu_rms_norm_mul_pipeline_key {
bool inplace;
bool src_overlap;

bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
return inplace == other.inplace && src_overlap == other.src_overlap;
}
};

struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.src_overlap);
return seed;
}
};

/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
Expand Down Expand Up @@ -517,7 +537,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
size_t bytes_per_kv = 0;
if (!key.kv_direct) {
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
}
Expand Down Expand Up @@ -755,16 +775,17 @@ class ggml_webgpu_shader_lib {
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
row_norm_pipelines; // op/inplace

std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized
get_rows_pipelines; // src_type, vectorized
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
unary_pipelines; // type/op/inplace
unary_pipelines; // type/op/inplace
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
scale_pipelines; // inplace
scale_pipelines; // inplace
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
solve_tri_pipelines; // type
solve_tri_pipelines; // type
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
ssm_conv_pipelines; // type/vectorized
ssm_conv_pipelines; // type/vectorized
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
webgpu_pipeline,
ggml_webgpu_gated_delta_net_pipeline_key_hash>
Expand Down Expand Up @@ -813,6 +834,11 @@ class ggml_webgpu_shader_lib {
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
conv2d_pipelines;

std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
webgpu_pipeline,
ggml_webgpu_rms_norm_mul_pipeline_key_hash>
rms_norm_mul_pipelines;

public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }

Expand Down Expand Up @@ -1828,6 +1854,39 @@ class ggml_webgpu_shader_lib {
return unary_pipelines[key];
}

webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
key.inplace = context.inplace;
key.src_overlap = context.src_overlap;

auto it = rms_norm_mul_pipelines.find(key);
if (it != rms_norm_mul_pipelines.end()) {
return it->second;
}

std::vector<std::string> defines;
std::string op_name = "RMS_NORM_MUL";
std::string variant = op_name;

if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
} else if (key.src_overlap) {
defines.push_back("SRC_OVERLAP");
variant += "_src_overlap";
}

defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));

auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
rms_norm_mul_pipelines[key] = pipeline;
return rms_norm_mul_pipelines[key];
}

webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {};
key.type = context.dst->type;
Expand Down
Loading
Loading