diff --git a/swift/rlhf_trainers/gkd_loss.py b/swift/rlhf_trainers/gkd_loss.py index 2e749e691c..6d67d41bbb 100644 --- a/swift/rlhf_trainers/gkd_loss.py +++ b/swift/rlhf_trainers/gkd_loss.py @@ -265,6 +265,11 @@ def gkd_loss( def build_opsd_teacher_data(inputs, strip_assistant=False): """Build teacher data for OPSD by replacing the last user message with teacher_prompt. + teacher_prompt supports two formats: + 1. str: Used only as the user content for the teacher, replacing the last user message. + 2. list[dict]: A complete list of messages (including system/user/assistant), + which directly replaces the teacher's messages (retaining the assistant from the original data as the response). + Args: inputs: list of data dicts, each may contain 'teacher_prompt' strip_assistant: if True, remove trailing assistant message before replacement @@ -277,13 +282,32 @@ def build_opsd_teacher_data(inputs, strip_assistant=False): result = [] for data in inputs: item = {k: v for k, v in data.items() if k != 'teacher_prompt'} - messages = [dict(m) for m in data.get('messages', [])] - if strip_assistant and messages and messages[-1]['role'] == 'assistant': - messages.pop() - for msg in reversed(messages): - if msg['role'] == 'user': - msg['content'] = data['teacher_prompt'] - break - item['messages'] = messages + tp = data['teacher_prompt'] + + if isinstance(tp, list): + # teacher_prompt is already a complete list of messages (including system/user/optional assistant) + teacher_messages = [dict(m) for m in tp] + # Fallback: If teacher_prompt does not end with an assistant message, fetch it from the original messages + if (not teacher_messages) or teacher_messages[-1].get('role') != 'assistant': + orig = data.get('messages', []) + for m in reversed(orig): + if m.get('role') == 'assistant': + teacher_messages.append(dict(m)) + break + if strip_assistant and teacher_messages and teacher_messages[-1].get('role') == 'assistant': + teacher_messages.pop() + item['messages'] = teacher_messages + else: + # teacher_prompt is a string: Replace the content of the last user message in the original messages + messages = [dict(m) for m in data.get('messages', [])] + if strip_assistant and messages and messages[-1]['role'] == 'assistant': + messages.pop() + for msg in reversed(messages): + if msg['role'] == 'user': + msg['content'] = tp + break + item['messages'] = messages + result.append(item) return result +