@@ -3223,6 +3223,7 @@ class MTMDOptions(BaseModel):
32233223 embedding_cache: Optional["ConfigFile.MTMDEmbeddingCacheOptions"] = None
32243224 allowed_media_domains: Optional[List[str]] = None
32253225 allowed_local_media_path: Optional[str] = None
3226+ batch_max_tokens: int = Field(default=1024, ge=1)
32263227 image_max_bytes: int = Field(default=20 * 1024 * 1024, ge=1)
32273228 audio_max_bytes: int = Field(default=100 * 1024 * 1024, ge=1)
32283229 video_max_bytes: int = Field(default=512 * 1024 * 1024, ge=1)
@@ -10410,6 +10411,21 @@ class MTMDLoadedMedia:
1041010411
1041110412
1041210413class MTMDProcessor:
10414+ @dataclass
10415+ class MediaChunk:
10416+ kind: Literal["image", "audio", "video"]
10417+ key: str
10418+ chunk: Any
10419+ n_tokens: int
10420+ decode_n_pos: int
10421+ non_causal: bool
10422+ embeddings: Optional[np.ndarray] = None
10423+
10424+ @dataclass
10425+ class ParsedChunk:
10426+ text_tokens: Optional[List[int]] = None
10427+ media: Optional["MTMDProcessor.MediaChunk"] = None
10428+
1041310429 def __init__(
1041410430 self,
1041510431 *,
@@ -10422,6 +10438,7 @@ def __init__(
1042210438 n_ubatch: int,
1042310439 n_threads_batch: int,
1042410440 mmproj_path: str,
10441+ batch_max_tokens: int,
1042510442 embedding_cache: Optional[MTMDEmbeddingCache],
1042610443 allowed_media_domains: Optional[List[str]],
1042710444 allowed_local_media_path: Optional[str],
@@ -10437,6 +10454,7 @@ def __init__(
1043710454 self.n_ubatch = n_ubatch
1043810455 self.mmproj_path = mmproj_path
1043910456 self.embedding_cache = embedding_cache
10457+ self.batch_max_tokens = batch_max_tokens
1044010458 self.model_fingerprint = MTMDEmbeddingCache.fingerprint_file(model_path)
1044110459 self.mmproj_fingerprint = MTMDEmbeddingCache.fingerprint_file(mmproj_path)
1044210460 self.allowed_media_domains = (
@@ -10456,6 +10474,7 @@ def __init__(
1045610474 self.lock = threading.Lock()
1045710475 params = mtmd_cpp.mtmd_context_params_default()
1045810476 params.n_threads = max(1, n_threads_batch)
10477+ params.batch_max_tokens = batch_max_tokens
1045910478 self.ctx = mtmd_cpp.mtmd_init_from_file(
1046010479 mmproj_path.encode("utf-8"),
1046110480 llama_model,
@@ -10705,37 +10724,91 @@ def _media_identity_tokens(
1070510724 tokens.append(-1 - (int.from_bytes(digest[:4], "little") & 0x3FFFFFFF))
1070610725 return tokens
1070710726
10708- def _encode_media_chunk(
10709- self,
10710- *,
10711- kind: Literal["image", "audio", "video"],
10712- key: str,
10713- chunk: Any,
10714- ) -> np.ndarray:
10715- n_tokens = int(mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk))
10716- if self.embedding_cache is not None:
10717- cached = self.embedding_cache.load(key)
10718- if (
10719- cached is not None
10720- and cached.embeddings.shape == (n_tokens, self.n_embd_inp)
10721- ):
10722- return cached.embeddings
10723- result = int(mtmd_cpp.mtmd_encode_chunk(self.ctx, chunk))
10724- if result != 0:
10725- raise CompletionRequestValidationError(
10726- f"failed to encode {kind} chunk: error code {result}"
10727- )
10728- output = mtmd_cpp.mtmd_get_output_embd(self.ctx)
10729- if output is None:
10730- raise CompletionRequestValidationError(f"MTMD {kind} encoder returned no embeddings")
10727+ def _embeddings_from_pointer(self, output: Any, n_tokens: int) -> np.ndarray:
1073110728 flat = np.ctypeslib.as_array(output, shape=(n_tokens * self.n_embd_inp,))
10732- embeddings = np.array(flat, dtype=np.float32, copy=True).reshape(
10729+ return np.array(flat, dtype=np.float32, copy=True).reshape(
1073310730 n_tokens,
1073410731 self.n_embd_inp,
1073510732 )
10736- if self.embedding_cache is not None:
10737- self.embedding_cache.save(key, embeddings)
10738- return embeddings
10733+
10734+ def _load_cached_media_chunk(self, media_chunk: "MTMDProcessor.MediaChunk") -> bool:
10735+ if self.embedding_cache is None:
10736+ return False
10737+ cached = self.embedding_cache.load(media_chunk.key)
10738+ if cached is None or cached.embeddings.shape != (
10739+ media_chunk.n_tokens,
10740+ self.n_embd_inp,
10741+ ):
10742+ return False
10743+ media_chunk.embeddings = cached.embeddings
10744+ return True
10745+
10746+ def _save_media_chunk(self, media_chunk: "MTMDProcessor.MediaChunk") -> None:
10747+ if self.embedding_cache is None or media_chunk.embeddings is None:
10748+ return
10749+ self.embedding_cache.save(media_chunk.key, media_chunk.embeddings)
10750+
10751+ def _encode_media_batch(
10752+ self,
10753+ media_chunks: Sequence["MTMDProcessor.MediaChunk"],
10754+ start_index: int,
10755+ ) -> int:
10756+ batch = mtmd_cpp.mtmd_batch_init(self.ctx)
10757+ if batch is None:
10758+ raise CompletionRequestValidationError("failed to create MTMD media batch")
10759+ try:
10760+ first = media_chunks[start_index]
10761+ result = int(mtmd_cpp.mtmd_batch_add_chunk(batch, first.chunk))
10762+ if result != 0:
10763+ raise CompletionRequestValidationError(
10764+ f"failed to add {first.kind} chunk to MTMD batch: error code {result}"
10765+ )
10766+ group = [first]
10767+ next_index = start_index + 1
10768+ while next_index < len(media_chunks):
10769+ candidate = media_chunks[next_index]
10770+ result = int(mtmd_cpp.mtmd_batch_add_chunk(batch, candidate.chunk))
10771+ if result == 0:
10772+ group.append(candidate)
10773+ next_index += 1
10774+ continue
10775+ if result in {2, 3}:
10776+ break
10777+ raise CompletionRequestValidationError(
10778+ f"failed to add {candidate.kind} chunk to MTMD batch: error code {result}"
10779+ )
10780+ result = int(mtmd_cpp.mtmd_batch_encode(batch))
10781+ if result != 0:
10782+ raise CompletionRequestValidationError(
10783+ f"failed to encode MTMD media batch: error code {result}"
10784+ )
10785+ for media_chunk in group:
10786+ output = mtmd_cpp.mtmd_batch_get_output_embd(batch, media_chunk.chunk)
10787+ if output is None:
10788+ raise CompletionRequestValidationError(
10789+ f"MTMD {media_chunk.kind} encoder returned no embeddings"
10790+ )
10791+ media_chunk.embeddings = self._embeddings_from_pointer(
10792+ output,
10793+ media_chunk.n_tokens,
10794+ )
10795+ self._save_media_chunk(media_chunk)
10796+ return len(group)
10797+ finally:
10798+ mtmd_cpp.mtmd_batch_free(batch)
10799+
10800+ def _encode_media_chunks(
10801+ self,
10802+ media_chunks: Sequence["MTMDProcessor.MediaChunk"],
10803+ ) -> None:
10804+ uncached = [
10805+ media_chunk
10806+ for media_chunk in media_chunks
10807+ if not self._load_cached_media_chunk(media_chunk)
10808+ ]
10809+ index = 0
10810+ while index < len(uncached):
10811+ index += self._encode_media_batch(uncached, index)
1073910812
1074010813 def _positions_for_chunk(self, chunk: Any, start_pos: int) -> np.ndarray:
1074110814 n_tokens = int(mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk))
@@ -10858,12 +10931,8 @@ def _build_prompt_plan_locked(
1085810931 raise CompletionRequestValidationError(
1085910932 f"failed to tokenize MTMD prompt: error code {result}"
1086010933 )
10861- segments: List[PromptSegment] = []
10862- identity_tokens: List[int] = []
10863- text_tokens: List[int] = []
10864- text_token_index_by_pos: Dict[int, int] = {}
10865- identity_pos = 0
10866- decode_pos = 0
10934+ parsed_chunks: List[MTMDProcessor.ParsedChunk] = []
10935+ media_chunks: List[MTMDProcessor.MediaChunk] = []
1086710936 video_index = 0
1086810937 used_media_keys = set()
1086910938 n_chunks = int(mtmd_cpp.mtmd_input_chunks_size(chunks))
@@ -10884,24 +10953,9 @@ def _build_prompt_plan_locked(
1088410953 else []
1088510954 )
1088610955 if tokens:
10887- start_pos = identity_pos
10888- segments.append(
10889- PromptSegment(
10890- kind="text",
10891- start_pos=start_pos,
10892- n_pos=len(tokens),
10893- identity_tokens=list(tokens),
10894- decode_start_pos=decode_pos,
10895- decode_n_pos=len(tokens),
10896- text_tokens=list(tokens),
10897- )
10956+ parsed_chunks.append(
10957+ MTMDProcessor.ParsedChunk(text_tokens=tokens)
1089810958 )
10899- for offset, token in enumerate(tokens):
10900- text_token_index_by_pos[start_pos + offset] = len(text_tokens)
10901- text_tokens.append(token)
10902- identity_tokens.extend(tokens)
10903- identity_pos += len(tokens)
10904- decode_pos += len(tokens)
1090510959 continue
1090610960 if chunk_type == mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_IMAGE:
1090710961 chunk_kind: Literal["image", "audio"] = "image"
@@ -10951,37 +11005,84 @@ def _build_prompt_plan_locked(
1095111005 decode_n_pos = int(mtmd_cpp.mtmd_input_chunk_get_n_pos(chunk))
1095211006 if decode_n_pos <= 0:
1095311007 raise CompletionRequestValidationError("MTMD media chunk has no decoder positions")
10954- embeddings = self._encode_media_chunk(kind=kind, key=key, chunk=chunk)
10955- n_tokens = int(embeddings.shape[0])
11008+ n_tokens = int(mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk))
1095611009 if n_tokens <= 0:
10957- raise CompletionRequestValidationError("MTMD media chunk has no embeddings ")
11010+ raise CompletionRequestValidationError("MTMD media chunk has no embedding tokens ")
1095811011 non_causal = bool(mtmd_cpp.mtmd_decode_use_non_causal(self.ctx, chunk))
10959- segment_identity = self._media_identity_tokens(kind, key, n_tokens)
10960- positions = self._positions_for_chunk(chunk, decode_pos)
10961- segment = PromptSegment(
11012+ media_chunk = MTMDProcessor.MediaChunk(
1096211013 kind=kind,
10963- start_pos=identity_pos,
10964- n_pos=n_tokens,
10965- identity_tokens=segment_identity,
10966- decode_start_pos=decode_pos,
11014+ key=key,
11015+ chunk=chunk,
11016+ n_tokens=n_tokens,
1096711017 decode_n_pos=decode_n_pos,
10968- media=PromptSegment.Media(
10969- embeddings=embeddings,
10970- positions=positions,
10971- non_causal=non_causal,
10972- ),
11018+ non_causal=non_causal,
1097311019 )
10974- if non_causal and embeddings.shape[0] > min(self.n_batch, self.n_ubatch):
11020+ parsed_chunks.append(MTMDProcessor.ParsedChunk(media=media_chunk))
11021+ media_chunks.append(media_chunk)
11022+ if used_media_keys != {media.key for media in loaded_media}:
11023+ raise CompletionRequestValidationError("not all media inputs were consumed by MTMD")
11024+ self._encode_media_chunks(media_chunks)
11025+ segments: List[PromptSegment] = []
11026+ identity_tokens: List[int] = []
11027+ text_tokens: List[int] = []
11028+ text_token_index_by_pos: Dict[int, int] = {}
11029+ identity_pos = 0
11030+ decode_pos = 0
11031+ for parsed_chunk in parsed_chunks:
11032+ if parsed_chunk.text_tokens is not None:
11033+ tokens = parsed_chunk.text_tokens
11034+ start_pos = identity_pos
11035+ segments.append(
11036+ PromptSegment(
11037+ kind="text",
11038+ start_pos=start_pos,
11039+ n_pos=len(tokens),
11040+ identity_tokens=list(tokens),
11041+ decode_start_pos=decode_pos,
11042+ decode_n_pos=len(tokens),
11043+ text_tokens=list(tokens),
11044+ )
11045+ )
11046+ for offset, token in enumerate(tokens):
11047+ text_token_index_by_pos[start_pos + offset] = len(text_tokens)
11048+ text_tokens.append(token)
11049+ identity_tokens.extend(tokens)
11050+ identity_pos += len(tokens)
11051+ decode_pos += len(tokens)
11052+ continue
11053+ media_chunk = parsed_chunk.media
11054+ if media_chunk is None or media_chunk.embeddings is None:
11055+ raise CompletionRequestValidationError("MTMD media chunk has no embeddings")
11056+ embeddings = media_chunk.embeddings
11057+ if media_chunk.non_causal and embeddings.shape[0] > min(self.n_batch, self.n_ubatch):
1097511058 raise CompletionRequestValidationError(
10976- f"non-causal {kind} embedding chunk exceeds model batch limits; "
11059+ f"non-causal {media_chunk. kind} embedding chunk exceeds model batch limits; "
1097711060 "increase n_batch and n_ubatch"
1097811061 )
10979- segments.append(segment)
11062+ segment_identity = self._media_identity_tokens(
11063+ media_chunk.kind,
11064+ media_chunk.key,
11065+ media_chunk.n_tokens,
11066+ )
11067+ positions = self._positions_for_chunk(media_chunk.chunk, decode_pos)
11068+ segments.append(
11069+ PromptSegment(
11070+ kind=media_chunk.kind,
11071+ start_pos=identity_pos,
11072+ n_pos=media_chunk.n_tokens,
11073+ identity_tokens=segment_identity,
11074+ decode_start_pos=decode_pos,
11075+ decode_n_pos=media_chunk.decode_n_pos,
11076+ media=PromptSegment.Media(
11077+ embeddings=embeddings,
11078+ positions=positions,
11079+ non_causal=media_chunk.non_causal,
11080+ ),
11081+ )
11082+ )
1098011083 identity_tokens.extend(segment_identity)
10981- identity_pos += n_tokens
10982- decode_pos += decode_n_pos
10983- if used_media_keys != {media.key for media in loaded_media}:
10984- raise CompletionRequestValidationError("not all media inputs were consumed by MTMD")
11084+ identity_pos += media_chunk.n_tokens
11085+ decode_pos += media_chunk.decode_n_pos
1098511086 return PromptPlan(
1098611087 text=prompt,
1098711088 generation_prompt=generation_prompt,
@@ -16211,6 +16312,7 @@ def main() -> None:
1621116312 n_ubatch=model.n_ubatch,
1621216313 n_threads_batch=model.n_threads_batch,
1621316314 mmproj_path=mmproj_path,
16315+ batch_max_tokens=config.model.mtmd.batch_max_tokens,
1621416316 embedding_cache=embedding_cache,
1621516317 allowed_media_domains=config.model.mtmd.allowed_media_domains,
1621616318 allowed_local_media_path=config.model.mtmd.allowed_local_media_path,
0 commit comments