Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 227 additions & 2 deletions crates/inference/src/tokenizer/bpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::error::InferenceError;
use crate::tokenizer::common::{
JsonValue, ThreadSafeLruCache, TokenizedInput, Tokenizer, invert_vocab, json_object_to_vocab,
json_path, known_special_id, pad_ids, parse_added_tokens, parse_json,
push_eos_preserving_limit, vocab_txt_to_map,
parse_post_processor_flags, push_eos_preserving_limit, vocab_txt_to_map,
};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
Expand All @@ -20,6 +20,12 @@ use tracing::warn;
const DEFAULT_BPE_CACHE_CAPACITY: usize = 8_192;
const DEFAULT_BPE_MAX_SEQ_LEN: usize = 4_096;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PreTokenizeMode {
ByteLevel,
Gpt4Regex,
}

/// **Unstable**: byte-level BPE tokenizer for Qwen-family models; API evolving.
#[derive(Debug, Clone)]
pub struct BpeTokenizer {
Expand All @@ -40,6 +46,7 @@ struct BpeInner {
eos_id: Option<u32>,
add_bos: bool,
add_eos: bool,
pre_tokenize_mode: PreTokenizeMode,
max_seq_len: usize,
cache: ThreadSafeLruCache<String, Vec<u32>>,
}
Expand Down Expand Up @@ -160,6 +167,18 @@ impl BpeTokenizer {
tokenizer = tokenizer.with_unk_token(unk_token);
}

let pp = parse_post_processor_flags(&root);
if pp.add_eos {
tokenizer = tokenizer.with_add_eos();
if let Some(eos_id) = pp.eos_id {
tokenizer = tokenizer.with_eos_id(eos_id);
}
}

if detect_gpt4_regex_pretokenizer(&root) {
tokenizer = tokenizer.with_pre_tokenize_mode(PreTokenizeMode::Gpt4Regex);
}

Ok(tokenizer)
}

Expand Down Expand Up @@ -230,6 +249,7 @@ impl BpeTokenizer {
eos_id,
add_bos: false,
add_eos: false,
pre_tokenize_mode: PreTokenizeMode::ByteLevel,
max_seq_len,
cache: ThreadSafeLruCache::new(cache_capacity),
};
Expand All @@ -254,6 +274,7 @@ impl BpeTokenizer {
eos_id: self.inner.eos_id,
add_bos: self.inner.add_bos,
add_eos: self.inner.add_eos,
pre_tokenize_mode: self.inner.pre_tokenize_mode,
max_seq_len,
cache: self.inner.cache.clone(),
};
Expand All @@ -278,6 +299,7 @@ impl BpeTokenizer {
eos_id: self.inner.eos_id,
add_bos: self.inner.add_bos,
add_eos: self.inner.add_eos,
pre_tokenize_mode: self.inner.pre_tokenize_mode,
max_seq_len: self.inner.max_seq_len,
cache: self.inner.cache.clone(),
};
Expand Down Expand Up @@ -306,6 +328,53 @@ impl BpeTokenizer {
eos_id: self.inner.eos_id,
add_bos: self.inner.add_bos,
add_eos: true,
pre_tokenize_mode: self.inner.pre_tokenize_mode,
max_seq_len: self.inner.max_seq_len,
cache: self.inner.cache.clone(),
};
Self {
inner: Arc::new(inner),
}
}

fn with_eos_id(self, eos_id: u32) -> Self {
let inner = BpeInner {
vocab: self.inner.vocab.clone(),
id_to_token: self.inner.id_to_token.clone(),
merges: self.inner.merges.clone(),
byte_encoder: self.inner.byte_encoder.clone(),
special_tokens: self.inner.special_tokens.clone(),
special_tokens_sorted: self.inner.special_tokens_sorted.clone(),
pad_id: self.inner.pad_id,
unk_id: self.inner.unk_id,
bos_id: self.inner.bos_id,
eos_id: Some(eos_id),
add_bos: self.inner.add_bos,
add_eos: self.inner.add_eos,
pre_tokenize_mode: self.inner.pre_tokenize_mode,
max_seq_len: self.inner.max_seq_len,
cache: self.inner.cache.clone(),
};
Self {
inner: Arc::new(inner),
}
}

fn with_pre_tokenize_mode(self, mode: PreTokenizeMode) -> Self {
let inner = BpeInner {
vocab: self.inner.vocab.clone(),
id_to_token: self.inner.id_to_token.clone(),
merges: self.inner.merges.clone(),
byte_encoder: self.inner.byte_encoder.clone(),
special_tokens: self.inner.special_tokens.clone(),
special_tokens_sorted: self.inner.special_tokens_sorted.clone(),
pad_id: self.inner.pad_id,
unk_id: self.inner.unk_id,
bos_id: self.inner.bos_id,
eos_id: self.inner.eos_id,
add_bos: self.inner.add_bos,
add_eos: self.inner.add_eos,
pre_tokenize_mode: mode,
max_seq_len: self.inner.max_seq_len,
cache: self.inner.cache.clone(),
};
Expand Down Expand Up @@ -427,7 +496,11 @@ impl BpeTokenizer {
}

fn tokenize_regular_segment_into(&self, text: &str, scratch: &mut TokenizeScratch) {
for piece in byte_level_pretokenize(text) {
let pieces = match self.inner.pre_tokenize_mode {
PreTokenizeMode::ByteLevel => byte_level_pretokenize(text),
PreTokenizeMode::Gpt4Regex => gpt4_regex_pretokenize(text),
};
for piece in pieces {
if let Some(cached) = self.inner.cache.get(&piece) {
scratch.ids.extend(cached.iter().copied());
continue;
Expand Down Expand Up @@ -858,6 +931,158 @@ pub fn byte_decode_token(token_str: &str) -> String {
String::from_utf8_lossy(&bytes).to_string()
}

fn detect_gpt4_regex_pretokenizer(root: &JsonValue) -> bool {
let Some(pt) = root.get("pre_tokenizer") else {
return false;
};
has_regex_split(pt)
}

fn has_regex_split(pt: &JsonValue) -> bool {
let pt_type = pt.get("type").and_then(JsonValue::as_str).unwrap_or("");
match pt_type {
"Split" => pt.get("pattern").and_then(|p| p.get("Regex")).is_some(),
"Sequence" => pt
.get("pretokenizers")
.and_then(JsonValue::as_array)
.is_some_and(|arr| arr.iter().any(has_regex_split)),
_ => false,
}
}

fn gpt4_regex_pretokenize(text: &str) -> Vec<String> {
let chars: Vec<char> = text.chars().collect();
let mut pieces = Vec::new();
let mut pos = 0;

while pos < chars.len() {
if let Some(end) = try_contraction(&chars, pos) {
pieces.push(chars[pos..end].iter().collect());
pos = end;
} else if let Some(end) = try_prefix_letters(&chars, pos) {
pieces.push(chars[pos..end].iter().collect());
pos = end;
} else if chars[pos].is_numeric() {
pieces.push(chars[pos].to_string());
pos += 1;
} else if let Some(end) = try_punctuation_run(&chars, pos) {
pieces.push(chars[pos..end].iter().collect());
pos = end;
} else if let Some(end) = try_newline_run(&chars, pos) {
pieces.push(chars[pos..end].iter().collect());
pos = end;
} else if let Some(end) = try_trailing_ws(&chars, pos) {
pieces.push(chars[pos..end].iter().collect());
pos = end;
} else if chars[pos].is_whitespace() {
let start = pos;
while pos < chars.len() && chars[pos].is_whitespace() {
pos += 1;
}
pieces.push(chars[start..pos].iter().collect());
} else {
pieces.push(chars[pos].to_string());
pos += 1;
}
}

pieces
}

fn eq_ci(a: char, lower: char) -> bool {
a.to_ascii_lowercase() == lower
}

/// `(?i:'s|'t|'re|'ve|'m|'ll|'d)`
fn try_contraction(chars: &[char], pos: usize) -> Option<usize> {
if chars.get(pos).copied() != Some('\'') {
return None;
}
let rest = &chars[pos + 1..];
if rest.len() >= 2 && eq_ci(rest[0], 'l') && eq_ci(rest[1], 'l') {
return Some(pos + 3);
}
if rest.len() >= 2 && eq_ci(rest[0], 'r') && eq_ci(rest[1], 'e') {
return Some(pos + 3);
}
if rest.len() >= 2 && eq_ci(rest[0], 'v') && eq_ci(rest[1], 'e') {
return Some(pos + 3);
}
if !rest.is_empty() {
let c = rest[0].to_ascii_lowercase();
if matches!(c, 's' | 't' | 'm' | 'd') {
return Some(pos + 2);
}
}
None
}

/// `[^\r\n\p{L}\p{N}]?\p{L}+`
fn try_prefix_letters(chars: &[char], pos: usize) -> Option<usize> {
let mut i = pos;
if i < chars.len()
&& !chars[i].is_alphabetic()
&& !chars[i].is_numeric()
&& chars[i] != '\r'
&& chars[i] != '\n'
{
i += 1;
}
let start = i;
while i < chars.len() && chars[i].is_alphabetic() {
i += 1;
}
if i > start { Some(i) } else { None }
}

/// ` ?[^\s\p{L}\p{N}]+[\r\n]*`
fn try_punctuation_run(chars: &[char], pos: usize) -> Option<usize> {
let mut i = pos;
if i < chars.len() && chars[i] == ' ' {
i += 1;
}
let start = i;
while i < chars.len()
&& !chars[i].is_whitespace()
&& !chars[i].is_alphabetic()
&& !chars[i].is_numeric()
{
i += 1;
}
if i == start {
return None;
}
while i < chars.len() && (chars[i] == '\r' || chars[i] == '\n') {
i += 1;
}
Some(i)
}

/// `\s*[\r\n]+`
fn try_newline_run(chars: &[char], pos: usize) -> Option<usize> {
let mut i = pos;
while i < chars.len() && chars[i].is_whitespace() && chars[i] != '\r' && chars[i] != '\n' {
i += 1;
}
let nl_start = i;
while i < chars.len() && (chars[i] == '\r' || chars[i] == '\n') {
i += 1;
}
if i > nl_start { Some(i) } else { None }
}

/// `\s+(?!\S)` — whitespace not followed by non-whitespace (i.e. trailing)
fn try_trailing_ws(chars: &[char], pos: usize) -> Option<usize> {
if !chars.get(pos).is_some_and(|c| c.is_whitespace()) {
return None;
}
let mut i = pos;
while i < chars.len() && chars[i].is_whitespace() {
i += 1;
}
if i == chars.len() { Some(i) } else { None }
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down