Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion swift/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def _to_std_key(messages: List[Dict[str, str]], std_key: str, optional_keys: Lis
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if 'rejected_messages' in row:
row['rejected_messages'] = MessagesPreprocessor.preprocess(
self, {'messages': row['rejected_messages']})['messages']
self, {'messages': row['rejected_messages'], 'system': row.get('system')})['messages']
Comment on lines 515 to +517

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If MessagesPreprocessor.preprocess returns None (which happens if rejected_messages is empty, invalid, or cannot be repaired), attempting to access ['messages'] directly on the result will raise a TypeError: 'NoneType' object is not subscriptable.

To make this preprocessing step more robust and prevent potential crashes, we should check if the preprocessed rejected row is None and return None for the entire row to filter it out gracefully.

Suggested change
if 'rejected_messages' in row:
row['rejected_messages'] = MessagesPreprocessor.preprocess(
self, {'messages': row['rejected_messages']})['messages']
self, {'messages': row['rejected_messages'], 'system': row.get('system')})['messages']
if 'rejected_messages' in row:
rejected_row = MessagesPreprocessor.preprocess(
self, {'messages': row['rejected_messages'], 'system': row.get('system')})
if rejected_row is None:
return None
row['rejected_messages'] = rejected_row['messages']

messages = row['messages']
if self.inner_key is not None:
messages = messages[self.inner_key]
Expand Down
33 changes: 32 additions & 1 deletion tests/general/test_data_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from swift.dataset import EncodePreprocessor, PackingDataset, load_dataset
from swift.dataset import EncodePreprocessor, MessagesPreprocessor, PackingDataset, load_dataset
from swift.model import get_processor
from swift.template import get_template

Expand Down Expand Up @@ -147,5 +147,36 @@ def test_packing_dataset(self):
self.assertIn('labels', packed[0])


class TestMessagesPreprocessor(unittest.TestCase):
"""Unit tests for MessagesPreprocessor (no model required)."""

def test_system_propagated_to_rejected_messages(self):
"""system message must be prepended to rejected_messages, not only to messages.

Regression test for the bug where the recursive preprocess() call for
rejected_messages omitted the 'system' key, causing chosen and rejected
to have asymmetric conversation prefixes during DPO training.
"""
row = {
'messages': [
{'role': 'user', 'content': 'Q'},
{'role': 'assistant', 'content': 'good'},
],
'rejected_messages': [
{'role': 'user', 'content': 'Q'},
{'role': 'assistant', 'content': 'bad'},
],
'system': 'You are helpful.',
}
result = MessagesPreprocessor().preprocess(row)
self.assertIsNotNone(result)
self.assertEqual(result['messages'][0]['role'], 'system',
'system message should be first in chosen messages')
self.assertEqual(result['rejected_messages'][0]['role'], 'system',
'system message should also be first in rejected_messages')
self.assertEqual(result['messages'][0]['content'], 'You are helpful.')
self.assertEqual(result['rejected_messages'][0]['content'], 'You are helpful.')


if __name__ == '__main__':
unittest.main()