From b815b3731d7d4795a9abcfcfc9c862f1c82ee4fe Mon Sep 17 00:00:00 2001 From: OceanLi <122793010+ohdearquant@users.noreply.github.com> Date: Mon, 25 May 2026 12:32:19 -0400 Subject: [PATCH] fix(inference): inject Qwen BPE EOS and align GPT-4 pre-tokenizer regex Parse post_processor for EOS injection (id=151643 from TemplateProcessing) and detect regex-based Split pre-tokenizer in tokenizer.json. Implements a hand-coded GPT-4 regex pattern that attaches leading punctuation to following letter runs (e.g. ".com", "/path"), matching HF behavior. Qwen3-Embedding-0.6B parity moves from 0/10 to 10/10. Co-Authored-By: Claude Opus 4.6 --- crates/inference/src/tokenizer/bpe.rs | 229 +++++++++++++++++++++++++- 1 file changed, 227 insertions(+), 2 deletions(-) diff --git a/crates/inference/src/tokenizer/bpe.rs b/crates/inference/src/tokenizer/bpe.rs index 96472f72..2d8d8e49 100644 --- a/crates/inference/src/tokenizer/bpe.rs +++ b/crates/inference/src/tokenizer/bpe.rs @@ -8,7 +8,7 @@ use crate::error::InferenceError; use crate::tokenizer::common::{ JsonValue, ThreadSafeLruCache, TokenizedInput, Tokenizer, invert_vocab, json_object_to_vocab, json_path, known_special_id, pad_ids, parse_added_tokens, parse_json, - push_eos_preserving_limit, vocab_txt_to_map, + parse_post_processor_flags, push_eos_preserving_limit, vocab_txt_to_map, }; use std::cmp::Ordering; use std::collections::{BinaryHeap, HashMap}; @@ -20,6 +20,12 @@ use tracing::warn; const DEFAULT_BPE_CACHE_CAPACITY: usize = 8_192; const DEFAULT_BPE_MAX_SEQ_LEN: usize = 4_096; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PreTokenizeMode { + ByteLevel, + Gpt4Regex, +} + /// **Unstable**: byte-level BPE tokenizer for Qwen-family models; API evolving. #[derive(Debug, Clone)] pub struct BpeTokenizer { @@ -40,6 +46,7 @@ struct BpeInner { eos_id: Option, add_bos: bool, add_eos: bool, + pre_tokenize_mode: PreTokenizeMode, max_seq_len: usize, cache: ThreadSafeLruCache>, } @@ -160,6 +167,18 @@ impl BpeTokenizer { tokenizer = tokenizer.with_unk_token(unk_token); } + let pp = parse_post_processor_flags(&root); + if pp.add_eos { + tokenizer = tokenizer.with_add_eos(); + if let Some(eos_id) = pp.eos_id { + tokenizer = tokenizer.with_eos_id(eos_id); + } + } + + if detect_gpt4_regex_pretokenizer(&root) { + tokenizer = tokenizer.with_pre_tokenize_mode(PreTokenizeMode::Gpt4Regex); + } + Ok(tokenizer) } @@ -230,6 +249,7 @@ impl BpeTokenizer { eos_id, add_bos: false, add_eos: false, + pre_tokenize_mode: PreTokenizeMode::ByteLevel, max_seq_len, cache: ThreadSafeLruCache::new(cache_capacity), }; @@ -254,6 +274,7 @@ impl BpeTokenizer { eos_id: self.inner.eos_id, add_bos: self.inner.add_bos, add_eos: self.inner.add_eos, + pre_tokenize_mode: self.inner.pre_tokenize_mode, max_seq_len, cache: self.inner.cache.clone(), }; @@ -278,6 +299,7 @@ impl BpeTokenizer { eos_id: self.inner.eos_id, add_bos: self.inner.add_bos, add_eos: self.inner.add_eos, + pre_tokenize_mode: self.inner.pre_tokenize_mode, max_seq_len: self.inner.max_seq_len, cache: self.inner.cache.clone(), }; @@ -306,6 +328,53 @@ impl BpeTokenizer { eos_id: self.inner.eos_id, add_bos: self.inner.add_bos, add_eos: true, + pre_tokenize_mode: self.inner.pre_tokenize_mode, + max_seq_len: self.inner.max_seq_len, + cache: self.inner.cache.clone(), + }; + Self { + inner: Arc::new(inner), + } + } + + fn with_eos_id(self, eos_id: u32) -> Self { + let inner = BpeInner { + vocab: self.inner.vocab.clone(), + id_to_token: self.inner.id_to_token.clone(), + merges: self.inner.merges.clone(), + byte_encoder: self.inner.byte_encoder.clone(), + special_tokens: self.inner.special_tokens.clone(), + special_tokens_sorted: self.inner.special_tokens_sorted.clone(), + pad_id: self.inner.pad_id, + unk_id: self.inner.unk_id, + bos_id: self.inner.bos_id, + eos_id: Some(eos_id), + add_bos: self.inner.add_bos, + add_eos: self.inner.add_eos, + pre_tokenize_mode: self.inner.pre_tokenize_mode, + max_seq_len: self.inner.max_seq_len, + cache: self.inner.cache.clone(), + }; + Self { + inner: Arc::new(inner), + } + } + + fn with_pre_tokenize_mode(self, mode: PreTokenizeMode) -> Self { + let inner = BpeInner { + vocab: self.inner.vocab.clone(), + id_to_token: self.inner.id_to_token.clone(), + merges: self.inner.merges.clone(), + byte_encoder: self.inner.byte_encoder.clone(), + special_tokens: self.inner.special_tokens.clone(), + special_tokens_sorted: self.inner.special_tokens_sorted.clone(), + pad_id: self.inner.pad_id, + unk_id: self.inner.unk_id, + bos_id: self.inner.bos_id, + eos_id: self.inner.eos_id, + add_bos: self.inner.add_bos, + add_eos: self.inner.add_eos, + pre_tokenize_mode: mode, max_seq_len: self.inner.max_seq_len, cache: self.inner.cache.clone(), }; @@ -427,7 +496,11 @@ impl BpeTokenizer { } fn tokenize_regular_segment_into(&self, text: &str, scratch: &mut TokenizeScratch) { - for piece in byte_level_pretokenize(text) { + let pieces = match self.inner.pre_tokenize_mode { + PreTokenizeMode::ByteLevel => byte_level_pretokenize(text), + PreTokenizeMode::Gpt4Regex => gpt4_regex_pretokenize(text), + }; + for piece in pieces { if let Some(cached) = self.inner.cache.get(&piece) { scratch.ids.extend(cached.iter().copied()); continue; @@ -858,6 +931,158 @@ pub fn byte_decode_token(token_str: &str) -> String { String::from_utf8_lossy(&bytes).to_string() } +fn detect_gpt4_regex_pretokenizer(root: &JsonValue) -> bool { + let Some(pt) = root.get("pre_tokenizer") else { + return false; + }; + has_regex_split(pt) +} + +fn has_regex_split(pt: &JsonValue) -> bool { + let pt_type = pt.get("type").and_then(JsonValue::as_str).unwrap_or(""); + match pt_type { + "Split" => pt.get("pattern").and_then(|p| p.get("Regex")).is_some(), + "Sequence" => pt + .get("pretokenizers") + .and_then(JsonValue::as_array) + .is_some_and(|arr| arr.iter().any(has_regex_split)), + _ => false, + } +} + +fn gpt4_regex_pretokenize(text: &str) -> Vec { + let chars: Vec = text.chars().collect(); + let mut pieces = Vec::new(); + let mut pos = 0; + + while pos < chars.len() { + if let Some(end) = try_contraction(&chars, pos) { + pieces.push(chars[pos..end].iter().collect()); + pos = end; + } else if let Some(end) = try_prefix_letters(&chars, pos) { + pieces.push(chars[pos..end].iter().collect()); + pos = end; + } else if chars[pos].is_numeric() { + pieces.push(chars[pos].to_string()); + pos += 1; + } else if let Some(end) = try_punctuation_run(&chars, pos) { + pieces.push(chars[pos..end].iter().collect()); + pos = end; + } else if let Some(end) = try_newline_run(&chars, pos) { + pieces.push(chars[pos..end].iter().collect()); + pos = end; + } else if let Some(end) = try_trailing_ws(&chars, pos) { + pieces.push(chars[pos..end].iter().collect()); + pos = end; + } else if chars[pos].is_whitespace() { + let start = pos; + while pos < chars.len() && chars[pos].is_whitespace() { + pos += 1; + } + pieces.push(chars[start..pos].iter().collect()); + } else { + pieces.push(chars[pos].to_string()); + pos += 1; + } + } + + pieces +} + +fn eq_ci(a: char, lower: char) -> bool { + a.to_ascii_lowercase() == lower +} + +/// `(?i:'s|'t|'re|'ve|'m|'ll|'d)` +fn try_contraction(chars: &[char], pos: usize) -> Option { + if chars.get(pos).copied() != Some('\'') { + return None; + } + let rest = &chars[pos + 1..]; + if rest.len() >= 2 && eq_ci(rest[0], 'l') && eq_ci(rest[1], 'l') { + return Some(pos + 3); + } + if rest.len() >= 2 && eq_ci(rest[0], 'r') && eq_ci(rest[1], 'e') { + return Some(pos + 3); + } + if rest.len() >= 2 && eq_ci(rest[0], 'v') && eq_ci(rest[1], 'e') { + return Some(pos + 3); + } + if !rest.is_empty() { + let c = rest[0].to_ascii_lowercase(); + if matches!(c, 's' | 't' | 'm' | 'd') { + return Some(pos + 2); + } + } + None +} + +/// `[^\r\n\p{L}\p{N}]?\p{L}+` +fn try_prefix_letters(chars: &[char], pos: usize) -> Option { + let mut i = pos; + if i < chars.len() + && !chars[i].is_alphabetic() + && !chars[i].is_numeric() + && chars[i] != '\r' + && chars[i] != '\n' + { + i += 1; + } + let start = i; + while i < chars.len() && chars[i].is_alphabetic() { + i += 1; + } + if i > start { Some(i) } else { None } +} + +/// ` ?[^\s\p{L}\p{N}]+[\r\n]*` +fn try_punctuation_run(chars: &[char], pos: usize) -> Option { + let mut i = pos; + if i < chars.len() && chars[i] == ' ' { + i += 1; + } + let start = i; + while i < chars.len() + && !chars[i].is_whitespace() + && !chars[i].is_alphabetic() + && !chars[i].is_numeric() + { + i += 1; + } + if i == start { + return None; + } + while i < chars.len() && (chars[i] == '\r' || chars[i] == '\n') { + i += 1; + } + Some(i) +} + +/// `\s*[\r\n]+` +fn try_newline_run(chars: &[char], pos: usize) -> Option { + let mut i = pos; + while i < chars.len() && chars[i].is_whitespace() && chars[i] != '\r' && chars[i] != '\n' { + i += 1; + } + let nl_start = i; + while i < chars.len() && (chars[i] == '\r' || chars[i] == '\n') { + i += 1; + } + if i > nl_start { Some(i) } else { None } +} + +/// `\s+(?!\S)` — whitespace not followed by non-whitespace (i.e. trailing) +fn try_trailing_ws(chars: &[char], pos: usize) -> Option { + if !chars.get(pos).is_some_and(|c| c.is_whitespace()) { + return None; + } + let mut i = pos; + while i < chars.len() && chars[i].is_whitespace() { + i += 1; + } + if i == chars.len() { Some(i) } else { None } +} + #[cfg(test)] mod tests { use super::*;