Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tests/utils/safety_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
from tunix.utils import safety


class SafetyTest(parameterized.TestCase):

@parameterized.parameters(
('', ''),
(None, None),
('hello world', 'hello world'),
('<start_of_turn>user<end_of_turn>', 'user'),
('<bos>Hello<eos>', 'Hello'),
('<|begin_of_text|>System message<|end_of_text|>', 'System message'),
('<|start_header_id|>assistant<|end_header_id|><|eot_id|>', 'assistant'),
('<|im_start|>user\nHello!<|im_end|>', 'user\nHello!'),
('Nested <start_of_turn><bos>tokens<eos><end_of_turn>', 'Nested tokens'),
('<start_<start_of_turn>of_turn>', ''),
('Recursive <bos<bos>>', 'Recursive '),
)
def test_sanitize_control_tokens(self, content, expected):
self.assertEqual(safety.sanitize_control_tokens(content), expected)

def test_sanitize_control_tokens_with_extra(self):
content = '[CUSTOM]user\nHello![/CUSTOM]'
extra_tokens = ['[CUSTOM]', '[/CUSTOM]']
expected = 'user\nHello!'
self.assertEqual(
safety.sanitize_control_tokens(content, extra_tokens=extra_tokens),
expected,
)

def test_sanitize_control_tokens_recursive_with_extra(self):
content = '[CUSTOM[CUSTOM]]nested[/CUSTOM]'
extra_tokens = ['[CUSTOM]', '[/CUSTOM]']
expected = 'nested'
self.assertEqual(
safety.sanitize_control_tokens(content, extra_tokens=extra_tokens),
expected,
)

def test_sanitize_control_tokens_with_empty_extra(self):
content = 'hello world'
extra_tokens = ['', None]
expected = 'hello world'
self.assertEqual(
safety.sanitize_control_tokens(content, extra_tokens=extra_tokens),
expected,
)


if __name__ == '__main__':
absltest.main()
7 changes: 5 additions & 2 deletions tunix/examples/data/math_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import datasets as hf_datasets
import grain
import tensorflow_datasets as tfds
from tunix.utils import safety

# For OSS usage
import tensorflow_datasets.text.gsm8k

Expand Down Expand Up @@ -77,9 +79,10 @@ def _process_element(x):
if isinstance(value, bytes):
item[key] = value.decode("utf-8")

question = safety.sanitize_control_tokens(item["question"])
return {
"prompts": _apply_template(item["question"]),
"question": item["question"],
"prompts": _apply_template(question),
"question": question,
"answer": extract_hash_answer(item["answer"]),
}

Expand Down
6 changes: 6 additions & 0 deletions tunix/generate/tokenizer_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from etils import epath
import numpy as np
from tunix.utils import safety

import sentencepiece as spm

Expand Down Expand Up @@ -168,6 +169,10 @@ def apply_chat_template(
Raises:
NotImplementedError: If chat templating is not supported by the tokenizer.
"""
messages = [
{**m, 'content': safety.sanitize_control_tokens(m['content'])}
for m in messages
]
if self._tokenizer_type == TokenizerType.HF:
return self._tokenizer.apply_chat_template(
messages,
Expand Down Expand Up @@ -276,6 +281,7 @@ def tokenize(
Returns:
Tokens corresponding to the input string.
"""
example = safety.sanitize_control_tokens(example)
int_list = []
if self.bos_id():
int_list.append(self.bos_id())
Expand Down
46 changes: 24 additions & 22 deletions tunix/rl/agentic/parser/chat_template_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
import dataclasses
from typing import Dict, List
from tunix.utils import safety


dataclass = dataclasses.dataclass
Expand Down Expand Up @@ -48,6 +49,11 @@ def __init__(self, tokenizer, enable_thinking: bool = True):
self.enable_thinking = enable_thinking
self.tokens = self._init_tokens()
self.generation_prompt = self._init_generation_prompt()
self._tokens_to_sanitize = {
v
for v in dataclasses.asdict(self.tokens).values()
if isinstance(v, str) and v
}

@abstractmethod
def _init_tokens(self) -> TokenConfig:
Expand Down Expand Up @@ -87,6 +93,9 @@ def _handle_first_message(self, messages: List[Dict[str, str]]) -> str:
def _parse_message(self, message: Dict[str, str]) -> str:
"""Parse a single message based on its role."""
role = message["role"]
content = safety.sanitize_control_tokens(
message["content"], extra_tokens=self._tokens_to_sanitize
)

parser_map = {
"system": self._parse_system,
Expand All @@ -98,24 +107,22 @@ def _parse_message(self, message: Dict[str, str]) -> str:
if role not in parser_map:
raise NotImplementedError(f"Unsupported message role: {role}")

return parser_map[role](message)
return parser_map[role](content)

def _parse_system(self, message: Dict[str, str]) -> str:
return self.tokens.system_token + message["content"] + self.tokens.eot_token
def _parse_system(self, content: str) -> str:
return self.tokens.system_token + content + self.tokens.eot_token

def _parse_user(self, message: Dict[str, str]) -> str:
return self.tokens.user_token + message["content"] + self.tokens.eot_token
def _parse_user(self, content: str) -> str:
return self.tokens.user_token + content + self.tokens.eot_token

def _parse_assistant(self, message: Dict[str, str]) -> str:
return (
self.tokens.assistant_token + message["content"] + self.tokens.eot_token
)
def _parse_assistant(self, content: str) -> str:
return self.tokens.assistant_token + content + self.tokens.eot_token

def _parse_tool(self, message: Dict[str, str]) -> str:
def _parse_tool(self, content: str) -> str:
return (
self.tokens.user_token
+ self.tokens.tool_response_start_token
+ message["content"]
+ content
+ self.tokens.tool_response_end_token
+ self.tokens.eot_token
)
Expand Down Expand Up @@ -170,13 +177,9 @@ def _init_generation_prompt(self) -> str:
def _handle_first_message(self, messages: List[Dict[str, str]]) -> str:
"""Add default system message if first message is not system."""
if messages[0]["role"] != "system":
return self._parse_system({
"role": "system",
"content": (
"You are Qwen, created by Alibaba Cloud. You are a helpful"
" assistant."
),
})
return self._parse_system(
"You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
)
return ""


Expand Down Expand Up @@ -213,13 +216,12 @@ def _init_tokens(self) -> TokenConfig:
assistant_token="<start_of_turn>model\n",
)

def _parse_assistant(self, message: Dict[str, str]) -> str:
return self.tokens.assistant_token + message["content"]
def _parse_assistant(self, content: str) -> str:
return self.tokens.assistant_token + content

def _parse_system(self, message: Dict[str, str]) -> str:
def _parse_system(self, content: str) -> str:
# This should not be called if parse() is used, as it handles the system
# prompt by merging it. Raise error for unexpected system messages.
del message # Unused.
raise ValueError(
"Gemma models do not support system messages directly. The system"
" prompt should be the first message and is handled by merging with"
Expand Down
68 changes: 68 additions & 0 deletions tunix/utils/safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Safety utilities for RL agents."""

import re
from typing import Iterable

_CONTROL_TOKENS = [
'<start_of_turn>',
'<end_of_turn>',
'<bos>',
'<eos>',
'<|begin_of_text|>',
'<|end_of_text|>',
'<|start_header_id|>',
'<|end_header_id|>',
'<|eot_id|>',
'<|im_start|>',
'<|im_end|>',
]

_CONTROL_TOKENS_RE = re.compile('|'.join(map(re.escape, _CONTROL_TOKENS)))


def sanitize_control_tokens(
content: str, extra_tokens: Iterable[str] | None = None
) -> str:
"""Sanitize control tokens from the content.

Args:
content: The content to sanitize.
extra_tokens: Additional tokens to sanitize.

Returns:
The sanitized content.
"""
if not content:
return content

if extra_tokens:
# Combine with default tokens and create a one-off regex.
# Ignore empty strings to prevent matching everywhere.
all_tokens = set(_CONTROL_TOKENS) | {t for t in extra_tokens if t}
regex = re.compile('|'.join(map(re.escape, all_tokens)))
else:
regex = _CONTROL_TOKENS_RE

# Strip known model control tokens to prevent prompt injection.
# We use a loop to handle cases where tokens are nested to bypass sanitization.
sanitized = content
while True:
new_content = regex.sub('', sanitized)
if new_content == sanitized:
break
sanitized = new_content
return sanitized
Loading