diff --git a/.gitignore b/.gitignore index f8ceb1560a1..f48ce4cacd1 100644 --- a/.gitignore +++ b/.gitignore @@ -82,6 +82,7 @@ models/* models-mnt !models/.editorconfig !models/ggml-vocab-*.gguf* +!models/templates # Zig zig-out/ diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2c51b6c42d4..3f5cefe007c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1900,6 +1900,7 @@ def prepare_tensors(self): "MixtralForCausalLM", "VLlama3ForCausalLM", "LlavaForConditionalGeneration", + "VoxtralForConditionalGeneration", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -1912,6 +1913,11 @@ def __init__(self, *args, **kwargs): self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) def set_vocab(self): + path_tekken_json = self.dir_model / "tekken.json" + path_tokenizer_json = self.dir_model / "tokenizer.json" + if path_tekken_json.is_file() and not path_tokenizer_json.is_file(): + return self.set_vocab_tekken() + try: self._set_vocab_sentencepiece() except FileNotFoundError: @@ -1944,6 +1950,52 @@ def set_vocab(self): if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + def set_vocab_tekken(self): + vocab = gguf.vocab.MistralVocab(self.dir_model) + self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model) + + tokens = [] + scores = [] + toktypes = [] + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size, ( + f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})" + ) + + if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken: + self.gguf_writer.add_tokenizer_pre("tekken") + self.gguf_writer.add_token_merges( + vocab.extract_vocab_merges_from_model() + ) + + logger.info( + f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}." + ) + + self.gguf_writer.add_bos_token_id(vocab.bos_id) + self.gguf_writer.add_eos_token_id(vocab.eos_id) + self.gguf_writer.add_unk_token_id(vocab.unk_id) + self.gguf_writer.add_pad_token_id(vocab.pad_id) + + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_vocab_size(vocab.vocab_size) + + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(False) + + script_dir = Path(__file__).parent + template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja" + with open(template_path, "r", encoding="utf-8") as f: + template = f.read() + self.gguf_writer.add_chat_template(template) + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -1971,12 +2023,13 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") - is_vision_tensor = "vision_tower" in name \ + is_multimodal_tensor = "vision_tower" in name \ or "vision_model" in name \ + or "audio_tower" in name \ or "model.connector" in name \ or "multi_modal_projector" in name - if is_vision_tensor: + if is_multimodal_tensor: return [] # skip vision tensors elif self.hf_arch == "LlamaModel": name = "model." + name @@ -7231,9 +7284,10 @@ class WhisperEncoderModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.hparams["hidden_size"] = self.hparams["d_model"] - self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] - self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] + if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams: + self.hparams["hidden_size"] = self.hparams["d_model"] + self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] + self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] def set_gguf_parameters(self): super().set_gguf_parameters() @@ -7272,9 +7326,21 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel): def set_gguf_parameters(self): super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX) self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) +@ModelBase.register("VoxtralForConditionalGeneration") +class VoxtralWhisperEncoderModel(WhisperEncoderModel): + has_vision_encoder = False # no vision encoder + has_audio_encoder = True + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.VOXTRAL) + self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size + + @ModelBase.register("FalconH1ForCausalLM") class FalconH1Model(Mamba2Model): model_arch = gguf.MODEL_ARCH.FALCON_H1 @@ -7589,6 +7655,88 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] +@ModelBase.register("SmallThinkerForCausalLM") +class SmallThinkerModel(TextModel): + model_arch = gguf.MODEL_ARCH.SMALLTHINKER + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None: + self.gguf_writer.add_expert_count(n_experts) + if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + self.gguf_writer.add_feed_forward_length(moe_intermediate_size) + logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") + if (self.hparams.get('moe_primary_router_apply_softmax')): + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + # YaRN is not enabled by default + # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + + sliding_window_layout = self.hparams.get("sliding_window_layout") + if sliding_window_layout: + for i in sliding_window_layout: + if i != 0: + sliding_window = self.hparams.get("sliding_window_size") + if sliding_window: + self.gguf_writer.add_sliding_window(sliding_window) + break + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("experts") != -1: + n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts")) + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down", "gate", "up"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + ###### CONVERSION LOGIC ###### diff --git a/docs/multimodal.md b/docs/multimodal.md index edbd081df79..e2e12d07df1 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -97,6 +97,9 @@ NOTE: some models may require large context window, for example: `-c 8192` # Qwen2-Audio and SeaLLM-Audio # note: no pre-quantized GGUF this model, as they have very poor result # ref: https://github.com/ggml-org/llama.cpp/pull/13760 + +# Mistral's Voxtral +(tool_name) -hf ggml-org/Voxtral-Mini-3B-2507-GGUF ``` **Mixed modalities**: diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e9b5c306365..109253838f2 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16( K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + // Increment pointers after each loop: + K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { + // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { @@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16( } } - K += gridDim.y*D * nb11; - V += gridDim.y*D * nb21; - maskh += gridDim.y*D; - __syncthreads(); } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 6a4bdc0ff9a..2cf2e408e92 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32( K += blockIdx.y*D * nb11; V += blockIdx.y*D * nb21; maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { + for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, + // Increment pointers after each loop: + K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { + // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { @@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32( } } - K += gridDim.y*D * nb11; - V += gridDim.y*D * nb21; - maskh += gridDim.y*D; - __syncthreads(); } diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index f839a42bc90..410a67b0195 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -28,6 +28,7 @@ #include "mmvq.hpp" #include "norm.hpp" #include "outprod.hpp" +#include "quantize.hpp" #include "quants.hpp" #include "rope.hpp" #include "set_rows.hpp" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a023d6fb452..b08941c328b 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -44,6 +44,7 @@ #include "ggml-sycl/set_rows.hpp" #include "ggml-sycl/sycl_hw.hpp" #include "ggml-sycl/getrows.hpp" +#include "ggml-sycl/quantize.hpp" #include "ggml.h" static bool g_sycl_loaded = false; @@ -1373,120 +1374,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)( -template -static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, - const sycl::nd_item<3> &item_ct1) { - const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE; - - if (ix >= kx_padded) { - return; - } - - const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - - const int i_padded = iy*kx_padded + ix; - - block_q8_1 * y = (block_q8_1 *) vy; - - const int ib = i_padded / QK8_1; // block index - const int iqs = i_padded % QK8_1; // quant index - typedef sycl::vec TC; - typedef sycl::vec TQ; - TC zeros; - TQ qzeros; -#pragma unroll - for (int i = 0; i < QUANT_BLOCK_TILE; i++) - { - zeros[i] = 0.f; - qzeros[i] = 0; - } - const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros; - float sum = xi[0]; - float amax = sycl::fabs(xi[0]); -#pragma unroll - for (int i = 1; i < QUANT_BLOCK_TILE; i++) - { - sum += xi[i]; - amax = sycl::fmax(sycl::fabs(xi[i]), amax); - } - sum = warp_reduce_sum(sum, item_ct1); - amax = warp_reduce_max(amax, item_ct1); - - const float d = amax / 127; - TQ q = qzeros; - if (amax != 0.0f) - { -#pragma unroll - for (int i = 0; i < QUANT_BLOCK_TILE; i++) { - q[i] = sycl::round(xi[i] / d); - } - } - - *(TQ *)&y[ib].qs[iqs] = q; - - if (iqs > 0) { - return; - } - - reinterpret_cast(y[ib].ds.x()) = d; - reinterpret_cast(y[ib].ds.y()) = sum; -} - -template -static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor, - const int kx, const int kx_padded, const sycl::nd_item<1> & it) { - /* - Quantizes and reorders the resultant q8 tensor in a per row fashion - Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values - */ - - auto subgroup_id = it.get_group(0); - auto wi_id = it.get_local_id(0); - - const int num_blocks_per_row = kx / QK8_1; - auto row = subgroup_id / num_blocks_per_row; - auto col = subgroup_id % num_blocks_per_row; - - auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1); - auto col_offset = QK8_1 * col + wi_id * ElementsPerWI; - - auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset); - auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2)); - - sycl::vec wi_f32_vals; - sycl::vec quantized_values; - - auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id; - wi_f32_vals = *reinterpret_cast *>(x + float_ptr_offset); - - float sum = 0.0f; - float amax = 0.0f; - -#pragma unroll(ElementsPerWI) - for (int i = 0; i < ElementsPerWI; i++) { - sum += wi_f32_vals[i]; - amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i])); - quantized_values[i] = 0; - } - sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus()); - amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum()); - float d = amax == 0 ? 1 : amax / 127; - -#pragma unroll(ElementsPerWI) - for (int i = 0; i < ElementsPerWI; i++) { - quantized_values[i] = sycl::round(wi_f32_vals[i] / d); - } - - d = amax == 0 ? 0 : d; - - *reinterpret_cast *>(quant_ptr) = quantized_values; - if (wi_id == 0) { - *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum)); - } -} - static void mul_mat_p021_f16_f32( const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, @@ -1770,32 +1657,6 @@ static void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } -static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded, - bool reorder_q8_tensor, queue_ptr stream) { - if (reorder_q8_tensor) { - auto local_range = std::size_t(WARP_SIZE); - auto num_quant_blocks = ky * (kx / QK8_1); - auto global_range = num_quant_blocks * local_range; - stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), - [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - quantize_and_reorder_q8_1(x, vy, kx, kx_padded, it); - }); - } else { - const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; - const sycl::range<3> num_blocks(1, ky, block_num_x); - int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; - static_assert(QK8_1 % WARP_SIZE == 0); - const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); - { - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - - stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - quantize_q8_1(x, vy, kx, kx_padded, item_ct1); - }); - } - } -} static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, float *dst, const int ncols_x, @@ -2372,10 +2233,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { peer_access_enabled = enable_peer_access; } +template