diff --git a/CHANGELOG.md b/CHANGELOG.md index 229921847..b83ccc83b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.28] + +- feat(example): align server MTP support with llama.cpp by @abetlen in #2283 +- feat: update llama.cpp to ggml-org/llama.cpp@9e3b928fd +- feat(example): add OpenAI-compatible embeddings endpoint by @abetlen in #2281 + ## [0.3.27] - feat: update llama.cpp to ggml-org/llama.cpp@465b1f0e7 diff --git a/examples/server/README.md b/examples/server/README.md index 1f6e0d3db..ff04374fc 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -1,7 +1,7 @@ # Server Example This example is an updated OpenAI-compatible web server that depends only on the low-level C bindings. -It supports batched inference, prompt caching, response parsing, `/v1/responses`, disk sequence caching, MTP, LoRA, and multimodal image/audio inputs. +It supports batched inference, prompt caching, response parsing, `/v1/responses`, `/v1/embeddings`, disk sequence caching, MTP, LoRA, and multimodal image/audio inputs. ## Setup @@ -46,6 +46,7 @@ The smallest checked-in example uses Qwen3.5 0.8B so the server can be started o | Config | Model | Notes | | --- | --- | --- | +| [`configs/bge-small-en-v1.5.json`](configs/bge-small-en-v1.5.json) | [`CompendiumLabs/bge-small-en-v1.5-gguf`](https://huggingface.co/CompendiumLabs/bge-small-en-v1.5-gguf) | Small embedding model config for `/v1/embeddings`. | | [`configs/qwen3.5-0.8b.json`](configs/qwen3.5-0.8b.json) | [`lmstudio-community/Qwen3.5-0.8B-GGUF`](https://huggingface.co/lmstudio-community/Qwen3.5-0.8B-GGUF) | Default small multimodal example. | | [`configs/gemma-4-12b-it-qat.json`](configs/gemma-4-12b-it-qat.json) | [`unsloth/gemma-4-12B-it-qat-GGUF`](https://huggingface.co/unsloth/gemma-4-12B-it-qat-GGUF) | Larger Gemma 4 QAT multimodal config with projector. | | [`configs/qwen3.6-27b.json`](configs/qwen3.6-27b.json) | [`unsloth/Qwen3.6-27B-GGUF`](https://huggingface.co/unsloth/Qwen3.6-27B-GGUF) | Larger Qwen3.6 multimodal config. | @@ -86,11 +87,33 @@ response = client.responses.create( print(response.output_text) ``` +### Embeddings + +Start the server with an embedding config before calling `/v1/embeddings`. + +```bash +cd examples/server +uv run --script server.py -C configs/bge-small-en-v1.5.json +``` + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://127.0.0.1:8000/v1", api_key="not-used") + +response = client.embeddings.create( + model="bge-small-en-v1.5", + input=["The food was delicious.", "The meal was excellent."], +) +print(len(response.data[0].embedding)) +``` + ## API Surface | Endpoint | Purpose | Reference | | --- | --- | --- | | `POST /v1/completions` | Legacy text completions with streaming, stop sequences, logprobs, penalties, seeds, and grammar-backed JSON output. | [OpenAI Completions API](https://platform.openai.com/docs/api-reference/completions) | +| `POST /v1/embeddings` | OpenAI-compatible embeddings for embedding-mode GGUF models, including string inputs, token inputs, base64 output, and dimensions truncation. | [OpenAI Embeddings API](https://platform.openai.com/docs/api-reference/embeddings) | | `POST /v1/chat/completions` | Chat completions with streaming, tools, forced tool choice, reasoning parsing, multimodal content parts, and structured response parsing. | [OpenAI Chat API](https://platform.openai.com/docs/api-reference/chat) | | `POST /v1/responses` | Stateless Responses API compatibility for clients that use response items and response events. | [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) | | `WS /v1/responses` | Stateful websocket Responses transport with per-connection `previous_response_id` replay. | [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) | @@ -190,6 +213,8 @@ Most model runtime fields map to `llama_model_params` or `llama_context_params` | `threads` | Decode thread count. | | `threads_batch` | Prefill and batch thread count. | | `kv_unified` | Selects unified or per-sequence memory layout. | +| `embedding` | Overrides embedding mode; omit to auto-detect pooled embedding GGUFs from model metadata. | +| `pooling_type` | Overrides pooled embedding behavior for embedding models, such as `1` for mean pooling. | | `store_logits` | Keeps logits after decode when needed by sampling or diagnostics. | | `use_mmap` | Memory maps model weights. | | `use_mlock` | Attempts to lock model pages into RAM. | @@ -409,6 +434,22 @@ Use MTP when the loaded model and llama.cpp build expose the required draft stat } ``` +By default `draft-mtp` creates the MTP context from the target model. +Set `draft_model_path` or `draft_model_from_pretrained` when the model uses a separate assistant GGUF. + +```json +{ + "model": { + "draft_model": "draft-mtp", + "draft_model_num_pred_tokens": 2, + "draft_model_from_pretrained": { + "repo_id": "example/gemma-assistant-GGUF", + "filename": "assistant.gguf" + } + } +} +``` + MTP currently applies to text-only requests. ## Disk Sequence Cache diff --git a/examples/server/configs/bge-small-en-v1.5.json b/examples/server/configs/bge-small-en-v1.5.json new file mode 100644 index 000000000..3fc8016df --- /dev/null +++ b/examples/server/configs/bge-small-en-v1.5.json @@ -0,0 +1,22 @@ +{ + "server": { + "host": "0.0.0.0", + "port": 8000 + }, + "model": { + "alias": "bge-small-en-v1.5", + "from_pretrained": { + "repo_id": "CompendiumLabs/bge-small-en-v1.5-gguf", + "filename": "bge-small-en-v1.5-q4_k_m.gguf" + }, + "n_ctx": 512, + "n_seq_max": 16, + "n_batch": 512, + "n_ubatch": 512, + "threads": 4, + "threads_batch": 8, + "kv_unified": true, + "store_logits": false, + "use_mmap": true + } +} diff --git a/examples/server/server.py b/examples/server/server.py index 590a21420..64a16f0bd 100644 --- a/examples/server/server.py +++ b/examples/server/server.py @@ -36,7 +36,6 @@ import copy import shutil import inspect -import importlib.util import sys import urllib.error import urllib.parse @@ -1262,27 +1261,33 @@ def __init__( self, *, model: "Model", + draft_model: Any, context_params: Any, num_pred_tokens: int, top_k: int, p_min: float, ) -> None: self.target_ctx = model.ctx - self.model = model.llama_model + self.model = draft_model self.n_seq_max = model.n_seq_max self.n_vocab = model.n_vocab - self.n_embd = int(llama_cpp.llama_model_n_embd(self.model)) + self.n_embd = int(llama_cpp.llama_model_n_embd_out(self.model)) + if self.n_embd <= 0: + self.n_embd = int(llama_cpp.llama_model_n_embd(self.model)) + if self.n_embd != model.n_embd: + raise RuntimeError( + "MTP draft model output embedding size must match target model " + f"embedding size ({self.n_embd} != {model.n_embd})" + ) self.num_pred_tokens = max(0, int(num_pred_tokens)) self.top_k = max(1, int(top_k)) self.p_min = max(0.0, min(1.0, float(p_min))) - ( - self.target_hidden_norm_weight, - self.draft_hidden_norm_weight, - self.hidden_norm_epsilon, - ) = self._load_hidden_norm_weights(model.model_path) self.ctx = llama_cpp.llama_init_from_model(self.model, context_params) if self.ctx is None: raise RuntimeError("failed to create MTP draft context") + ctx_other = llama_cpp_ext.llama_get_ctx_other(self.ctx) + self.is_mem_shared = bool(ctx_other and ctx_other == self.target_ctx) + self.sampled_batch_draft = not self.is_mem_shared self.n_batch = int(llama_cpp.llama_n_batch(self.ctx)) mem = llama_cpp.llama_get_memory(self.ctx) if mem is None: @@ -1313,135 +1318,13 @@ def __init__( self.decode_failures_total = 0 self.target_processing_enabled = False self.set_target_processing_enabled(True) - llama_cpp_ext.llama_set_embeddings_pre_norm( + llama_cpp_ext.llama_set_embeddings_nextn( self.ctx, True, True, ) self._init_samplers() - @staticmethod - def _load_gguf_reader() -> Any: - try: - from gguf import GGUFReader # type: ignore[import-not-found] - return GGUFReader - except ImportError: - pass - - gguf_init = ( - Path(__file__).resolve().parents[2] - / "vendor/llama.cpp/gguf-py/gguf/__init__.py" - ) - spec = importlib.util.spec_from_file_location( - "gguf", - gguf_init, - submodule_search_locations=[str(gguf_init.parent)], - ) - if spec is None or spec.loader is None: - raise RuntimeError("MTP requires the gguf Python reader from llama.cpp") - module = importlib.util.module_from_spec(spec) - previous = sys.modules.get("gguf") - # Package-relative imports in gguf-py expect the package to be registered. - sys.modules["gguf"] = module - try: - spec.loader.exec_module(module) - except Exception: - if previous is None: - sys.modules.pop("gguf", None) - else: - sys.modules["gguf"] = previous - raise - return module.GGUFReader - - @staticmethod - def _gguf_field_contents(reader: Any, name: str) -> Any: - field = reader.fields.get(name) - if field is None: - return None - return field.contents() - - def _load_hidden_norm_weights( - self, - model_path: str, - ) -> Tuple[np.ndarray, np.ndarray, float]: - GGUFReader = self._load_gguf_reader() - - target_weight: Optional[np.ndarray] = None - draft_weights: List[np.ndarray] = [] - reader = GGUFReader(model_path) - arch = self._gguf_field_contents(reader, "general.architecture") - if arch not in {"qwen35", "qwen35moe"}: - raise RuntimeError( - "draft-mtp currently supports qwen35/qwen35moe GGUF models, " - f"got {arch!r}" - ) - nextn_layers = self._gguf_field_contents( - reader, - f"{arch}.nextn_predict_layers", - ) - # The current MTP path follows llama.cpp's Qwen3.5 one-nextn-layer graph. - if int(nextn_layers or 0) != 1: - raise RuntimeError( - "draft-mtp currently supports exactly one Qwen3.5 nextn prediction layer" - ) - epsilon = self._gguf_field_contents( - reader, - f"{arch}.attention.layer_norm_rms_epsilon", - ) - if epsilon is None: - raise RuntimeError( - f"MTP requires {arch}.attention.layer_norm_rms_epsilon" - ) - for tensor in reader.tensors: - if tensor.name == "output_norm.weight": - target_weight = np.asarray(tensor.data, dtype=np.float32).copy() - elif tensor.name.endswith(".nextn.shared_head_norm.weight"): - draft_weights.append(np.asarray(tensor.data, dtype=np.float32).copy()) - if target_weight is None: - raise RuntimeError("MTP requires output_norm.weight in GGUF model") - if len(draft_weights) > 1: - raise RuntimeError( - "MTP requires at most one blk.*.nextn.shared_head_norm.weight in GGUF model" - ) - draft_weight = draft_weights[0] if draft_weights else target_weight - if target_weight.shape != (self.n_embd,): - raise RuntimeError( - "MTP target norm weight shape does not match model embedding size " - f"({target_weight.shape} != ({self.n_embd},))" - ) - if draft_weight.shape != (self.n_embd,): - raise RuntimeError( - "MTP draft norm weight shape does not match model embedding size " - f"({draft_weight.shape} != ({self.n_embd},))" - ) - return target_weight, draft_weight, float(epsilon) - - def _normalize_hidden_rows(self, rows: np.ndarray, weight: np.ndarray) -> np.ndarray: - rows = np.asarray(rows, dtype=np.float32) - scale = np.reciprocal( - np.sqrt( - np.mean(np.square(rows), axis=-1, keepdims=True) - + self.hidden_norm_epsilon - ) - ) - return rows * scale * weight.reshape(1, -1) - - def _normalize_target_hidden_rows(self, rows: np.ndarray) -> np.ndarray: - return self._normalize_hidden_rows(rows, self.target_hidden_norm_weight) - - def _normalize_draft_hidden_row( - self, - row: Union[np.ndarray, ctypes.POINTER(ctypes.c_float)], - ) -> np.ndarray: - if isinstance(row, np.ndarray): - row_array = row.reshape(1, -1) - else: - row_array = np.ctypeslib.as_array(row, shape=(self.n_embd,)).reshape(1, -1) - return self._normalize_hidden_rows( - row_array, - self.draft_hidden_norm_weight, - )[0] - def _init_samplers(self) -> None: for seq_id in range(self.n_seq_max): params = llama_cpp.llama_sampler_chain_default_params() @@ -1507,7 +1390,7 @@ def close(self) -> None: def set_target_processing_enabled(self, enabled: bool) -> None: if self.target_processing_enabled == enabled: return - llama_cpp_ext.llama_set_embeddings_pre_norm( + llama_cpp_ext.llama_set_embeddings_nextn( self.target_ctx, enabled, False, @@ -1641,14 +1524,13 @@ def process(self, batch: Any, /) -> None: ): return - h_tgt = llama_cpp_ext.llama_get_embeddings_pre_norm(self.target_ctx) + h_tgt = llama_cpp_ext.llama_get_embeddings_nextn(self.target_ctx) if not h_tgt: - raise RuntimeError("missing target pre-norm embeddings for MTP") + raise RuntimeError("missing target nextn embeddings for MTP") h_tgt_rows = np.ctypeslib.as_array( h_tgt, shape=(n_tokens, self.n_embd), ) - h_tgt_rows = self._normalize_target_hidden_rows(h_tgt_rows) previous_row_by_seq: Dict[int, int] = {} first_pos_by_seq: Dict[int, int] = {} @@ -1708,17 +1590,17 @@ def _process_rows( self.ready[seq_id] and self.ready_pos[seq_id] == first_pos ) aligned_by_seq.setdefault(seq_id, aligned) - mtp_pos = ( - pos - 1 - if previous_row_by_seq.get(seq_id) is None - else int(batch.pos[previous_row_by_seq[seq_id]]) - ) - if aligned and mtp_pos >= 0 and mtp_pos >= self.context_pos[seq_id]: - previous_row = previous_row_by_seq.get(seq_id) + previous_row = previous_row_by_seq.get(seq_id) + if ( + aligned + and not self.is_mem_shared + and pos >= 0 + and pos >= self.context_pos[seq_id] + ): slot = int(self.batch.n_tokens) self._add_batch_token( token=int(batch.token[index]), - pos=mtp_pos, + pos=pos, seq_id=seq_id, logits=False, ) @@ -1726,7 +1608,7 @@ def _process_rows( self._set_batch_embedding_row(slot, self.pending_h[seq_id]) else: self._set_batch_embedding_row(slot, h_tgt_rows[previous_row]) - added_pos_by_seq[seq_id] = mtp_pos + added_pos_by_seq[seq_id] = pos previous_row_by_seq[seq_id] = index target_rows_by_seq.setdefault(seq_id, []).append(index) @@ -1760,15 +1642,15 @@ def draft( n_past = int(input_ids.size) - 1 if self.ready_pos[seq_id] != n_past: return np.array([], dtype=np.intc) - first_pos = n_past - 1 + first_pos = n_past if first_pos < 0: return np.array([], dtype=np.intc) token = int(input_ids[-1]) drafted: List[int] = [] - if self.context_pos[seq_id] > first_pos: + if not self.is_mem_shared and self.context_pos[seq_id] > first_pos: self.truncate(seq_id, first_pos) - if self.context_pos[seq_id] < first_pos: + if not self.is_mem_shared and self.context_pos[seq_id] < first_pos: self.ready[seq_id] = False return np.array([], dtype=np.intc) @@ -1782,9 +1664,11 @@ def draft( ) self._set_batch_embedding_row(0, self.pending_h[seq_id]) if not self._try_decode_batch(): - self.truncate(seq_id, first_pos) + if not self.is_mem_shared: + self.truncate(seq_id, first_pos) return np.array([], dtype=np.intc) - self.context_pos[seq_id] = n_past + if not self.is_mem_shared: + self.context_pos[seq_id] = first_pos + 1 while len(drafted) < n_predict: sampled_token = self._sample_token(seq_id=seq_id) @@ -1794,26 +1678,28 @@ def draft( drafted.append(token) if len(drafted) >= n_predict: break - h_row = llama_cpp_ext.llama_get_embeddings_pre_norm_ith(self.ctx, 0) + h_row = llama_cpp_ext.llama_get_embeddings_nextn_ith(self.ctx, 0) if not h_row: break - h_row = self._normalize_draft_hidden_row(h_row) self._clear_batch() self._add_batch_token( token=token, - pos=first_pos + len(drafted), + pos=first_pos if self.is_mem_shared else first_pos + len(drafted), seq_id=seq_id, logits=True, ) self._set_batch_embedding_row(0, h_row) if not self._try_decode_batch(): break - self.context_pos[seq_id] = first_pos + len(drafted) + 1 + if not self.is_mem_shared: + self.context_pos[seq_id] = first_pos + len(drafted) + 1 if not drafted: - self.truncate(seq_id, n_past) + if not self.is_mem_shared: + self.truncate(seq_id, n_past) return np.array([], dtype=np.intc) - self.truncate(seq_id, n_past) + if not self.is_mem_shared: + self.truncate(seq_id, n_past) return np.asarray(drafted, dtype=np.intc) def draft_many( @@ -1843,12 +1729,12 @@ def draft_many( n_past = int(input_ids.size) - 1 if self.ready_pos[seq_id] != n_past: continue - first_pos = n_past - 1 + first_pos = n_past if first_pos < 0: continue - if self.context_pos[seq_id] > first_pos: + if not self.is_mem_shared and self.context_pos[seq_id] > first_pos: self.truncate(seq_id, first_pos) - if self.context_pos[seq_id] < first_pos: + if not self.is_mem_shared and self.context_pos[seq_id] < first_pos: self.ready[seq_id] = False continue self._reset_sampler(seq_id) @@ -1904,10 +1790,11 @@ def draft_many( ): if sampled_token is None: continue - self.context_pos[representative.seq_id] = max( - self.context_pos[representative.seq_id], - representative.keep_len, - ) + if not self.is_mem_shared: + self.context_pos[representative.seq_id] = max( + self.context_pos[representative.seq_id], + representative.first_pos + 1, + ) for state in grouped[representative.cache_key]: state.drafted.append(sampled_token) active = [] @@ -1917,7 +1804,11 @@ def draft_many( for row, state in enumerate(active): self._add_batch_token( token=state.token, - pos=state.first_pos + len(state.drafted), + pos=( + state.first_pos + if self.is_mem_shared + else state.first_pos + len(state.drafted) + ), seq_id=state.seq_id, logits=True, ) @@ -1932,14 +1823,19 @@ def draft_many( for row, state in enumerate(active) ] for row, (state, sampled_token) in enumerate(zip(active, sampled_tokens)): - decoded_pos = state.first_pos + len(state.drafted) - self.context_pos[state.seq_id] = max( - self.context_pos[state.seq_id], - decoded_pos + 1, + decoded_pos = ( + state.first_pos + if self.is_mem_shared + else state.first_pos + len(state.drafted) ) + if not self.is_mem_shared: + self.context_pos[state.seq_id] = max( + self.context_pos[state.seq_id], + decoded_pos + 1, + ) if sampled_token is None: continue - h_row_ptr = llama_cpp_ext.llama_get_embeddings_pre_norm_ith( + h_row_ptr = llama_cpp_ext.llama_get_embeddings_nextn_ith( self.ctx, row ) state.drafted.append(sampled_token) @@ -1947,14 +1843,17 @@ def draft_many( continue if not h_row_ptr: continue - h_row = self._normalize_draft_hidden_row(h_row_ptr) state.token = sampled_token - state.embedding = h_row + state.embedding = np.ctypeslib.as_array( + h_row_ptr, + shape=(self.n_embd,), + ).copy() next_active.append(state) active = next_active finally: - for state in touched: - self.truncate(state.seq_id, state.keep_len) + if not self.is_mem_shared: + for state in touched: + self.truncate(state.seq_id, state.keep_len) for state in touched: if state.drafted: @@ -2036,7 +1935,7 @@ def _build_sampled_batch_plan( pending_token = update.pending_token if ( pending_token is None - or start_pos <= 0 + or start_pos < 0 or target_count <= 0 or target_count > len(tokens) or target_count > len(row_indices) @@ -2046,9 +1945,7 @@ def _build_sampled_batch_plan( continue for target_index in range(sample_index + 1): - mtp_pos = start_pos + target_index - 1 - if mtp_pos < 0: - continue + mtp_pos = start_pos + target_index source_row = ( None if target_index == 0 @@ -2063,7 +1960,7 @@ def _build_sampled_batch_plan( ) ) - actual_pos = start_pos + sample_index + actual_pos = start_pos + sample_index + 1 pending_rows.append( self.SampledPendingRow( update_index=update_index, @@ -2090,6 +1987,8 @@ def _decode_sampled_context_rows( self._clear_batch() decoded_context_rows: List[Tuple[int, int]] = [] for row in context_rows: + if self.is_mem_shared: + continue if row.draft_pos < self.context_pos[row.seq_id]: continue if row.source_row is None: @@ -2124,7 +2023,10 @@ def _decode_sampled_pending_rows( self._clear_batch() for pending_index, row in enumerate(pending_rows): - if row.draft_pos < self.context_pos[row.seq_id]: + if ( + not self.is_mem_shared + and row.draft_pos < self.context_pos[row.seq_id] + ): continue is_sample_pending = ( pending_index @@ -2144,17 +2046,18 @@ def _decode_sampled_pending_rows( update_index=row.update_index, seq_id=row.seq_id, output_index=slot, - keep_len=row.draft_pos + 1, - ready_pos=row.draft_pos + 1, + keep_len=row.draft_pos, + ready_pos=row.draft_pos, ) ) self._decode_batch() - for row in pending_rows: - self.context_pos[row.seq_id] = max( - self.context_pos[row.seq_id], - row.draft_pos + 1, - ) + if not self.is_mem_shared: + for row in pending_rows: + self.context_pos[row.seq_id] = max( + self.context_pos[row.seq_id], + row.draft_pos + 1, + ) return sampled_outputs @@ -2184,7 +2087,7 @@ def _start_sampled_draft_states( if n_predict <= 0: continue if n_predict > 1: - h_row = llama_cpp_ext.llama_get_embeddings_pre_norm_ith( + h_row = llama_cpp_ext.llama_get_embeddings_nextn_ith( self.ctx, output.output_index ) if h_row: @@ -2193,11 +2096,18 @@ def _start_sampled_draft_states( update_index=output.update_index, seq_id=seq_id, keep_len=output.keep_len, - pos=output.keep_len, + pos=( + output.keep_len + if self.is_mem_shared + else output.keep_len + 1 + ), token=sampled_token, drafted=[sampled_token], n_predict=n_predict, - embedding=self._normalize_draft_hidden_row(h_row), + embedding=np.ctypeslib.as_array( + h_row, + shape=(self.n_embd,), + ).copy(), ) ) results[output.update_index] = np.asarray([sampled_token], dtype=np.intc) @@ -2235,31 +2145,36 @@ def _extend_sampled_draft_states( for batch_row, (state, sampled_token) in enumerate( zip(active, sampled_tokens) ): - self.context_pos[state.seq_id] = max( - self.context_pos[state.seq_id], - state.pos + 1, - ) + if not self.is_mem_shared: + self.context_pos[state.seq_id] = max( + self.context_pos[state.seq_id], + state.pos + 1, + ) if sampled_token is None: continue state.drafted.append(sampled_token) if len(state.drafted) >= state.n_predict: continue - h_row = llama_cpp_ext.llama_get_embeddings_pre_norm_ith( + h_row = llama_cpp_ext.llama_get_embeddings_nextn_ith( self.ctx, batch_row ) if not h_row: continue - h_row = self._normalize_draft_hidden_row(h_row) state.token = sampled_token - state.embedding = h_row - state.pos += 1 + state.embedding = np.ctypeslib.as_array( + h_row, + shape=(self.n_embd,), + ).copy() + if not self.is_mem_shared: + state.pos += 1 next_active.append(state) active = next_active finally: for state in touched: cleanup_keep_len_by_seq[state.seq_id] = state.keep_len - for seq_id, keep_len in cleanup_keep_len_by_seq.items(): - self._truncate_memory(seq_id, keep_len) + if not self.is_mem_shared: + for seq_id, keep_len in cleanup_keep_len_by_seq.items(): + self._truncate_memory(seq_id, keep_len) for state in touched: if state.drafted: @@ -2276,9 +2191,9 @@ def process_sampled_batch( results = [np.array([], dtype=np.intc) for _ in updates] if self.num_pred_tokens <= 0 or not updates: return results - h_tgt = llama_cpp_ext.llama_get_embeddings_pre_norm(self.target_ctx) + h_tgt = llama_cpp_ext.llama_get_embeddings_nextn(self.target_ctx) if not h_tgt: - raise RuntimeError("missing target pre-norm embeddings for MTP") + raise RuntimeError("missing target nextn embeddings for MTP") n_target_rows = max( ( max(update.row_indices) + 1 @@ -2290,7 +2205,6 @@ def process_sampled_batch( if n_target_rows <= 0: return results h_tgt_rows = np.ctypeslib.as_array(h_tgt, shape=(n_target_rows, self.n_embd)) - h_tgt_rows = self._normalize_target_hidden_rows(h_tgt_rows) plan = self._build_sampled_batch_plan(updates) if not plan.context_rows and not plan.pending_rows: @@ -2313,8 +2227,9 @@ def process_sampled_batch( self.ready_pos[output.seq_id] = output.ready_pos cleanup_keep_len_by_seq[output.seq_id] = output.keep_len - for seq_id, keep_len in cleanup_keep_len_by_seq.items(): - self._truncate_memory(seq_id, keep_len) + if not self.is_mem_shared: + for seq_id, keep_len in cleanup_keep_len_by_seq.items(): + self._truncate_memory(seq_id, keep_len) if sampled_outputs: active = self._start_sampled_draft_states( @@ -2344,6 +2259,9 @@ def accept(self, seq_id: int, accepted_draft_tokens: int) -> None: def _truncate_memory(self, seq_id: int, keep_len: int) -> None: if seq_id < 0 or seq_id >= self.n_seq_max: return + if self.is_mem_shared: + self.context_pos[seq_id] = min(self.context_pos[seq_id], keep_len) + return if not llama_cpp.llama_memory_seq_rm( self.mem, seq_id, @@ -2386,13 +2304,14 @@ def copy_sequence( or dest_seq_id >= self.n_seq_max ): return - llama_cpp.llama_memory_seq_cp( - self.mem, - source_seq_id, - dest_seq_id, - p0, - p1, - ) + if not self.is_mem_shared: + llama_cpp.llama_memory_seq_cp( + self.mem, + source_seq_id, + dest_seq_id, + p0, + p1, + ) source_ready_pos = self.ready_pos[source_seq_id] copied_full_ready_state = p1 < 0 or p1 == source_ready_pos if self.ready[source_seq_id] and copied_full_ready_state: @@ -2457,6 +2376,7 @@ class CompletionChunk(TypedDict): CompletionStream = Generator[CompletionChunk, None, OpenAICompletion] CompletionPrompt = Union[str, List[int], List[str], List[List[int]]] +EmbeddingInput = Union[str, List[str], List[int], List[List[int]]] class CreateCompletionRequest(BaseModel): @@ -2522,6 +2442,126 @@ def normalized_prompt(self) -> List[Union[str, List[int]]]: raise ValueError("prompt must be a string, token ids, list of strings, or list of token-id lists") +class CreateEmbeddingRequest(BaseModel): + model_config = ConfigDict(extra="ignore") + + input: EmbeddingInput + model: str + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = Field(default=None, ge=1) + user: Optional[str] = None + + @staticmethod + def _validate_text_input(text: str) -> str: + if text == "": + raise ValueError("embedding input must not contain empty strings") + return text + + @staticmethod + def _validate_token_input(tokens: List[int]) -> List[int]: + if not tokens: + raise ValueError("embedding token input must not be empty") + if len(tokens) > 2048: + raise ValueError("embedding token input must not exceed 2048 tokens") + return tokens + + @model_validator(mode="after") + def validate_after(self) -> "CreateEmbeddingRequest": + self.normalized_input() + return self + + def normalized_input(self) -> List[Union[str, List[int]]]: + if isinstance(self.input, str): + return [self._validate_text_input(self.input)] + if all(isinstance(token, int) for token in self.input): + return [self._validate_token_input(cast(List[int], self.input))] + if all(isinstance(item, str) for item in self.input): + if len(self.input) > 2048: + raise ValueError("embedding input array must not exceed 2048 items") + return [ + self._validate_text_input(item) + for item in cast(List[str], self.input) + ] + if all( + isinstance(item, list) + and all(isinstance(token, int) for token in item) + for item in self.input + ): + if len(self.input) > 2048: + raise ValueError("embedding input array must not exceed 2048 items") + return [ + self._validate_token_input(item) + for item in cast(List[List[int]], self.input) + ] + raise ValueError( + "embedding input must be a string, list of strings, token ids, or list of token-id lists" + ) + + +class EmbeddingDataResponse(BaseModel): + object: Literal["embedding"] = "embedding" + embedding: Union[List[float], str] + index: int + + +class EmbeddingUsageResponse(BaseModel): + prompt_tokens: int + total_tokens: int + + +class CreateEmbeddingResponse(BaseModel): + object: Literal["list"] = "list" + data: List[EmbeddingDataResponse] + model: str + usage: EmbeddingUsageResponse + + @staticmethod + def encode_embedding( + embedding: Sequence[float], + encoding_format: Literal["float", "base64"], + dimensions: Optional[int], + ) -> Union[List[float], str]: + if dimensions is not None: + if dimensions > len(embedding): + raise CompletionRequestValidationError( + f"dimensions ({dimensions}) exceeds embedding size ({len(embedding)})" + ) + embedding = embedding[:dimensions] + if encoding_format == "float": + return [float(value) for value in embedding] + array = np.asarray(embedding, dtype=np.float32) + return base64.b64encode(array.tobytes()).decode("ascii") + + @classmethod + def from_embeddings( + cls, + *, + model: str, + embeddings: Sequence[Sequence[float]], + total_tokens: int, + encoding_format: Literal["float", "base64"], + dimensions: Optional[int], + ) -> "CreateEmbeddingResponse": + return cls( + data=[ + EmbeddingDataResponse( + embedding=cls.encode_embedding( + embedding, + encoding_format, + dimensions, + ), + index=index, + ) + for index, embedding in enumerate(embeddings) + ], + model=model, + usage=EmbeddingUsageResponse( + prompt_tokens=total_tokens, + total_tokens=total_tokens, + ), + ) + + class ChatCompletionFunctionCall(BaseModel): model_config = ConfigDict( extra="allow", @@ -3214,6 +3254,7 @@ class ModelOptions(BaseModel): rope_scaling_type: Optional[int] = None pooling_type: Optional[int] = None attention_type: Optional[int] = None + embedding: Optional[bool] = None rope_freq_base: Optional[float] = None rope_freq_scale: Optional[float] = None yarn_ext_factor: Optional[float] = None @@ -3232,6 +3273,10 @@ class ModelOptions(BaseModel): max_output_tokens: Optional[int] = Field(default=None, ge=0) kv_unified: bool = True draft_model: Optional[Literal["prompt-lookup-decoding", "draft-mtp"]] = None + draft_model_path: Optional[str] = None + draft_model_from_pretrained: Optional[ + "ConfigFile.FromPretrainedOptions" + ] = None draft_model_num_pred_tokens: int = 16 draft_model_max_ngram_size: int = 2 draft_model_top_k: int = Field(default=1, ge=1) @@ -3246,8 +3291,21 @@ class ModelOptions(BaseModel): def validate_source(self) -> "ConfigFile.ModelOptions": if (self.path is None) == (self.from_pretrained is None): raise ValueError("exactly one of model.path or model.from_pretrained is required") + if ( + self.draft_model_path is not None + and self.draft_model_from_pretrained is not None + ): + raise ValueError( + "model.draft_model_path and model.draft_model_from_pretrained " + "are mutually exclusive" + ) return self + def resolve_draft_model_path(self) -> Optional[str]: + if self.draft_model_from_pretrained is not None: + return self.draft_model_from_pretrained.resolve_model_path() + return self.draft_model_path + @field_validator("chat_template", mode="before") @classmethod def normalize_chat_template(cls, value: Any) -> Any: @@ -10408,6 +10466,7 @@ def __init__( rope_scaling_type: Optional[int] = None, pooling_type: Optional[int] = None, attention_type: Optional[int] = None, + embedding: Optional[bool] = None, rope_freq_base: Optional[float] = None, rope_freq_scale: Optional[float] = None, yarn_ext_factor: Optional[float] = None, @@ -10426,6 +10485,7 @@ def __init__( max_seq_len: Optional[int] = None, max_output_tokens: Optional[int] = None, draft_model: Optional[str] = None, + draft_model_path: Optional[str] = None, draft_model_num_pred_tokens: int = 16, draft_model_max_ngram_size: int = 2, draft_model_top_k: int = 1, @@ -10450,6 +10510,7 @@ def __init__( self._lora_adapters: List[Any] = [] self._lora_adapter_array: Optional[Any] = None self._lora_scales_array: Optional[Any] = None + self.draft_llama_model: Optional[Any] = None model_params, self._c_tensor_split, self._kv_overrides_array = ( self.build_model_params( n_gpu_layers=n_gpu_layers, @@ -10473,9 +10534,13 @@ def __init__( if vocab is None: raise RuntimeError("failed to access model vocabulary") self.vocab = vocab - if llama_cpp.llama_model_has_encoder(llama_model): + embedding = self.resolve_embedding_mode(llama_model, embedding) + self.embedding = embedding + self.has_encoder = bool(llama_cpp.llama_model_has_encoder(llama_model)) + self.has_decoder = bool(llama_cpp.llama_model_has_decoder(llama_model)) + if self.has_encoder and not embedding: raise RuntimeError("encoder models are not supported") - if not llama_cpp.llama_model_has_decoder(llama_model): + if not self.has_decoder and not (embedding and self.has_encoder): raise RuntimeError("decoder is required") if llama_cpp.llama_model_is_recurrent(llama_model): self.memory_model = "recurrent" @@ -10503,11 +10568,6 @@ def __init__( "speculative decoding is only supported for attention models" ) n_ctx_train = int(llama_cpp.llama_model_n_ctx_train(llama_model)) - target_n_rs_seq = ( - max(1, draft_model_num_pred_tokens) - if normalized_draft_model == "draft-mtp" - else None - ) context_params = self.build_context_params( n_ctx=n_ctx if n_ctx is not None else n_ctx_train, @@ -10519,6 +10579,7 @@ def __init__( rope_scaling_type=rope_scaling_type, pooling_type=pooling_type, attention_type=attention_type, + embedding=embedding, rope_freq_base=rope_freq_base, rope_freq_scale=rope_freq_scale, yarn_ext_factor=yarn_ext_factor, @@ -10534,7 +10595,7 @@ def __init__( type_k=type_k, type_v=type_v, kv_unified=kv_unified, - n_rs_seq=target_n_rs_seq, + n_rs_seq=None, ctx_type=None, ) ctx = llama_cpp.llama_init_from_model(llama_model, context_params) @@ -10542,7 +10603,7 @@ def __init__( raise RuntimeError("failed to create context") self.ctx = ctx mem = llama_cpp.llama_get_memory(ctx) - if mem is None: + if mem is None and not embedding: raise RuntimeError("failed to access model memory") self.mem = mem self.n_ctx = int(llama_cpp.llama_n_ctx(ctx)) @@ -10564,14 +10625,13 @@ def __init__( "MTP requires runtime n_batch to fit the pending token plus draft tokens " f"(required {required_mtp_batch}, got {self.n_batch})" ) - if target_n_rs_seq is not None and self.n_rs_seq < target_n_rs_seq: - raise RuntimeError( - "MTP requires retained recurrent-state slots for rollback " - f"(required {target_n_rs_seq}, got {self.n_rs_seq})" - ) self.n_ctx_train = n_ctx_train self.n_vocab = int(llama_cpp.llama_vocab_n_tokens(self.vocab)) + self.n_embd = int(llama_cpp.llama_model_n_embd(self.llama_model)) self.n_embd_inp = int(llama_cpp.llama_model_n_embd_inp(self.llama_model)) + self.n_embd_out = int(llama_cpp.llama_model_n_embd_out(self.llama_model)) + if self.n_embd_out <= 0: + self.n_embd_out = self.n_embd self.kv_unified = kv_unified self.max_seq_len_limit = min(self.request_context_limit, self.n_ctx_train) if max_seq_len is None: @@ -10615,6 +10675,22 @@ def __init__( num_pred_tokens=draft_model_num_pred_tokens, ) elif normalized_draft_model == "draft-mtp": + draft_llama_model = self.llama_model + if draft_model_path is not None: + draft_llama_model = llama_cpp.llama_model_load_from_file( + draft_model_path.encode("utf-8"), + model_params, + ) + if draft_llama_model is None: + llama_cpp.llama_batch_free(self.batch) + llama_cpp.llama_free(self.ctx) + self._free_lora_adapters() + llama_cpp.llama_model_free(self.llama_model) + if self.backend_initialized: + llama_cpp.llama_backend_free() + self.backend_initialized = False + raise RuntimeError(f"failed to load MTP draft model: {draft_model_path}") + self.draft_llama_model = draft_llama_model if self.n_ubatch < self.n_seq_max: mtp_n_batch = self.n_batch else: @@ -10644,6 +10720,7 @@ def __init__( rope_scaling_type=rope_scaling_type, pooling_type=pooling_type, attention_type=attention_type, + embedding=embedding, rope_freq_base=rope_freq_base, rope_freq_scale=rope_freq_scale, yarn_ext_factor=yarn_ext_factor, @@ -10659,13 +10736,15 @@ def __init__( type_k=type_k, type_v=type_v, kv_unified=kv_unified, - n_rs_seq=target_n_rs_seq, + n_rs_seq=0, ctx_type=llama_cpp.LLAMA_CONTEXT_TYPE_MTP, n_outputs_max=min(mtp_n_batch, self.n_seq_max), + ctx_other=self.ctx, ) try: self.draft_provider = MTPDraftProvider( model=self, + draft_model=draft_llama_model, context_params=mtp_context_params, num_pred_tokens=draft_model_num_pred_tokens, top_k=draft_model_top_k, @@ -10675,6 +10754,9 @@ def __init__( llama_cpp.llama_batch_free(self.batch) llama_cpp.llama_free(self.ctx) self._free_lora_adapters() + if self.draft_llama_model is not None: + llama_cpp.llama_model_free(self.draft_llama_model) + self.draft_llama_model = None llama_cpp.llama_model_free(self.llama_model) if self.backend_initialized: llama_cpp.llama_backend_free() @@ -10685,7 +10767,10 @@ def __init__( try: self._load_lora_adapters(self.loras) self._apply_lora_adapters(self.ctx, "target") - if isinstance(self.draft_provider, MTPDraftProvider): + if ( + isinstance(self.draft_provider, MTPDraftProvider) + and self.draft_llama_model is None + ): self._apply_lora_adapters(self.draft_provider.ctx, "MTP draft") except BaseException: if self.draft_provider is not None: @@ -10693,6 +10778,9 @@ def __init__( llama_cpp.llama_batch_free(self.batch) llama_cpp.llama_free(self.ctx) self._free_lora_adapters() + if self.draft_llama_model is not None: + llama_cpp.llama_model_free(self.draft_llama_model) + self.draft_llama_model = None llama_cpp.llama_model_free(self.llama_model) if self.backend_initialized: llama_cpp.llama_backend_free() @@ -10786,6 +10874,7 @@ def build_context_params( rope_scaling_type: Optional[int], pooling_type: Optional[int], attention_type: Optional[int], + embedding: bool, rope_freq_base: Optional[float], rope_freq_scale: Optional[float], yarn_ext_factor: Optional[float], @@ -10804,6 +10893,7 @@ def build_context_params( n_rs_seq: Optional[int] = None, ctx_type: Optional[int] = None, n_outputs_max: Optional[int] = None, + ctx_other: Optional[Any] = None, ) -> Any: context_params = llama_cpp.llama_context_default_params() if n_ctx is not None: @@ -10824,12 +10914,15 @@ def build_context_params( context_params.ctx_type = ctx_type if n_outputs_max is not None: context_params.n_outputs_max = n_outputs_max + if ctx_other is not None: + context_params.ctx_other = ctx_other if rope_scaling_type is not None: context_params.rope_scaling_type = rope_scaling_type if pooling_type is not None: context_params.pooling_type = pooling_type if attention_type is not None: context_params.attention_type = attention_type + context_params.embeddings = embedding if rope_freq_base is not None: context_params.rope_freq_base = rope_freq_base if rope_freq_scale is not None: @@ -10935,19 +11028,42 @@ def close(self) -> None: llama_cpp.llama_batch_free(self.batch) llama_cpp.llama_free(self.ctx) self._free_lora_adapters() + if self.draft_llama_model is not None: + llama_cpp.llama_model_free(self.draft_llama_model) + self.draft_llama_model = None llama_cpp.llama_model_free(self.llama_model) if self.backend_initialized: llama_cpp.llama_backend_free() self.backend_initialized = False - def _meta_value(self, key: str) -> Optional[str]: + @staticmethod + def _model_meta_key_by_index(llama_model: Any, index: int) -> Optional[str]: + capacity = 256 + while True: + buffer = ctypes.create_string_buffer(capacity) + count = int( + llama_cpp.llama_model_meta_key_by_index( + llama_model, + index, + cast(Any, buffer), + capacity, + ) + ) + if count < 0: + return None + if count < capacity: + return buffer.value.decode("utf-8", errors="ignore") + capacity = count + 1 + + @staticmethod + def _model_meta_value(llama_model: Any, key: str) -> Optional[str]: encoded = key.encode("utf-8") capacity = 256 while True: buffer = ctypes.create_string_buffer(capacity) count = int( llama_cpp.llama_model_meta_val_str( - self.llama_model, + llama_model, encoded, cast(Any, buffer), capacity, @@ -10959,6 +11075,50 @@ def _meta_value(self, key: str) -> Optional[str]: return buffer.value.decode("utf-8", errors="ignore") capacity = count + 1 + @staticmethod + def _parse_pooling_type(value: str) -> Optional[int]: + normalized = value.strip().lower() + try: + return int(normalized) + except ValueError: + return { + "none": llama_cpp.LLAMA_POOLING_TYPE_NONE, + "mean": llama_cpp.LLAMA_POOLING_TYPE_MEAN, + "cls": llama_cpp.LLAMA_POOLING_TYPE_CLS, + "last": llama_cpp.LLAMA_POOLING_TYPE_LAST, + "rank": llama_cpp.LLAMA_POOLING_TYPE_RANK, + }.get(normalized) + + @classmethod + def detect_embedding_model(cls, llama_model: Any) -> bool: + for index in range(int(llama_cpp.llama_model_meta_count(llama_model))): + key = cls._model_meta_key_by_index(llama_model, index) + if key is None or not key.endswith(".pooling_type"): + continue + value = cls._model_meta_value(llama_model, key) + if value is None: + continue + pooling_type = cls._parse_pooling_type(value) + return pooling_type in { + llama_cpp.LLAMA_POOLING_TYPE_MEAN, + llama_cpp.LLAMA_POOLING_TYPE_CLS, + llama_cpp.LLAMA_POOLING_TYPE_LAST, + } + return False + + @classmethod + def resolve_embedding_mode( + cls, + llama_model: Any, + embedding: Optional[bool], + ) -> bool: + if embedding is not None: + return embedding + return cls.detect_embedding_model(llama_model) + + def _meta_value(self, key: str) -> Optional[str]: + return self._model_meta_value(self.llama_model, key) + def _build_chat_formatter(self) -> Optional[Jinja2ChatFormatter]: template_text = self.chat_template_override if template_text is None: @@ -11131,6 +11291,10 @@ def clear_batch(self) -> None: self._embedding_batch = None self._embedding_batch_refs = [] + def clear_memory(self) -> None: + if self.mem is not None: + llama_cpp.llama_memory_clear(self.mem, True) + def add_batch_tokens( self, *, @@ -11206,9 +11370,14 @@ def add_batch_embeddings( def decode(self) -> None: batch = self._embedding_batch if self._embedding_batch is not None else self.batch - result = int(llama_cpp.llama_decode(self.ctx, batch)) + if self.embedding and self.has_encoder: + operation = "llama_encode" + result = int(llama_cpp.llama_encode(self.ctx, batch)) + else: + operation = "llama_decode" + result = int(llama_cpp.llama_decode(self.ctx, batch)) if result != 0: - raise RuntimeError(f"llama_decode failed with code {result}") + raise RuntimeError(f"{operation} failed with code {result}") def process_draft_batch(self) -> None: if self.draft_provider is not None: @@ -11242,6 +11411,103 @@ def logits(self, output_index: int) -> np.ndarray: raise RuntimeError(f"missing logits output {output_index}") return np.ctypeslib.as_array(ptr, shape=(self.n_vocab,)).copy() + def embed( + self, + inputs: Sequence[Union[str, List[int]]], + ) -> Tuple[List[List[float]], int]: + if not self.embedding: + raise CompletionRequestValidationError( + "model.embedding must be true to use /v1/embeddings" + ) + pooling_type = int(llama_cpp.llama_pooling_type(self.ctx)) + if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE: + raise CompletionRequestValidationError( + "/v1/embeddings requires a pooled embedding model; " + "set model.pooling_type to MEAN, CLS, or LAST" + ) + if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_RANK: + raise CompletionRequestValidationError( + "/v1/embeddings does not support reranking pooling" + ) + if len(inputs) > 2048: + raise CompletionRequestValidationError( + "embedding input batch size exceeds 2048" + ) + + embeddings: List[List[float]] = [] + total_tokens = 0 + batch_sizes: List[int] = [] + batch_token_count = 0 + + def decode_embedding_batch() -> None: + nonlocal batch_token_count + if not batch_sizes: + return + self.clear_memory() + self.decode() + self.clear_batch() + for seq_id in range(len(batch_sizes)): + ptr = llama_cpp.llama_get_embeddings_seq( + self.ctx, + llama_cpp.llama_seq_id(seq_id), + ) + if not ptr: + raise RuntimeError(f"missing embedding output for input {seq_id}") + embeddings.append( + np.ctypeslib.as_array(ptr, shape=(self.n_embd_out,)).astype( + float + ).tolist() + ) + batch_sizes.clear() + batch_token_count = 0 + + try: + self.clear_batch() + self.clear_memory() + for input_item in inputs: + tokens = ( + self.tokenize(input_item) + if isinstance(input_item, str) + else list(input_item) + ) + n_tokens = len(tokens) + if n_tokens == 0: + raise CompletionRequestValidationError( + "embedding input must not be empty" + ) + if n_tokens > self.n_ctx_seq: + raise CompletionRequestValidationError( + f"embedding input has {n_tokens} tokens, exceeding n_ctx_seq ({self.n_ctx_seq})" + ) + if n_tokens > self.n_batch: + raise CompletionRequestValidationError( + f"embedding input has {n_tokens} tokens, exceeding n_batch ({self.n_batch})" + ) + if total_tokens + n_tokens > 300_000: + raise CompletionRequestValidationError( + "embedding request exceeds 300000 total tokens" + ) + if ( + batch_token_count + n_tokens > self.n_batch + or len(batch_sizes) >= self.n_seq_max + ): + decode_embedding_batch() + seq_id = len(batch_sizes) + self.add_batch_tokens( + seq_id=seq_id, + start_pos=0, + tokens=tokens, + output_indices=[0] * n_tokens, + ) + batch_sizes.append(n_tokens) + batch_token_count += n_tokens + total_tokens += n_tokens + decode_embedding_batch() + finally: + self.clear_batch() + self.clear_memory() + return embeddings, total_tokens + class SequenceDiskCache(SequenceCache): """Directory-backed cache for serialized llama.cpp sequence state.""" @@ -12065,6 +12331,37 @@ def build_memory_policy(self) -> MemoryPolicy: return PartitionedAttentionMemoryPolicy(self) return UnifiedAttentionMemoryPolicy(self) + def clear_resident_state(self) -> None: + self.model.clear_memory() + self.model.clear_batch() + self.radix_trie = RadixTrie() + self.sequence_history = SequenceHistory() + self.checkpoint_logits.clear() + self.claimed_sequences.clear() + self.free_sequences.clear() + self.unused_sequences = list(range(self.model.n_seq_max - 1, -1, -1)) + for seq_id in range(self.model.n_seq_max): + self.model.truncate_draft_sequence(seq_id, 0) + + def create_embedding( + self, + payload: CreateEmbeddingRequest, + ) -> CreateEmbeddingResponse: + if not self.is_idle(): + raise RuntimeError("embedding requests require an idle scheduler") + self.clear_resident_state() + try: + embeddings, total_tokens = self.model.embed(payload.normalized_input()) + return CreateEmbeddingResponse.from_embeddings( + model=payload.model, + embeddings=embeddings, + total_tokens=total_tokens, + encoding_format=payload.encoding_format, + dimensions=payload.dimensions, + ) + finally: + self.clear_resident_state() + @staticmethod def request_needs_prompt_logits(request: CompletionRequest) -> bool: return request.payload.max_tokens != 0 and request.effective_max_len > len( @@ -14420,6 +14717,39 @@ def run_callback() -> None: raise error_box["error"] return result_box.get("result") + def call_on_idle_scheduler(self, callback: Callable[[], Any]) -> Any: + result_box: Dict[str, Any] = {} + error_box: Dict[str, BaseException] = {} + done = threading.Event() + + def run_callback() -> None: + if not self.scheduler.is_idle(): + with self.condition: + self.commands.appendleft(run_callback) + self.condition.notify_all() + return + try: + result_box["result"] = callback() + except BaseException as exc: # noqa: BLE001 + error_box["error"] = exc + finally: + done.set() + + self.enqueue(run_callback) + done.wait() + if "error" in error_box: + raise error_box["error"] + return result_box.get("result") + + def create_embedding( + self, + payload: CreateEmbeddingRequest, + ) -> CreateEmbeddingResponse: + embedding = self.call_on_idle_scheduler( + lambda: self.scheduler.create_embedding(payload) + ) + return cast(CreateEmbeddingResponse, embedding) + def render_prometheus_metrics(self) -> str: metrics = self.call_on_scheduler(self.scheduler.render_prometheus_metrics) return cast(str, metrics) @@ -14874,6 +15204,17 @@ async def create_completion( # pyright: ignore[reportUnusedFunction] return result return JSONResponse(result.model_dump(mode="json", exclude_none=True)) + @app.post("/v1/embeddings", response_model=CreateEmbeddingResponse) + async def create_embedding( # pyright: ignore[reportUnusedFunction] + body: CreateEmbeddingRequest, + ) -> JSONResponse: + service: CompletionService = app.state.service + try: + embedding = await asyncio.to_thread(service.create_embedding, body) + except CompletionRequestValidationError as exc: + raise bad_request(exc) from exc + return JSONResponse(embedding.model_dump(mode="json", exclude_none=True)) + @app.post("/v1/chat/completions") async def create_chat_completion( # pyright: ignore[reportUnusedFunction] http_request: Request, body: CreateChatCompletionRequest @@ -15250,6 +15591,7 @@ def main() -> None: rope_scaling_type=config.model.rope_scaling_type, pooling_type=config.model.pooling_type, attention_type=config.model.attention_type, + embedding=config.model.embedding, rope_freq_base=config.model.rope_freq_base, rope_freq_scale=config.model.rope_freq_scale, yarn_ext_factor=config.model.yarn_ext_factor, @@ -15268,6 +15610,7 @@ def main() -> None: max_seq_len=config.model.max_seq_len, max_output_tokens=config.model.max_output_tokens, draft_model=config.model.draft_model, + draft_model_path=config.model.resolve_draft_model_path(), draft_model_num_pred_tokens=config.model.draft_model_num_pred_tokens, draft_model_max_ngram_size=config.model.draft_model_max_ngram_size, draft_model_top_k=config.model.draft_model_top_k, diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 369f24ca8..13668893f 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.3.27" +__version__ = "0.3.28" diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 44c25a519..21f85c81c 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -949,6 +949,10 @@ class llama_sampler_seq_config(ctypes.Structure): # // ref: https://github.com/ggml-org/llama.cpp/pull/14363 # struct llama_sampler_seq_config * samplers; # size_t n_samplers; +# +# // a source/target/parent context +# // can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts +# struct llama_context * ctx_other; # }; class llama_context_params(ctypes.Structure): """Parameters for llama_context @@ -989,6 +993,7 @@ class llama_context_params(ctypes.Structure): kv_unified (bool): use a unified buffer across the input sequences when computing the attention samplers (ctypes.POINTER(llama_sampler_seq_config)): backend sampler chain configuration n_samplers (int): number of backend sampler chain configurations + ctx_other (llama_context_p): source, target, or parent context """ if TYPE_CHECKING: @@ -1027,6 +1032,7 @@ class llama_context_params(ctypes.Structure): kv_unified: bool samplers: ctypes.POINTER(llama_sampler_seq_config) n_samplers: int + ctx_other: llama_context_p _fields_ = [ ("n_ctx", ctypes.c_uint32), @@ -1064,6 +1070,7 @@ class llama_context_params(ctypes.Structure): ("kv_unified", ctypes.c_bool), ("samplers", ctypes.POINTER(llama_sampler_seq_config)), ("n_samplers", ctypes.c_size_t), + ("ctx_other", llama_context_p_ctypes), ] diff --git a/llama_cpp/llama_cpp_ext.py b/llama_cpp/llama_cpp_ext.py index f6ab9197b..284811086 100644 --- a/llama_cpp/llama_cpp_ext.py +++ b/llama_cpp/llama_cpp_ext.py @@ -42,58 +42,76 @@ def decorator(f): return decorator -# LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); +# LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); @_ctypes_function_from_names( ( - "llama_set_embeddings_pre_norm", - "_Z29llama_set_embeddings_pre_normP13llama_contextbb", - "?llama_set_embeddings_pre_norm@@YAXPEAUllama_context@@_N1@Z", + "llama_set_embeddings_nextn", + "_Z26llama_set_embeddings_nextnP13llama_contextbb", + "?llama_set_embeddings_nextn@@YAXPEAUllama_context@@_N1@Z", ), [llama_cpp.llama_context_p_ctypes, ctypes.c_bool, ctypes.c_bool], None, ) -def llama_set_embeddings_pre_norm( +def llama_set_embeddings_nextn( ctx: llama_cpp.llama_context_p, value: bool, masked: bool, /, ): - """Set whether the context outputs pre-norm embeddings or not.""" + """Set whether the context outputs nextn embeddings or not.""" ... -# LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx); +# LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); @_ctypes_function_from_names( ( - "llama_get_embeddings_pre_norm", - "_Z29llama_get_embeddings_pre_normP13llama_context", - "?llama_get_embeddings_pre_norm@@YAPEAMPEAUllama_context@@@Z", + "llama_get_embeddings_nextn", + "_Z26llama_get_embeddings_nextnP13llama_context", + "?llama_get_embeddings_nextn@@YAPEAMPEAUllama_context@@@Z", ), [llama_cpp.llama_context_p_ctypes], ctypes.POINTER(ctypes.c_float), ) -def llama_get_embeddings_pre_norm( +def llama_get_embeddings_nextn( ctx: llama_cpp.llama_context_p, /, ): - """Get the pre-norm embeddings from the last evaluation.""" + """Get the nextn embeddings from the last evaluation.""" ... -# LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); +# LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); @_ctypes_function_from_names( ( - "llama_get_embeddings_pre_norm_ith", - "_Z33llama_get_embeddings_pre_norm_ithP13llama_contexti", - "?llama_get_embeddings_pre_norm_ith@@YAPEAMPEAUllama_context@@H@Z", + "llama_get_embeddings_nextn_ith", + "_Z30llama_get_embeddings_nextn_ithP13llama_contexti", + "?llama_get_embeddings_nextn_ith@@YAPEAMPEAUllama_context@@H@Z", ), [llama_cpp.llama_context_p_ctypes, ctypes.c_int32], ctypes.POINTER(ctypes.c_float), ) -def llama_get_embeddings_pre_norm_ith( +def llama_get_embeddings_nextn_ith( ctx: llama_cpp.llama_context_p, i: Union[ctypes.c_int32, int], /, ): - """Get the pre-norm embeddings for the ith output row from the last evaluation.""" + """Get the nextn embeddings for the ith output row from the last evaluation.""" + ... + + +# LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); +@_ctypes_function_from_names( + ( + "llama_get_ctx_other", + "_Z19llama_get_ctx_otherP13llama_context", + "?llama_get_ctx_other@@YAPEAUllama_context@@PEAU1@@Z", + ), + [llama_cpp.llama_context_p_ctypes], + llama_cpp.llama_context_p_ctypes, +) +def llama_get_ctx_other( + ctx: llama_cpp.llama_context_p, + /, +): + """Get the context linked through llama_context_params.ctx_other.""" ... diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 465b1f0e7..9e3b928fd 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 465b1f0e75c590426cff3ca998bcd25297071a5b +Subproject commit 9e3b928fd8c9d14dbf15a8768b9fdd7e5c721d66