Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ name = "openai_harmony"
crate-type = ["rlib", "cdylib"]

[features]
default = []
default = ["network"]
python-binding = ["pyo3"]
wasm-binding = ["wasm-bindgen", "serde-wasm-bindgen", "wasm-bindgen-futures"]
network = ["reqwest"]

[dependencies]
anyhow = "1.0.98"
base64 = "0.22.1"
image = "0.25.6"
image = { version = "0.25.6", optional = true }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = { version = "1.0.140", features = ["preserve_order"] }
serde_with = "3.12.0"
Expand All @@ -33,7 +34,7 @@ sha2 = "0.10.9"
# installation on the CI runners. We disable the default features (which
# include `platform-native-tls`) and explicitly enable only the capabilities
# we need.
reqwest = { version = "0.12.5", default-features = false, features = [
reqwest = { version = "0.12.5", optional = true, default-features = false, features = [
"blocking",
"json",
"multipart",
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ mod tiktoken;
pub mod tiktoken_ext;

pub use encoding::{HarmonyEncoding, ParseOptions, StreamableParser};
#[cfg(feature = "network")]
pub use registry::load_harmony_encoding;
pub use registry::load_harmony_encoding_from_vocab_bytes;
pub use registry::HarmonyEncodingName;

#[cfg(test)]
Expand Down
63 changes: 60 additions & 3 deletions src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl std::fmt::Debug for HarmonyEncodingName {
}
}

#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(not(target_arch = "wasm32"), feature = "network"))]
pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<HarmonyEncoding> {
match name {
HarmonyEncodingName::HarmonyGptOss => {
Expand Down Expand Up @@ -81,7 +81,7 @@ pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<Harmon
}
}

#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", feature = "network"))]
pub async fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<HarmonyEncoding> {
match name {
HarmonyEncodingName::HarmonyGptOss => {
Expand Down Expand Up @@ -116,7 +116,64 @@ pub async fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<
FormattingToken::EndMessageDoneSampling,
FormattingToken::EndMessageAssistantToTool,
]),
conversation_has_function_tools: Arc::new(AtomicBool::new(false)),
})
}
}
}

/// Load a [`HarmonyEncoding`] from raw tiktoken vocab bytes.
///
/// This is useful in environments where filesystem access and async HTTP
/// are unavailable (e.g. `wasm32-unknown-unknown` without JS bindings).
pub fn load_harmony_encoding_from_vocab_bytes(
name: HarmonyEncodingName,
vocab_bytes: &[u8],
) -> anyhow::Result<HarmonyEncoding> {
match name {
HarmonyEncodingName::HarmonyGptOss => {
let n_ctx = 1_048_576;
let max_action_length = 524_288;
let encoding_ext = tiktoken_ext::Encoding::O200kHarmony;
let mut specials: Vec<(String, u32)> = encoding_ext
.special_tokens()
.iter()
.map(|(s, r)| ((*s).to_string(), *r))
.collect();
specials.extend((200014..=201088).map(|id| (format!("<|reserved_{id}|>"), id)));
let tokenizer = tiktoken_ext::load_encoding_from_bytes(
vocab_bytes,
None,
specials,
&encoding_ext.pattern(),
)?;
Ok(HarmonyEncoding {
name: name.to_string(),
n_ctx,
tokenizer: Arc::new(tokenizer),
tokenizer_name: encoding_ext.name().to_owned(),
max_message_tokens: n_ctx - max_action_length,
max_action_length,
format_token_mapping: make_mapping([
(FormattingToken::Start, "<|start|>"),
(FormattingToken::Message, "<|message|>"),
(FormattingToken::EndMessage, "<|end|>"),
(FormattingToken::EndMessageDoneSampling, "<|return|>"),
(FormattingToken::Refusal, "<|refusal|>"),
(FormattingToken::ConstrainedFormat, "<|constrain|>"),
(FormattingToken::Channel, "<|channel|>"),
(FormattingToken::EndMessageAssistantToTool, "<|call|>"),
(FormattingToken::BeginUntrusted, "<|untrusted|>"),
(FormattingToken::EndUntrusted, "<|end_untrusted|>"),
]),
stop_formatting_tokens: HashSet::from([
FormattingToken::EndMessageDoneSampling,
FormattingToken::EndMessageAssistantToTool,
FormattingToken::EndMessage,
]),
stop_formatting_tokens_for_assistant_actions: HashSet::from([
FormattingToken::EndMessageDoneSampling,
FormattingToken::EndMessageAssistantToTool,
]),
})
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/tiktoken_ext/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mod public_encodings;
pub use public_encodings::{set_tiktoken_base_url, Encoding};
pub use public_encodings::{load_encoding_from_bytes, set_tiktoken_base_url, Encoding};
43 changes: 32 additions & 11 deletions src/tiktoken_ext/public_encodings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ impl Encoding {
None
}

#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(not(target_arch = "wasm32"), feature = "network"))]
pub fn load_from_name(name: impl AsRef<str>) -> Result<CoreBPE, LoadError> {
let name = name.as_ref();
Self::from_name(name)
.ok_or_else(|| LoadError::UnknownEncodingName(name.to_string()))?
.load()
}

#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", feature = "network"))]
pub async fn load_from_name(name: impl AsRef<str>) -> Result<CoreBPE, LoadError> {
let name = name.as_ref();
Self::from_name(name)
Expand All @@ -121,7 +121,7 @@ impl Encoding {
}
}

#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(not(target_arch = "wasm32"), feature = "network"))]
pub fn load(&self) -> Result<CoreBPE, LoadError> {
#[cfg(not(target_arch = "wasm32"))]
let (vocab_file_path, check_hash) =
Expand Down Expand Up @@ -202,7 +202,7 @@ impl Encoding {
}
}

#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", feature = "network"))]
pub async fn load(&self) -> Result<CoreBPE, LoadError> {
let url = self.public_vocab_file_url();
let vocab_bytes = download_or_find_cached_file_bytes(&url, Some(self.expected_hash()))
Expand Down Expand Up @@ -264,7 +264,7 @@ impl Encoding {
}
}

fn special_tokens(&self) -> &'static [(&'static str, Rank)] {
pub fn special_tokens(&self) -> &'static [(&'static str, Rank)] {
match self {
Self::O200kBase => &[],
Self::O200kHarmony => &[
Expand Down Expand Up @@ -295,7 +295,7 @@ impl Encoding {
}
}

fn pattern(&self) -> String {
pub fn pattern(&self) -> String {
match self {
Self::O200kBase => {
[
Expand Down Expand Up @@ -411,9 +411,30 @@ where
.map_err(LoadError::CoreBPECreationFailed)
}

pub fn load_encoding_from_bytes<S, TS>(
data: &[u8],
expected_hash: Option<&str>,
special_tokens: S,
pattern: &str,
) -> Result<CoreBPE, LoadError>
where
S: IntoIterator<Item = (TS, Rank)>,
TS: Into<String>,
{
let reader = std::io::BufReader::new(data);
let encoder =
load_tiktoken_vocab(reader, expected_hash).map_err(LoadError::InvalidTiktokenVocabFile)?;
CoreBPE::new(
encoder,
special_tokens.into_iter().map(|(k, v)| (k.into(), v)),
pattern,
)
.map_err(LoadError::CoreBPECreationFailed)
}

/// This returns the path to a file containing the data at `url`. If the file is
/// cached, it is used. Otherwise, the file is downloaded and cached.
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(not(target_arch = "wasm32"), feature = "network"))]
fn download_or_find_cached_file(
url: &str,
expected_hash: Option<&str>,
Expand All @@ -440,7 +461,7 @@ fn download_or_find_cached_file(
Ok(cache_path)
}

#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", feature = "network"))]
async fn download_or_find_cached_file_bytes(
url: &str,
expected_hash: Option<&str>,
Expand Down Expand Up @@ -505,7 +526,7 @@ fn verify_file_hash(

/// Loads a remote file to `destination` and returns the computed hash of the
/// file contents.
#[cfg(not(target_arch = "wasm32"))]
#[cfg(all(not(target_arch = "wasm32"), feature = "network"))]
fn load_remote_file(url: &str, destination: &Path) -> Result<String, RemoteVocabFileError> {
let client = reqwest::blocking::Client::new();
let mut response = client
Expand Down Expand Up @@ -534,7 +555,7 @@ fn load_remote_file(url: &str, destination: &Path) -> Result<String, RemoteVocab
Ok(format!("{:x}", hasher.finalize()))
}

#[cfg(target_arch = "wasm32")]
#[cfg(any(target_arch = "wasm32", not(feature = "network")))]
fn load_remote_file(_url: &str, _destination: &Path) -> Result<String, RemoteVocabFileError> {
Err(RemoteVocabFileError::FailedToDownloadOrLoadVocabFile(
Box::new(std::io::Error::new(
Expand All @@ -544,7 +565,7 @@ fn load_remote_file(_url: &str, _destination: &Path) -> Result<String, RemoteVoc
))
}

#[cfg(target_arch = "wasm32")]
#[cfg(all(target_arch = "wasm32", feature = "network"))]
async fn load_remote_file_bytes(url: &str) -> Result<Vec<u8>, RemoteVocabFileError> {
use reqwest::Client;

Expand Down