Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions swift/rlhf_trainers/gkd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ def gkd_loss(
def build_opsd_teacher_data(inputs, strip_assistant=False):
"""Build teacher data for OPSD by replacing the last user message with teacher_prompt.

teacher_prompt supports two formats:
1. str: Used only as the user content for the teacher, replacing the last user message.
2. list[dict]: A complete list of messages (including system/user/assistant),
which directly replaces the teacher's messages (retaining the assistant from the original data as the response).

Args:
inputs: list of data dicts, each may contain 'teacher_prompt'
strip_assistant: if True, remove trailing assistant message before replacement
Expand All @@ -277,13 +282,32 @@ def build_opsd_teacher_data(inputs, strip_assistant=False):
result = []
for data in inputs:
item = {k: v for k, v in data.items() if k != 'teacher_prompt'}
messages = [dict(m) for m in data.get('messages', [])]
if strip_assistant and messages and messages[-1]['role'] == 'assistant':
messages.pop()
for msg in reversed(messages):
if msg['role'] == 'user':
msg['content'] = data['teacher_prompt']
break
item['messages'] = messages
tp = data['teacher_prompt']

if isinstance(tp, list):
# teacher_prompt is already a complete list of messages (including system/user/optional assistant)
teacher_messages = [dict(m) for m in tp]
# Fallback: If teacher_prompt does not end with an assistant message, fetch it from the original messages
if (not teacher_messages) or teacher_messages[-1].get('role') != 'assistant':
orig = data.get('messages', [])
for m in reversed(orig):
if m.get('role') == 'assistant':
teacher_messages.append(dict(m))
break
if strip_assistant and teacher_messages and teacher_messages[-1].get('role') == 'assistant':
teacher_messages.pop()
Comment on lines +290 to +298

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When strip_assistant is True and teacher_prompt (as a list) does not end with an assistant message, the code currently fetches the assistant message from the original messages, appends it, and then immediately pops it. This is redundant and inefficient. We can optimize this by only running the fallback logic when strip_assistant is False.

Suggested change
# Fallback: If teacher_prompt does not end with an assistant message, fetch it from the original messages
if (not teacher_messages) or teacher_messages[-1].get('role') != 'assistant':
orig = data.get('messages', [])
for m in reversed(orig):
if m.get('role') == 'assistant':
teacher_messages.append(dict(m))
break
if strip_assistant and teacher_messages and teacher_messages[-1].get('role') == 'assistant':
teacher_messages.pop()
if strip_assistant:
if teacher_messages and teacher_messages[-1].get('role') == 'assistant':
teacher_messages.pop()
else:
# Fallback: If teacher_prompt does not end with an assistant message, fetch it from the original messages
if (not teacher_messages) or teacher_messages[-1].get('role') != 'assistant':
orig = data.get('messages', [])
for m in reversed(orig):
if m.get('role') == 'assistant':
teacher_messages.append(dict(m))
break

item['messages'] = teacher_messages
else:
# teacher_prompt is a string: Replace the content of the last user message in the original messages
messages = [dict(m) for m in data.get('messages', [])]
if strip_assistant and messages and messages[-1]['role'] == 'assistant':
messages.pop()
for msg in reversed(messages):
if msg['role'] == 'user':
msg['content'] = tp
break
item['messages'] = messages

result.append(item)
return result