diff --git a/tests/utils/safety_test.py b/tests/utils/safety_test.py new file mode 100644 index 000000000..b30a65a46 --- /dev/null +++ b/tests/utils/safety_test.py @@ -0,0 +1,67 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +from tunix.utils import safety + + +class SafetyTest(parameterized.TestCase): + + @parameterized.parameters( + ('', ''), + (None, None), + ('hello world', 'hello world'), + ('user', 'user'), + ('Hello', 'Hello'), + ('<|begin_of_text|>System message<|end_of_text|>', 'System message'), + ('<|start_header_id|>assistant<|end_header_id|><|eot_id|>', 'assistant'), + ('<|im_start|>user\nHello!<|im_end|>', 'user\nHello!'), + ('Nested tokens', 'Nested tokens'), + ('of_turn>', ''), + ('Recursive >', 'Recursive '), + ) + def test_sanitize_control_tokens(self, content, expected): + self.assertEqual(safety.sanitize_control_tokens(content), expected) + + def test_sanitize_control_tokens_with_extra(self): + content = '[CUSTOM]user\nHello![/CUSTOM]' + extra_tokens = ['[CUSTOM]', '[/CUSTOM]'] + expected = 'user\nHello!' + self.assertEqual( + safety.sanitize_control_tokens(content, extra_tokens=extra_tokens), + expected, + ) + + def test_sanitize_control_tokens_recursive_with_extra(self): + content = '[CUSTOM[CUSTOM]]nested[/CUSTOM]' + extra_tokens = ['[CUSTOM]', '[/CUSTOM]'] + expected = 'nested' + self.assertEqual( + safety.sanitize_control_tokens(content, extra_tokens=extra_tokens), + expected, + ) + + def test_sanitize_control_tokens_with_empty_extra(self): + content = 'hello world' + extra_tokens = ['', None] + expected = 'hello world' + self.assertEqual( + safety.sanitize_control_tokens(content, extra_tokens=extra_tokens), + expected, + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/tunix/examples/data/math_dataset.py b/tunix/examples/data/math_dataset.py index 43d161cb7..19b97d902 100644 --- a/tunix/examples/data/math_dataset.py +++ b/tunix/examples/data/math_dataset.py @@ -18,6 +18,8 @@ import datasets as hf_datasets import grain import tensorflow_datasets as tfds +from tunix.utils import safety + # For OSS usage import tensorflow_datasets.text.gsm8k @@ -77,9 +79,10 @@ def _process_element(x): if isinstance(value, bytes): item[key] = value.decode("utf-8") + question = safety.sanitize_control_tokens(item["question"]) return { - "prompts": _apply_template(item["question"]), - "question": item["question"], + "prompts": _apply_template(question), + "question": question, "answer": extract_hash_answer(item["answer"]), } diff --git a/tunix/generate/tokenizer_adapter.py b/tunix/generate/tokenizer_adapter.py index ecdbdbb2d..86a368128 100644 --- a/tunix/generate/tokenizer_adapter.py +++ b/tunix/generate/tokenizer_adapter.py @@ -20,6 +20,7 @@ from etils import epath import numpy as np +from tunix.utils import safety import sentencepiece as spm @@ -168,6 +169,10 @@ def apply_chat_template( Raises: NotImplementedError: If chat templating is not supported by the tokenizer. """ + messages = [ + {**m, 'content': safety.sanitize_control_tokens(m['content'])} + for m in messages + ] if self._tokenizer_type == TokenizerType.HF: return self._tokenizer.apply_chat_template( messages, @@ -276,6 +281,7 @@ def tokenize( Returns: Tokens corresponding to the input string. """ + example = safety.sanitize_control_tokens(example) int_list = [] if self.bos_id(): int_list.append(self.bos_id()) diff --git a/tunix/rl/agentic/parser/chat_template_parser/parser.py b/tunix/rl/agentic/parser/chat_template_parser/parser.py index 24a42f7a6..da3d529a7 100644 --- a/tunix/rl/agentic/parser/chat_template_parser/parser.py +++ b/tunix/rl/agentic/parser/chat_template_parser/parser.py @@ -17,6 +17,7 @@ import abc import dataclasses from typing import Dict, List +from tunix.utils import safety dataclass = dataclasses.dataclass @@ -48,6 +49,11 @@ def __init__(self, tokenizer, enable_thinking: bool = True): self.enable_thinking = enable_thinking self.tokens = self._init_tokens() self.generation_prompt = self._init_generation_prompt() + self._tokens_to_sanitize = { + v + for v in dataclasses.asdict(self.tokens).values() + if isinstance(v, str) and v + } @abstractmethod def _init_tokens(self) -> TokenConfig: @@ -87,6 +93,9 @@ def _handle_first_message(self, messages: List[Dict[str, str]]) -> str: def _parse_message(self, message: Dict[str, str]) -> str: """Parse a single message based on its role.""" role = message["role"] + content = safety.sanitize_control_tokens( + message["content"], extra_tokens=self._tokens_to_sanitize + ) parser_map = { "system": self._parse_system, @@ -98,24 +107,22 @@ def _parse_message(self, message: Dict[str, str]) -> str: if role not in parser_map: raise NotImplementedError(f"Unsupported message role: {role}") - return parser_map[role](message) + return parser_map[role](content) - def _parse_system(self, message: Dict[str, str]) -> str: - return self.tokens.system_token + message["content"] + self.tokens.eot_token + def _parse_system(self, content: str) -> str: + return self.tokens.system_token + content + self.tokens.eot_token - def _parse_user(self, message: Dict[str, str]) -> str: - return self.tokens.user_token + message["content"] + self.tokens.eot_token + def _parse_user(self, content: str) -> str: + return self.tokens.user_token + content + self.tokens.eot_token - def _parse_assistant(self, message: Dict[str, str]) -> str: - return ( - self.tokens.assistant_token + message["content"] + self.tokens.eot_token - ) + def _parse_assistant(self, content: str) -> str: + return self.tokens.assistant_token + content + self.tokens.eot_token - def _parse_tool(self, message: Dict[str, str]) -> str: + def _parse_tool(self, content: str) -> str: return ( self.tokens.user_token + self.tokens.tool_response_start_token - + message["content"] + + content + self.tokens.tool_response_end_token + self.tokens.eot_token ) @@ -170,13 +177,9 @@ def _init_generation_prompt(self) -> str: def _handle_first_message(self, messages: List[Dict[str, str]]) -> str: """Add default system message if first message is not system.""" if messages[0]["role"] != "system": - return self._parse_system({ - "role": "system", - "content": ( - "You are Qwen, created by Alibaba Cloud. You are a helpful" - " assistant." - ), - }) + return self._parse_system( + "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." + ) return "" @@ -213,13 +216,12 @@ def _init_tokens(self) -> TokenConfig: assistant_token="model\n", ) - def _parse_assistant(self, message: Dict[str, str]) -> str: - return self.tokens.assistant_token + message["content"] + def _parse_assistant(self, content: str) -> str: + return self.tokens.assistant_token + content - def _parse_system(self, message: Dict[str, str]) -> str: + def _parse_system(self, content: str) -> str: # This should not be called if parse() is used, as it handles the system # prompt by merging it. Raise error for unexpected system messages. - del message # Unused. raise ValueError( "Gemma models do not support system messages directly. The system" " prompt should be the first message and is handled by merging with" diff --git a/tunix/utils/safety.py b/tunix/utils/safety.py new file mode 100644 index 000000000..924cf5a06 --- /dev/null +++ b/tunix/utils/safety.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Safety utilities for RL agents.""" + +import re +from typing import Iterable + +_CONTROL_TOKENS = [ + '', + '', + '', + '', + '<|begin_of_text|>', + '<|end_of_text|>', + '<|start_header_id|>', + '<|end_header_id|>', + '<|eot_id|>', + '<|im_start|>', + '<|im_end|>', +] + +_CONTROL_TOKENS_RE = re.compile('|'.join(map(re.escape, _CONTROL_TOKENS))) + + +def sanitize_control_tokens( + content: str, extra_tokens: Iterable[str] | None = None +) -> str: + """Sanitize control tokens from the content. + + Args: + content: The content to sanitize. + extra_tokens: Additional tokens to sanitize. + + Returns: + The sanitized content. + """ + if not content: + return content + + if extra_tokens: + # Combine with default tokens and create a one-off regex. + # Ignore empty strings to prevent matching everywhere. + all_tokens = set(_CONTROL_TOKENS) | {t for t in extra_tokens if t} + regex = re.compile('|'.join(map(re.escape, all_tokens))) + else: + regex = _CONTROL_TOKENS_RE + + # Strip known model control tokens to prevent prompt injection. + # We use a loop to handle cases where tokens are nested to bypass sanitization. + sanitized = content + while True: + new_content = regex.sub('', sanitized) + if new_content == sanitized: + break + sanitized = new_content + return sanitized