Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions src/py_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl PyHarmonyEncoding {
/// The encoded token sequence.
fn render_conversation_for_completion(
&self,
py: Python<'_>,
conversation_json: &str,
next_turn_role: &str,
config: Option<Bound<'_, PyDict>>,
Expand Down Expand Up @@ -116,14 +117,18 @@ impl PyHarmonyEncoding {
None
};

self.inner
.render_conversation_for_completion(&conversation, role, rust_config.as_ref())
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
// Release GIL during CPU-intensive rendering to allow other Python threads to run
py.allow_threads(|| {
self.inner
.render_conversation_for_completion(&conversation, role, rust_config.as_ref())
})
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))
}

/// Render a conversation without appending a new role.
fn render_conversation(
&self,
py: Python<'_>,
conversation_json: &str,
config: Option<Bound<'_, PyDict>>,
) -> PyResult<Vec<u32>> {
Expand All @@ -144,14 +149,18 @@ impl PyHarmonyEncoding {
None
};

self.inner
.render_conversation(&conversation, rust_config.as_ref())
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
// Release GIL during CPU-intensive rendering to allow other Python threads to run
py.allow_threads(|| {
self.inner
.render_conversation(&conversation, rust_config.as_ref())
})
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
}

/// Render a conversation for training.
fn render_conversation_for_training(
&self,
py: Python<'_>,
conversation_json: &str,
config: Option<Bound<'_, PyDict>>,
) -> PyResult<Vec<u32>> {
Expand All @@ -172,14 +181,18 @@ impl PyHarmonyEncoding {
None
};

self.inner
.render_conversation_for_training(&conversation, rust_config.as_ref())
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
// Release GIL during CPU-intensive rendering to allow other Python threads to run
py.allow_threads(|| {
self.inner
.render_conversation_for_training(&conversation, rust_config.as_ref())
})
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
}

/// Render a single message into tokens.
fn render(
&self,
py: Python<'_>,
message_json: &str,
render_options: Option<Bound<'_, PyDict>>,
) -> PyResult<Vec<u32>> {
Expand All @@ -199,8 +212,8 @@ impl PyHarmonyEncoding {
None
};

self.inner
.render(&message, rust_options.as_ref())
// Release GIL during CPU-intensive rendering to allow other Python threads to run
py.allow_threads(|| self.inner.render(&message, rust_options.as_ref()))
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
}

Expand Down Expand Up @@ -253,7 +266,12 @@ impl PyHarmonyEncoding {
}

/// Encode text into tokens using the underlying tokenizer with a set of allowed special tokens.
fn encode(&self, text: &str, allowed_special: Option<Bound<'_, PyAny>>) -> PyResult<Vec<u32>> {
fn encode(
&self,
py: Python<'_>,
text: &str,
allowed_special: Option<Bound<'_, PyAny>>,
) -> PyResult<Vec<u32>> {
let allowed_vec: Vec<String> = match allowed_special {
Some(obj) => obj.extract::<Vec<String>>().map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
Expand All @@ -264,7 +282,9 @@ impl PyHarmonyEncoding {
};
let allowed_set: std::collections::HashSet<&str> =
allowed_vec.iter().map(|s| s.as_str()).collect();
Ok(self.inner.tokenizer().encode(text, &allowed_set).0)

// Release GIL during CPU-intensive encoding to allow other Python threads to run
Ok(py.allow_threads(|| self.inner.tokenizer().encode(text, &allowed_set).0))
}

/// Return the list of special tokens for this tokenizer.
Expand Down
Loading