Skip to content

Commit a52702f

Browse files
authored
feat(example): use MTMD batch encoding (#2301)
1 parent 565d3c5 commit a52702f

3 files changed

Lines changed: 182 additions & 75 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- feat(example): use MTMD batch encoding by @abetlen in #2301
1011
- feat(example): support server video inputs and Gemma text tool calls by @abetlen in #2291
1112
- feat: update llama.cpp to ggml-org/llama.cpp@f05cf4676
1213
- fix(example): support multi-step Responses tool streaming by @abetlen in #2288

examples/server/README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ See [Hugging Face response parsing](https://huggingface.co/docs/transformers/cha
291291

292292
## Multimodal `model.mtmd`
293293

294-
`model.mtmd` loads a llama.cpp multimodal projector and enables OpenAI-style image and audio content parts.
294+
`model.mtmd` loads a llama.cpp multimodal projector and enables OpenAI-style image, audio, and video content parts.
295295

296296
```json
297297
{
@@ -305,8 +305,10 @@ See [Hugging Face response parsing](https://huggingface.co/docs/transformers/cha
305305
"path": ".cache/mtmd-embeddings",
306306
"max_bytes": 1073741824
307307
},
308+
"batch_max_tokens": 1024,
308309
"image_max_bytes": 20971520,
309310
"audio_max_bytes": 104857600,
311+
"video_max_bytes": 536870912,
310312
"image_timeout_seconds": 10.0
311313
}
312314
}
@@ -317,11 +319,13 @@ See [Hugging Face response parsing](https://huggingface.co/docs/transformers/cha
317319
| --- | --- |
318320
| `mmproj_path` | Local multimodal projector path. |
319321
| `mmproj_from_pretrained` | Hugging Face projector source. |
320-
| `embedding_cache.path` | Directory for cached image and audio embeddings. |
322+
| `embedding_cache.path` | Directory for cached image, audio, and video embeddings. |
321323
| `embedding_cache.max_bytes` | Maximum embedding cache size. |
324+
| `batch_max_tokens` | Maximum number of media output tokens per MTMD projector-side encode batch. |
322325
| `image_max_bytes` | Maximum image payload size. |
323326
| `audio_max_bytes` | Maximum audio payload size. |
324-
| `image_timeout_seconds` | Timeout for remote image and audio URL fetches. |
327+
| `video_max_bytes` | Maximum video payload size. |
328+
| `image_timeout_seconds` | Timeout for remote image, audio, and video URL fetches. |
325329

326330
Send image inputs with OpenAI chat content parts.
327331

examples/server/server.py

Lines changed: 174 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3223,6 +3223,7 @@ class MTMDOptions(BaseModel):
32233223
embedding_cache: Optional["ConfigFile.MTMDEmbeddingCacheOptions"] = None
32243224
allowed_media_domains: Optional[List[str]] = None
32253225
allowed_local_media_path: Optional[str] = None
3226+
batch_max_tokens: int = Field(default=1024, ge=1)
32263227
image_max_bytes: int = Field(default=20 * 1024 * 1024, ge=1)
32273228
audio_max_bytes: int = Field(default=100 * 1024 * 1024, ge=1)
32283229
video_max_bytes: int = Field(default=512 * 1024 * 1024, ge=1)
@@ -10410,6 +10411,21 @@ class MTMDLoadedMedia:
1041010411

1041110412

1041210413
class MTMDProcessor:
10414+
@dataclass
10415+
class MediaChunk:
10416+
kind: Literal["image", "audio", "video"]
10417+
key: str
10418+
chunk: Any
10419+
n_tokens: int
10420+
decode_n_pos: int
10421+
non_causal: bool
10422+
embeddings: Optional[np.ndarray] = None
10423+
10424+
@dataclass
10425+
class ParsedChunk:
10426+
text_tokens: Optional[List[int]] = None
10427+
media: Optional["MTMDProcessor.MediaChunk"] = None
10428+
1041310429
def __init__(
1041410430
self,
1041510431
*,
@@ -10422,6 +10438,7 @@ def __init__(
1042210438
n_ubatch: int,
1042310439
n_threads_batch: int,
1042410440
mmproj_path: str,
10441+
batch_max_tokens: int,
1042510442
embedding_cache: Optional[MTMDEmbeddingCache],
1042610443
allowed_media_domains: Optional[List[str]],
1042710444
allowed_local_media_path: Optional[str],
@@ -10437,6 +10454,7 @@ def __init__(
1043710454
self.n_ubatch = n_ubatch
1043810455
self.mmproj_path = mmproj_path
1043910456
self.embedding_cache = embedding_cache
10457+
self.batch_max_tokens = batch_max_tokens
1044010458
self.model_fingerprint = MTMDEmbeddingCache.fingerprint_file(model_path)
1044110459
self.mmproj_fingerprint = MTMDEmbeddingCache.fingerprint_file(mmproj_path)
1044210460
self.allowed_media_domains = (
@@ -10456,6 +10474,7 @@ def __init__(
1045610474
self.lock = threading.Lock()
1045710475
params = mtmd_cpp.mtmd_context_params_default()
1045810476
params.n_threads = max(1, n_threads_batch)
10477+
params.batch_max_tokens = batch_max_tokens
1045910478
self.ctx = mtmd_cpp.mtmd_init_from_file(
1046010479
mmproj_path.encode("utf-8"),
1046110480
llama_model,
@@ -10705,37 +10724,91 @@ def _media_identity_tokens(
1070510724
tokens.append(-1 - (int.from_bytes(digest[:4], "little") & 0x3FFFFFFF))
1070610725
return tokens
1070710726

10708-
def _encode_media_chunk(
10709-
self,
10710-
*,
10711-
kind: Literal["image", "audio", "video"],
10712-
key: str,
10713-
chunk: Any,
10714-
) -> np.ndarray:
10715-
n_tokens = int(mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk))
10716-
if self.embedding_cache is not None:
10717-
cached = self.embedding_cache.load(key)
10718-
if (
10719-
cached is not None
10720-
and cached.embeddings.shape == (n_tokens, self.n_embd_inp)
10721-
):
10722-
return cached.embeddings
10723-
result = int(mtmd_cpp.mtmd_encode_chunk(self.ctx, chunk))
10724-
if result != 0:
10725-
raise CompletionRequestValidationError(
10726-
f"failed to encode {kind} chunk: error code {result}"
10727-
)
10728-
output = mtmd_cpp.mtmd_get_output_embd(self.ctx)
10729-
if output is None:
10730-
raise CompletionRequestValidationError(f"MTMD {kind} encoder returned no embeddings")
10727+
def _embeddings_from_pointer(self, output: Any, n_tokens: int) -> np.ndarray:
1073110728
flat = np.ctypeslib.as_array(output, shape=(n_tokens * self.n_embd_inp,))
10732-
embeddings = np.array(flat, dtype=np.float32, copy=True).reshape(
10729+
return np.array(flat, dtype=np.float32, copy=True).reshape(
1073310730
n_tokens,
1073410731
self.n_embd_inp,
1073510732
)
10736-
if self.embedding_cache is not None:
10737-
self.embedding_cache.save(key, embeddings)
10738-
return embeddings
10733+
10734+
def _load_cached_media_chunk(self, media_chunk: "MTMDProcessor.MediaChunk") -> bool:
10735+
if self.embedding_cache is None:
10736+
return False
10737+
cached = self.embedding_cache.load(media_chunk.key)
10738+
if cached is None or cached.embeddings.shape != (
10739+
media_chunk.n_tokens,
10740+
self.n_embd_inp,
10741+
):
10742+
return False
10743+
media_chunk.embeddings = cached.embeddings
10744+
return True
10745+
10746+
def _save_media_chunk(self, media_chunk: "MTMDProcessor.MediaChunk") -> None:
10747+
if self.embedding_cache is None or media_chunk.embeddings is None:
10748+
return
10749+
self.embedding_cache.save(media_chunk.key, media_chunk.embeddings)
10750+
10751+
def _encode_media_batch(
10752+
self,
10753+
media_chunks: Sequence["MTMDProcessor.MediaChunk"],
10754+
start_index: int,
10755+
) -> int:
10756+
batch = mtmd_cpp.mtmd_batch_init(self.ctx)
10757+
if batch is None:
10758+
raise CompletionRequestValidationError("failed to create MTMD media batch")
10759+
try:
10760+
first = media_chunks[start_index]
10761+
result = int(mtmd_cpp.mtmd_batch_add_chunk(batch, first.chunk))
10762+
if result != 0:
10763+
raise CompletionRequestValidationError(
10764+
f"failed to add {first.kind} chunk to MTMD batch: error code {result}"
10765+
)
10766+
group = [first]
10767+
next_index = start_index + 1
10768+
while next_index < len(media_chunks):
10769+
candidate = media_chunks[next_index]
10770+
result = int(mtmd_cpp.mtmd_batch_add_chunk(batch, candidate.chunk))
10771+
if result == 0:
10772+
group.append(candidate)
10773+
next_index += 1
10774+
continue
10775+
if result in {2, 3}:
10776+
break
10777+
raise CompletionRequestValidationError(
10778+
f"failed to add {candidate.kind} chunk to MTMD batch: error code {result}"
10779+
)
10780+
result = int(mtmd_cpp.mtmd_batch_encode(batch))
10781+
if result != 0:
10782+
raise CompletionRequestValidationError(
10783+
f"failed to encode MTMD media batch: error code {result}"
10784+
)
10785+
for media_chunk in group:
10786+
output = mtmd_cpp.mtmd_batch_get_output_embd(batch, media_chunk.chunk)
10787+
if output is None:
10788+
raise CompletionRequestValidationError(
10789+
f"MTMD {media_chunk.kind} encoder returned no embeddings"
10790+
)
10791+
media_chunk.embeddings = self._embeddings_from_pointer(
10792+
output,
10793+
media_chunk.n_tokens,
10794+
)
10795+
self._save_media_chunk(media_chunk)
10796+
return len(group)
10797+
finally:
10798+
mtmd_cpp.mtmd_batch_free(batch)
10799+
10800+
def _encode_media_chunks(
10801+
self,
10802+
media_chunks: Sequence["MTMDProcessor.MediaChunk"],
10803+
) -> None:
10804+
uncached = [
10805+
media_chunk
10806+
for media_chunk in media_chunks
10807+
if not self._load_cached_media_chunk(media_chunk)
10808+
]
10809+
index = 0
10810+
while index < len(uncached):
10811+
index += self._encode_media_batch(uncached, index)
1073910812

1074010813
def _positions_for_chunk(self, chunk: Any, start_pos: int) -> np.ndarray:
1074110814
n_tokens = int(mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk))
@@ -10858,12 +10931,8 @@ def _build_prompt_plan_locked(
1085810931
raise CompletionRequestValidationError(
1085910932
f"failed to tokenize MTMD prompt: error code {result}"
1086010933
)
10861-
segments: List[PromptSegment] = []
10862-
identity_tokens: List[int] = []
10863-
text_tokens: List[int] = []
10864-
text_token_index_by_pos: Dict[int, int] = {}
10865-
identity_pos = 0
10866-
decode_pos = 0
10934+
parsed_chunks: List[MTMDProcessor.ParsedChunk] = []
10935+
media_chunks: List[MTMDProcessor.MediaChunk] = []
1086710936
video_index = 0
1086810937
used_media_keys = set()
1086910938
n_chunks = int(mtmd_cpp.mtmd_input_chunks_size(chunks))
@@ -10884,24 +10953,9 @@ def _build_prompt_plan_locked(
1088410953
else []
1088510954
)
1088610955
if tokens:
10887-
start_pos = identity_pos
10888-
segments.append(
10889-
PromptSegment(
10890-
kind="text",
10891-
start_pos=start_pos,
10892-
n_pos=len(tokens),
10893-
identity_tokens=list(tokens),
10894-
decode_start_pos=decode_pos,
10895-
decode_n_pos=len(tokens),
10896-
text_tokens=list(tokens),
10897-
)
10956+
parsed_chunks.append(
10957+
MTMDProcessor.ParsedChunk(text_tokens=tokens)
1089810958
)
10899-
for offset, token in enumerate(tokens):
10900-
text_token_index_by_pos[start_pos + offset] = len(text_tokens)
10901-
text_tokens.append(token)
10902-
identity_tokens.extend(tokens)
10903-
identity_pos += len(tokens)
10904-
decode_pos += len(tokens)
1090510959
continue
1090610960
if chunk_type == mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_IMAGE:
1090710961
chunk_kind: Literal["image", "audio"] = "image"
@@ -10951,37 +11005,84 @@ def _build_prompt_plan_locked(
1095111005
decode_n_pos = int(mtmd_cpp.mtmd_input_chunk_get_n_pos(chunk))
1095211006
if decode_n_pos <= 0:
1095311007
raise CompletionRequestValidationError("MTMD media chunk has no decoder positions")
10954-
embeddings = self._encode_media_chunk(kind=kind, key=key, chunk=chunk)
10955-
n_tokens = int(embeddings.shape[0])
11008+
n_tokens = int(mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk))
1095611009
if n_tokens <= 0:
10957-
raise CompletionRequestValidationError("MTMD media chunk has no embeddings")
11010+
raise CompletionRequestValidationError("MTMD media chunk has no embedding tokens")
1095811011
non_causal = bool(mtmd_cpp.mtmd_decode_use_non_causal(self.ctx, chunk))
10959-
segment_identity = self._media_identity_tokens(kind, key, n_tokens)
10960-
positions = self._positions_for_chunk(chunk, decode_pos)
10961-
segment = PromptSegment(
11012+
media_chunk = MTMDProcessor.MediaChunk(
1096211013
kind=kind,
10963-
start_pos=identity_pos,
10964-
n_pos=n_tokens,
10965-
identity_tokens=segment_identity,
10966-
decode_start_pos=decode_pos,
11014+
key=key,
11015+
chunk=chunk,
11016+
n_tokens=n_tokens,
1096711017
decode_n_pos=decode_n_pos,
10968-
media=PromptSegment.Media(
10969-
embeddings=embeddings,
10970-
positions=positions,
10971-
non_causal=non_causal,
10972-
),
11018+
non_causal=non_causal,
1097311019
)
10974-
if non_causal and embeddings.shape[0] > min(self.n_batch, self.n_ubatch):
11020+
parsed_chunks.append(MTMDProcessor.ParsedChunk(media=media_chunk))
11021+
media_chunks.append(media_chunk)
11022+
if used_media_keys != {media.key for media in loaded_media}:
11023+
raise CompletionRequestValidationError("not all media inputs were consumed by MTMD")
11024+
self._encode_media_chunks(media_chunks)
11025+
segments: List[PromptSegment] = []
11026+
identity_tokens: List[int] = []
11027+
text_tokens: List[int] = []
11028+
text_token_index_by_pos: Dict[int, int] = {}
11029+
identity_pos = 0
11030+
decode_pos = 0
11031+
for parsed_chunk in parsed_chunks:
11032+
if parsed_chunk.text_tokens is not None:
11033+
tokens = parsed_chunk.text_tokens
11034+
start_pos = identity_pos
11035+
segments.append(
11036+
PromptSegment(
11037+
kind="text",
11038+
start_pos=start_pos,
11039+
n_pos=len(tokens),
11040+
identity_tokens=list(tokens),
11041+
decode_start_pos=decode_pos,
11042+
decode_n_pos=len(tokens),
11043+
text_tokens=list(tokens),
11044+
)
11045+
)
11046+
for offset, token in enumerate(tokens):
11047+
text_token_index_by_pos[start_pos + offset] = len(text_tokens)
11048+
text_tokens.append(token)
11049+
identity_tokens.extend(tokens)
11050+
identity_pos += len(tokens)
11051+
decode_pos += len(tokens)
11052+
continue
11053+
media_chunk = parsed_chunk.media
11054+
if media_chunk is None or media_chunk.embeddings is None:
11055+
raise CompletionRequestValidationError("MTMD media chunk has no embeddings")
11056+
embeddings = media_chunk.embeddings
11057+
if media_chunk.non_causal and embeddings.shape[0] > min(self.n_batch, self.n_ubatch):
1097511058
raise CompletionRequestValidationError(
10976-
f"non-causal {kind} embedding chunk exceeds model batch limits; "
11059+
f"non-causal {media_chunk.kind} embedding chunk exceeds model batch limits; "
1097711060
"increase n_batch and n_ubatch"
1097811061
)
10979-
segments.append(segment)
11062+
segment_identity = self._media_identity_tokens(
11063+
media_chunk.kind,
11064+
media_chunk.key,
11065+
media_chunk.n_tokens,
11066+
)
11067+
positions = self._positions_for_chunk(media_chunk.chunk, decode_pos)
11068+
segments.append(
11069+
PromptSegment(
11070+
kind=media_chunk.kind,
11071+
start_pos=identity_pos,
11072+
n_pos=media_chunk.n_tokens,
11073+
identity_tokens=segment_identity,
11074+
decode_start_pos=decode_pos,
11075+
decode_n_pos=media_chunk.decode_n_pos,
11076+
media=PromptSegment.Media(
11077+
embeddings=embeddings,
11078+
positions=positions,
11079+
non_causal=media_chunk.non_causal,
11080+
),
11081+
)
11082+
)
1098011083
identity_tokens.extend(segment_identity)
10981-
identity_pos += n_tokens
10982-
decode_pos += decode_n_pos
10983-
if used_media_keys != {media.key for media in loaded_media}:
10984-
raise CompletionRequestValidationError("not all media inputs were consumed by MTMD")
11084+
identity_pos += media_chunk.n_tokens
11085+
decode_pos += media_chunk.decode_n_pos
1098511086
return PromptPlan(
1098611087
text=prompt,
1098711088
generation_prompt=generation_prompt,
@@ -16211,6 +16312,7 @@ def main() -> None:
1621116312
n_ubatch=model.n_ubatch,
1621216313
n_threads_batch=model.n_threads_batch,
1621316314
mmproj_path=mmproj_path,
16315+
batch_max_tokens=config.model.mtmd.batch_max_tokens,
1621416316
embedding_cache=embedding_cache,
1621516317
allowed_media_domains=config.model.mtmd.allowed_media_domains,
1621616318
allowed_local_media_path=config.model.mtmd.allowed_local_media_path,

0 commit comments

Comments
 (0)