From ead89480662039882c41f02fb433722af1a2414d Mon Sep 17 00:00:00 2001 From: Guillaume Noale Date: Tue, 31 Mar 2026 11:31:48 +0200 Subject: [PATCH] fix: release Python GIL during CPU intensive operations --- src/py_module.rs | 46 ++++++--- tests/test_harmony.py | 222 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 235 insertions(+), 33 deletions(-) diff --git a/src/py_module.rs b/src/py_module.rs index 345a887..6d7684f 100644 --- a/src/py_module.rs +++ b/src/py_module.rs @@ -86,6 +86,7 @@ impl PyHarmonyEncoding { /// The encoded token sequence. fn render_conversation_for_completion( &self, + py: Python<'_>, conversation_json: &str, next_turn_role: &str, config: Option>, @@ -116,14 +117,18 @@ impl PyHarmonyEncoding { None }; - self.inner - .render_conversation_for_completion(&conversation, role, rust_config.as_ref()) - .map_err(|e| PyErr::new::(e.to_string())) + // Release GIL during CPU-intensive rendering to allow other Python threads to run + py.allow_threads(|| { + self.inner + .render_conversation_for_completion(&conversation, role, rust_config.as_ref()) + }) + .map_err(|e| PyErr::new::(e.to_string())) } /// Render a conversation without appending a new role. fn render_conversation( &self, + py: Python<'_>, conversation_json: &str, config: Option>, ) -> PyResult> { @@ -144,14 +149,18 @@ impl PyHarmonyEncoding { None }; - self.inner - .render_conversation(&conversation, rust_config.as_ref()) - .map_err(|e| PyErr::new::(e.to_string())) + // Release GIL during CPU-intensive rendering to allow other Python threads to run + py.allow_threads(|| { + self.inner + .render_conversation(&conversation, rust_config.as_ref()) + }) + .map_err(|e| PyErr::new::(e.to_string())) } /// Render a conversation for training. fn render_conversation_for_training( &self, + py: Python<'_>, conversation_json: &str, config: Option>, ) -> PyResult> { @@ -172,14 +181,18 @@ impl PyHarmonyEncoding { None }; - self.inner - .render_conversation_for_training(&conversation, rust_config.as_ref()) - .map_err(|e| PyErr::new::(e.to_string())) + // Release GIL during CPU-intensive rendering to allow other Python threads to run + py.allow_threads(|| { + self.inner + .render_conversation_for_training(&conversation, rust_config.as_ref()) + }) + .map_err(|e| PyErr::new::(e.to_string())) } /// Render a single message into tokens. fn render( &self, + py: Python<'_>, message_json: &str, render_options: Option>, ) -> PyResult> { @@ -199,8 +212,8 @@ impl PyHarmonyEncoding { None }; - self.inner - .render(&message, rust_options.as_ref()) + // Release GIL during CPU-intensive rendering to allow other Python threads to run + py.allow_threads(|| self.inner.render(&message, rust_options.as_ref())) .map_err(|e| PyErr::new::(e.to_string())) } @@ -253,7 +266,12 @@ impl PyHarmonyEncoding { } /// Encode text into tokens using the underlying tokenizer with a set of allowed special tokens. - fn encode(&self, text: &str, allowed_special: Option>) -> PyResult> { + fn encode( + &self, + py: Python<'_>, + text: &str, + allowed_special: Option>, + ) -> PyResult> { let allowed_vec: Vec = match allowed_special { Some(obj) => obj.extract::>().map_err(|e| { PyErr::new::(format!( @@ -264,7 +282,9 @@ impl PyHarmonyEncoding { }; let allowed_set: std::collections::HashSet<&str> = allowed_vec.iter().map(|s| s.as_str()).collect(); - Ok(self.inner.tokenizer().encode(text, &allowed_set).0) + + // Release GIL during CPU-intensive encoding to allow other Python threads to run + Ok(py.allow_threads(|| self.inner.tokenizer().encode(text, &allowed_set).0)) } /// Return the list of special tokens for this tokenizer. diff --git a/tests/test_harmony.py b/tests/test_harmony.py index dbb9925..cefbf88 100644 --- a/tests/test_harmony.py +++ b/tests/test_harmony.py @@ -996,7 +996,9 @@ def test_streamable_parser_missing_message_token(strict: bool, expect_error: boo parser = StreamableParser(encoding, Role.ASSISTANT, strict=strict) if expect_error: - with pytest.raises(HarmonyError, match="unexpected tokens remaining in message header"): + with pytest.raises( + HarmonyError, match="unexpected tokens remaining in message header" + ): for token in tokens: parser.process(token) return @@ -1031,7 +1033,9 @@ def test_streamable_parser_missing_message_token_other_initial_headers( parser = StreamableParser(encoding, Role.ASSISTANT, strict=strict) if expect_error: - with pytest.raises(HarmonyError, match="unexpected tokens remaining in message header"): + with pytest.raises( + HarmonyError, match="unexpected tokens remaining in message header" + ): for token in tokens: parser.process(token) return @@ -1068,7 +1072,9 @@ def test_streamable_parser_missing_message_token_tool_call( parser = StreamableParser(encoding, Role.ASSISTANT, strict=strict) if expect_error: - with pytest.raises(HarmonyError, match="unexpected tokens remaining in message header"): + with pytest.raises( + HarmonyError, match="unexpected tokens remaining in message header" + ): for token in tokens: parser.process(token) return @@ -1077,12 +1083,8 @@ def test_streamable_parser_missing_message_token_tool_call( parser.process(token) expected = [ - Message.from_role_and_content( - Role.ASSISTANT, "... Let's use the tool." - ), - Message.from_role_and_content( - Role.ASSISTANT, '{"location": "Tokyo"}' - ) + Message.from_role_and_content(Role.ASSISTANT, "... Let's use the tool."), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') .with_channel("commentary") .with_recipient("functions.get_weather") .with_content_type("json"), @@ -1100,7 +1102,9 @@ def test_streamable_parser_invalid_utf8_decoding(): with pytest.raises(HarmonyError): encoding.decode_utf8(invalid_token_sequence) - prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + prefix_tokens = encoding.encode( + "<|start|>assistant<|message|>", allowed_special="all" + ) suffix_tokens = encoding.encode("worked<|end|>", allowed_special="all") tokens = prefix_tokens + invalid_token_sequence + suffix_tokens parser = StreamableParser(encoding, None) @@ -1110,7 +1114,7 @@ def test_streamable_parser_invalid_utf8_decoding(): expected = [ # Confirm we got the utf-8 replacement characters for the invalid sequences # and the remaining valid utf-8 sequence - Message.from_role_and_content(Role.ASSISTANT, " \uFFFD \uFFFDworked"), + Message.from_role_and_content(Role.ASSISTANT, " \ufffd \ufffdworked"), ] assert parser.messages == expected @@ -1129,7 +1133,9 @@ def test_streamable_parser_invalid_utf8_decoding_split_across_tokens(): with pytest.raises(HarmonyError): encoding.decode_utf8(invalid_token_sequence) - prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + prefix_tokens = encoding.encode( + "<|start|>assistant<|message|>", allowed_special="all" + ) suffix_tokens = encoding.encode("<|end|>", allowed_special="all") tokens = prefix_tokens + invalid_token_sequence + suffix_tokens parser = StreamableParser(encoding, None) @@ -1139,7 +1145,7 @@ def test_streamable_parser_invalid_utf8_decoding_split_across_tokens(): expected = [ # One utf-8 replacement character but otherwise kept our space # (from token 9552) and "X" and "Y" tokens - Message.from_role_and_content(Role.ASSISTANT, " \uFFFDXY"), + Message.from_role_and_content(Role.ASSISTANT, " \ufffdXY"), ] assert parser.messages == expected @@ -1159,7 +1165,9 @@ def test_streamable_parser_invalid_utf8_decoding_multi_byte_token(): with pytest.raises(HarmonyError): encoding.decode_utf8(invalid_token_sequence) - prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + prefix_tokens = encoding.encode( + "<|start|>assistant<|message|>", allowed_special="all" + ) suffix_tokens = encoding.encode("<|end|>", allowed_special="all") tokens = prefix_tokens + invalid_token_sequence + suffix_tokens parser = StreamableParser(encoding, None) @@ -1169,7 +1177,7 @@ def test_streamable_parser_invalid_utf8_decoding_multi_byte_token(): expected = [ # One utf-8 replacement character and the contents of our second token, # which maps to the text " interesting" - Message.from_role_and_content(Role.ASSISTANT, " \uFFFD interesting"), + Message.from_role_and_content(Role.ASSISTANT, " \ufffd interesting"), ] assert parser.messages == expected @@ -1190,7 +1198,9 @@ def test_streamable_parser_invalid_utf8_decoding_multi_byte_token_no_eos_marker( with pytest.raises(HarmonyError): encoding.decode_utf8(invalid_token_sequence) - prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + prefix_tokens = encoding.encode( + "<|start|>assistant<|message|>", allowed_special="all" + ) suffix_tokens = encoding.encode(" story") tokens = prefix_tokens + invalid_token_sequence + suffix_tokens parser = StreamableParser(encoding, None) @@ -1202,16 +1212,16 @@ def test_streamable_parser_invalid_utf8_decoding_multi_byte_token_no_eos_marker( content_deltas.append(parser.last_content_delta) # No EOS, so no full message, but make sure we have the current content - assert parser.current_content == " \uFFFD interesting story" + assert parser.current_content == " \ufffd interesting story" # Ensure all the deltas combine to form our expected content - assert "".join(content_deltas) == " \uFFFD interesting story" + assert "".join(content_deltas) == " \ufffd interesting story" # Confirm we can keep accumulating content delta and content one_more_token = encoding.encode("Y")[0] parser.process(one_more_token) assert parser.last_content_delta == "Y" - assert parser.current_content == " \uFFFD interesting storyY" + assert parser.current_content == " \ufffd interesting storyY" def test_streamable_parser_tricky_utf8_decoding(): @@ -1225,7 +1235,9 @@ def test_streamable_parser_tricky_utf8_decoding(): ) valid_token_sequence = encoding.encode(tricky_utf8_text) - prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all") + prefix_tokens = encoding.encode( + "<|start|>assistant<|message|>", allowed_special="all" + ) suffix_tokens = encoding.encode("<|end|>", allowed_special="all") tokens = prefix_tokens + valid_token_sequence + suffix_tokens parser = StreamableParser(encoding, None) @@ -1244,3 +1256,173 @@ def test_streamable_parser_tricky_utf8_decoding(): # Ensure if we're accumulating content deltas we still get the full utf-8 text assert "".join(content_deltas) == tricky_utf8_text + + +# --------------------------------------------------------------------------- +# Asyncio concurrency tests - verify GIL release doesn't block event loop +# --------------------------------------------------------------------------- + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial + + +def test_asyncio_loop_stays_responsive_during_rendering(): + """ + Verify that rendering tokens in a ThreadPoolExecutor doesn't block + the asyncio event loop. This tests the GIL release fix. + """ + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + # Create a simple conversation + system_msg = Message.from_role_and_content( + Role.SYSTEM, "You are a helpful assistant." + ) + user_msg = Message.from_role_and_content(Role.USER, "Hello, how are you?") + conversation = Conversation.from_messages([system_msg, user_msg]) + + # Track execution order + execution_log = [] + + def render_for_completion(messages): + """Blocking render function""" + execution_log.append("render_start") + import time + + time.sleep(0.1) # Small delay to simulate work + token_ids = encoding.render_conversation_for_completion( + conversation, Role.ASSISTANT + ) + execution_log.append("render_end") + return token_ids + + async def other_task(): + """Other async task that should run concurrently""" + execution_log.append("other_task_start") + await asyncio.sleep(0.05) + execution_log.append("other_task_end") + return "done" + + async def main(): + # Use ThreadPoolExecutor to run blocking render in thread + loop = asyncio.get_event_loop() + executor = ThreadPoolExecutor(max_workers=1) + + # Schedule both tasks + render_future = loop.run_in_executor( + executor, partial(render_for_completion, [system_msg, user_msg]) + ) + other_task_coro = asyncio.create_task(other_task()) + + # Wait for both to complete + render_result, other_result = await asyncio.gather( + asyncio.wrap_future(render_future), other_task_coro + ) + + return render_result, other_result + + # Run the async test + render_tokens, other_result = asyncio.run(main()) + + # Verify both tasks completed + assert len(render_tokens) > 0 + assert other_result == "done" + + # Verify interleaving occurred (other task ran while render was in progress) + # This confirms the GIL was released during rendering + assert "other_task_start" in execution_log + assert "other_task_end" in execution_log + assert "render_start" in execution_log + assert "render_end" in execution_log + + # The other task should have started before render ended + # (proving concurrent execution) + other_start_idx = execution_log.index("other_task_start") + render_end_idx = execution_log.index("render_end") + assert other_start_idx < render_end_idx, ( + "Other task should have run concurrently with rendering" + ) + + +def test_concurrent_rendering_from_multiple_threads(): + """ + Verify that multiple threads can render conversations concurrently + without blocking each other. + """ + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + # Create test conversations + conversations = [] + for i in range(5): + system_msg = Message.from_role_and_content( + Role.SYSTEM, f"You are a helpful assistant. Test {i}" + ) + user_msg = Message.from_role_and_content(Role.USER, f"Hello {i}, how are you?") + conversations.append(Conversation.from_messages([system_msg, user_msg])) + + results = [] + + def render_single(conversation): + return encoding.render_conversation_for_completion(conversation, Role.ASSISTANT) + + # Use ThreadPoolExecutor with multiple workers + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(render_single, conv) for conv in conversations] + results = [f.result() for f in futures] + + # Verify all renders completed successfully + assert len(results) == 5 + assert all(len(tokens) > 0 for tokens in results) + + # Verify results are different (different inputs) + # (they should be at least somewhat different due to different content) + unique_results = set(tuple(tokens) for tokens in results) + assert len(unique_results) >= 4, "Most results should be unique" + + +def test_encode_releases_gil(): + """ + Verify that the encode method releases the GIL during tokenization. + """ + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + execution_log = [] + + def encode_text(text): + execution_log.append("encode_start") + import time + + time.sleep(0.05) + result = encoding.encode(text) + execution_log.append("encode_end") + return result + + async def other_task(): + execution_log.append("other_start") + await asyncio.sleep(0.02) + execution_log.append("other_end") + return "done" + + async def main(): + loop = asyncio.get_event_loop() + executor = ThreadPoolExecutor(max_workers=1) + + text = "This is a longer text that will take some time to encode. " * 100 + + encode_future = loop.run_in_executor(executor, partial(encode_text, text)) + other_task_coro = asyncio.create_task(other_task()) + + encode_result, other_result = await asyncio.gather( + asyncio.wrap_future(encode_future), other_task_coro + ) + + return encode_result, other_result + + tokens, other_result = asyncio.run(main()) + + assert len(tokens) > 0 + assert other_result == "done" + assert "encode_start" in execution_log + assert "encode_end" in execution_log + assert "other_start" in execution_log + assert "other_end" in execution_log