diff --git a/Cargo.toml b/Cargo.toml index 3573719..fd93a6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,14 +11,15 @@ name = "openai_harmony" crate-type = ["rlib", "cdylib"] [features] -default = [] +default = ["network"] python-binding = ["pyo3"] wasm-binding = ["wasm-bindgen", "serde-wasm-bindgen", "wasm-bindgen-futures"] +network = ["reqwest"] [dependencies] anyhow = "1.0.98" base64 = "0.22.1" -image = "0.25.6" +image = { version = "0.25.6", optional = true } serde = { version = "1.0.219", features = ["derive"] } serde_json = { version = "1.0.140", features = ["preserve_order"] } serde_with = "3.12.0" @@ -33,7 +34,7 @@ sha2 = "0.10.9" # installation on the CI runners. We disable the default features (which # include `platform-native-tls`) and explicitly enable only the capabilities # we need. -reqwest = { version = "0.12.5", default-features = false, features = [ +reqwest = { version = "0.12.5", optional = true, default-features = false, features = [ "blocking", "json", "multipart", diff --git a/src/lib.rs b/src/lib.rs index 6c2fcfb..96d950e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,9 @@ mod tiktoken; pub mod tiktoken_ext; pub use encoding::{HarmonyEncoding, ParseOptions, StreamableParser}; +#[cfg(feature = "network")] pub use registry::load_harmony_encoding; +pub use registry::load_harmony_encoding_from_vocab_bytes; pub use registry::HarmonyEncodingName; #[cfg(test)] diff --git a/src/registry.rs b/src/registry.rs index d1ffd2e..4722963 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -41,7 +41,7 @@ impl std::fmt::Debug for HarmonyEncodingName { } } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(not(target_arch = "wasm32"), feature = "network"))] pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result { match name { HarmonyEncodingName::HarmonyGptOss => { @@ -81,7 +81,7 @@ pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result anyhow::Result { match name { HarmonyEncodingName::HarmonyGptOss => { @@ -116,7 +116,64 @@ pub async fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result< FormattingToken::EndMessageDoneSampling, FormattingToken::EndMessageAssistantToTool, ]), - conversation_has_function_tools: Arc::new(AtomicBool::new(false)), + }) + } + } +} + +/// Load a [`HarmonyEncoding`] from raw tiktoken vocab bytes. +/// +/// This is useful in environments where filesystem access and async HTTP +/// are unavailable (e.g. `wasm32-unknown-unknown` without JS bindings). +pub fn load_harmony_encoding_from_vocab_bytes( + name: HarmonyEncodingName, + vocab_bytes: &[u8], +) -> anyhow::Result { + match name { + HarmonyEncodingName::HarmonyGptOss => { + let n_ctx = 1_048_576; + let max_action_length = 524_288; + let encoding_ext = tiktoken_ext::Encoding::O200kHarmony; + let mut specials: Vec<(String, u32)> = encoding_ext + .special_tokens() + .iter() + .map(|(s, r)| ((*s).to_string(), *r)) + .collect(); + specials.extend((200014..=201088).map(|id| (format!("<|reserved_{id}|>"), id))); + let tokenizer = tiktoken_ext::load_encoding_from_bytes( + vocab_bytes, + None, + specials, + &encoding_ext.pattern(), + )?; + Ok(HarmonyEncoding { + name: name.to_string(), + n_ctx, + tokenizer: Arc::new(tokenizer), + tokenizer_name: encoding_ext.name().to_owned(), + max_message_tokens: n_ctx - max_action_length, + max_action_length, + format_token_mapping: make_mapping([ + (FormattingToken::Start, "<|start|>"), + (FormattingToken::Message, "<|message|>"), + (FormattingToken::EndMessage, "<|end|>"), + (FormattingToken::EndMessageDoneSampling, "<|return|>"), + (FormattingToken::Refusal, "<|refusal|>"), + (FormattingToken::ConstrainedFormat, "<|constrain|>"), + (FormattingToken::Channel, "<|channel|>"), + (FormattingToken::EndMessageAssistantToTool, "<|call|>"), + (FormattingToken::BeginUntrusted, "<|untrusted|>"), + (FormattingToken::EndUntrusted, "<|end_untrusted|>"), + ]), + stop_formatting_tokens: HashSet::from([ + FormattingToken::EndMessageDoneSampling, + FormattingToken::EndMessageAssistantToTool, + FormattingToken::EndMessage, + ]), + stop_formatting_tokens_for_assistant_actions: HashSet::from([ + FormattingToken::EndMessageDoneSampling, + FormattingToken::EndMessageAssistantToTool, + ]), }) } } diff --git a/src/tiktoken_ext/mod.rs b/src/tiktoken_ext/mod.rs index 5ad31ae..69955c5 100644 --- a/src/tiktoken_ext/mod.rs +++ b/src/tiktoken_ext/mod.rs @@ -1,2 +1,2 @@ mod public_encodings; -pub use public_encodings::{set_tiktoken_base_url, Encoding}; +pub use public_encodings::{load_encoding_from_bytes, set_tiktoken_base_url, Encoding}; diff --git a/src/tiktoken_ext/public_encodings.rs b/src/tiktoken_ext/public_encodings.rs index ab9c435..85be660 100644 --- a/src/tiktoken_ext/public_encodings.rs +++ b/src/tiktoken_ext/public_encodings.rs @@ -96,7 +96,7 @@ impl Encoding { None } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(not(target_arch = "wasm32"), feature = "network"))] pub fn load_from_name(name: impl AsRef) -> Result { let name = name.as_ref(); Self::from_name(name) @@ -104,7 +104,7 @@ impl Encoding { .load() } - #[cfg(target_arch = "wasm32")] + #[cfg(all(target_arch = "wasm32", feature = "network"))] pub async fn load_from_name(name: impl AsRef) -> Result { let name = name.as_ref(); Self::from_name(name) @@ -121,7 +121,7 @@ impl Encoding { } } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(not(target_arch = "wasm32"), feature = "network"))] pub fn load(&self) -> Result { #[cfg(not(target_arch = "wasm32"))] let (vocab_file_path, check_hash) = @@ -202,7 +202,7 @@ impl Encoding { } } - #[cfg(target_arch = "wasm32")] + #[cfg(all(target_arch = "wasm32", feature = "network"))] pub async fn load(&self) -> Result { let url = self.public_vocab_file_url(); let vocab_bytes = download_or_find_cached_file_bytes(&url, Some(self.expected_hash())) @@ -264,7 +264,7 @@ impl Encoding { } } - fn special_tokens(&self) -> &'static [(&'static str, Rank)] { + pub fn special_tokens(&self) -> &'static [(&'static str, Rank)] { match self { Self::O200kBase => &[], Self::O200kHarmony => &[ @@ -295,7 +295,7 @@ impl Encoding { } } - fn pattern(&self) -> String { + pub fn pattern(&self) -> String { match self { Self::O200kBase => { [ @@ -411,9 +411,30 @@ where .map_err(LoadError::CoreBPECreationFailed) } +pub fn load_encoding_from_bytes( + data: &[u8], + expected_hash: Option<&str>, + special_tokens: S, + pattern: &str, +) -> Result +where + S: IntoIterator, + TS: Into, +{ + let reader = std::io::BufReader::new(data); + let encoder = + load_tiktoken_vocab(reader, expected_hash).map_err(LoadError::InvalidTiktokenVocabFile)?; + CoreBPE::new( + encoder, + special_tokens.into_iter().map(|(k, v)| (k.into(), v)), + pattern, + ) + .map_err(LoadError::CoreBPECreationFailed) +} + /// This returns the path to a file containing the data at `url`. If the file is /// cached, it is used. Otherwise, the file is downloaded and cached. -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(not(target_arch = "wasm32"), feature = "network"))] fn download_or_find_cached_file( url: &str, expected_hash: Option<&str>, @@ -440,7 +461,7 @@ fn download_or_find_cached_file( Ok(cache_path) } -#[cfg(target_arch = "wasm32")] +#[cfg(all(target_arch = "wasm32", feature = "network"))] async fn download_or_find_cached_file_bytes( url: &str, expected_hash: Option<&str>, @@ -505,7 +526,7 @@ fn verify_file_hash( /// Loads a remote file to `destination` and returns the computed hash of the /// file contents. -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(not(target_arch = "wasm32"), feature = "network"))] fn load_remote_file(url: &str, destination: &Path) -> Result { let client = reqwest::blocking::Client::new(); let mut response = client @@ -534,7 +555,7 @@ fn load_remote_file(url: &str, destination: &Path) -> Result Result { Err(RemoteVocabFileError::FailedToDownloadOrLoadVocabFile( Box::new(std::io::Error::new( @@ -544,7 +565,7 @@ fn load_remote_file(_url: &str, _destination: &Path) -> Result Result, RemoteVocabFileError> { use reqwest::Client;