|
| 1 | +# data/dataset.py |
| 2 | +import json |
| 3 | +import os |
| 4 | +import torch |
| 5 | +from torch.utils.data import Dataset |
| 6 | + |
| 7 | +class ConversationDataset(Dataset): |
| 8 | + """对话数据集 |
| 9 | +
|
| 10 | + 支持以下格式: |
| 11 | + - 'raw': 原始文本格式,每行一个对话 |
| 12 | + - 'sharegpt': ShareGPT 格式,每行一个 JSON 对象 |
| 13 | + - 'alpaca': Alpaca 格式,包含 instruction, input, output 字段 |
| 14 | + """ |
| 15 | + def __init__(self, file_path, tokenizer, max_length=1024, data_format='raw', sharegpt_config=None): |
| 16 | + """ |
| 17 | + Args: |
| 18 | + file_path: 数据文件路径 |
| 19 | + tokenizer: HuggingFace tokenizer |
| 20 | + max_length: 最大序列长度 |
| 21 | + data_format: 数据格式 ('raw', 'sharegpt', 'alpaca') |
| 22 | + sharegpt_config: ShareGPT 格式配置 |
| 23 | +
|
| 24 | + Raises: |
| 25 | + FileNotFoundError: 文件不存在 |
| 26 | + ValueError: 参数无效或数据格式错误 |
| 27 | + """ |
| 28 | + # 参数验证 |
| 29 | + if not os.path.exists(file_path): |
| 30 | + raise FileNotFoundError(f"数据文件不存在: {file_path}") |
| 31 | + |
| 32 | + if data_format not in ['raw', 'sharegpt', 'alpaca']: |
| 33 | + raise ValueError(f"不支持的 data_format: {data_format},必须是 'raw', 'sharegpt' 或 'alpaca'") |
| 34 | + |
| 35 | + self.tokenizer = tokenizer |
| 36 | + self.max_length = max_length |
| 37 | + self.data_format = data_format |
| 38 | + self.sharegpt_config = sharegpt_config or {} |
| 39 | + |
| 40 | + # 加载数据 |
| 41 | + self.examples = self._load_data(file_path) |
| 42 | + |
| 43 | + # 设置 pad token |
| 44 | + if self.tokenizer.pad_token is None: |
| 45 | + self.tokenizer.pad_token = self.tokenizer.eos_token |
| 46 | + |
| 47 | + def _load_data(self, file_path): |
| 48 | + """加载数据文件""" |
| 49 | + try: |
| 50 | + with open(file_path, 'r', encoding='utf-8') as f: |
| 51 | + if self.data_format == 'sharegpt': |
| 52 | + # 每行是一个JSON对象 |
| 53 | + examples = [] |
| 54 | + for line_num, line in enumerate(f, 1): |
| 55 | + if line.strip(): |
| 56 | + try: |
| 57 | + examples.append(json.loads(line)) |
| 58 | + except json.JSONDecodeError as e: |
| 59 | + raise ValueError(f"第 {line_num} 行 JSON 解析失败: {e}") |
| 60 | + return examples |
| 61 | + elif self.data_format == 'alpaca': |
| 62 | + # 每行是一个JSON对象 |
| 63 | + examples = [] |
| 64 | + for line_num, line in enumerate(f, 1): |
| 65 | + if line.strip(): |
| 66 | + try: |
| 67 | + examples.append(json.loads(line)) |
| 68 | + except json.JSONDecodeError as e: |
| 69 | + raise ValueError(f"第 {line_num} 行 JSON 解析失败: {e}") |
| 70 | + return examples |
| 71 | + else: |
| 72 | + # raw 格式:每行是一个文本 |
| 73 | + return [line.strip() for line in f if line.strip()] |
| 74 | + except UnicodeDecodeError as e: |
| 75 | + raise ValueError(f"文件编码错误: {e}") |
| 76 | + except IOError as e: |
| 77 | + raise IOError(f"读取文件失败: {e}") |
| 78 | + |
| 79 | + def __len__(self): |
| 80 | + return len(self.examples) |
| 81 | + |
| 82 | + def __getitem__(self, idx): |
| 83 | + """获取数据样本""" |
| 84 | + try: |
| 85 | + if self.data_format == 'sharegpt': |
| 86 | + return self._process_sharegpt(idx) |
| 87 | + elif self.data_format == 'alpaca': |
| 88 | + return self._process_alpaca(idx) |
| 89 | + else: |
| 90 | + return self._process_raw(idx) |
| 91 | + except Exception as e: |
| 92 | + raise ValueError(f"处理第 {idx} 个样本时出错: {e}") |
| 93 | + |
| 94 | + def _process_raw(self, idx): |
| 95 | + """处理 raw 格式数据""" |
| 96 | + text = self.examples[idx] |
| 97 | + |
| 98 | + enc = self.tokenizer( |
| 99 | + text, |
| 100 | + truncation=True, |
| 101 | + padding='max_length', |
| 102 | + max_length=self.max_length, |
| 103 | + return_tensors='pt' |
| 104 | + ) |
| 105 | + |
| 106 | + input_ids = enc['input_ids'].squeeze() |
| 107 | + attention_mask = enc['attention_mask'].squeeze() |
| 108 | + |
| 109 | + # 对于 raw 格式,简单地将整个序列作为 labels |
| 110 | + # 注意:这种方式会训练模型预测整个序列,包括 user 部分 |
| 111 | + # 如果有明确的 user/assistant 分隔,建议使用 sharegpt 格式 |
| 112 | + labels = input_ids.clone() |
| 113 | + |
| 114 | + return { |
| 115 | + 'input_ids': input_ids, |
| 116 | + 'attention_mask': attention_mask, |
| 117 | + 'labels': labels |
| 118 | + } |
| 119 | + |
| 120 | + def _process_sharegpt(self, idx): |
| 121 | + """处理 ShareGPT 格式数据 |
| 122 | +
|
| 123 | + 只计算 assistant 回复部分的损失,将 user 部分的 labels 设为 -100 |
| 124 | + """ |
| 125 | + conv = self.examples[idx]['conversations'] |
| 126 | + |
| 127 | + # 获取角色配置 |
| 128 | + human_role = self.sharegpt_config.get('human_role', 'human') |
| 129 | + assistant_role = self.sharegpt_config.get('assistant_role', 'gpt') |
| 130 | + |
| 131 | + # 构建文本 |
| 132 | + text_parts = [] |
| 133 | + label_parts = [] |
| 134 | + |
| 135 | + for turn in conv: |
| 136 | + if turn['from'] == human_role: |
| 137 | + # user 输入:不计算损失 |
| 138 | + user_text = f"User: {turn['value']}\n" |
| 139 | + text_parts.append(user_text) |
| 140 | + label_parts.append(None) # 标记为不计算损失 |
| 141 | + elif turn['from'] == assistant_role: |
| 142 | + # assistant 回复:计算损失 |
| 143 | + assistant_text = f"Assistant: {turn['value']}\n" |
| 144 | + text_parts.append(assistant_text) |
| 145 | + label_parts.append(True) # 标记为计算损失 |
| 146 | + |
| 147 | + # 添加最后的提示符 |
| 148 | + text_parts.append("Assistant:") |
| 149 | + label_parts.append(False) # 提示符不计算损失 |
| 150 | + |
| 151 | + # 合并文本 |
| 152 | + full_text = "".join(text_parts) |
| 153 | + |
| 154 | + # tokenize |
| 155 | + enc = self.tokenizer( |
| 156 | + full_text, |
| 157 | + truncation=True, |
| 158 | + padding='max_length', |
| 159 | + max_length=self.max_length, |
| 160 | + return_tensors='pt' |
| 161 | + ) |
| 162 | + |
| 163 | + input_ids = enc['input_ids'].squeeze() |
| 164 | + attention_mask = enc['attention_mask'].squeeze() |
| 165 | + |
| 166 | + # 构建正确的 labels:只计算 assistant 回复部分 |
| 167 | + labels = input_ids.clone() |
| 168 | + |
| 169 | + # 计算各部分的 token 范围 |
| 170 | + if len(text_parts) > 1: |
| 171 | + cumulative_text = "" |
| 172 | + token_positions = [] |
| 173 | + |
| 174 | + # 逐部分累积文本并记录对应的 token 范围 |
| 175 | + for part in text_parts: |
| 176 | + part_start = len(self.tokenizer( |
| 177 | + cumulative_text, |
| 178 | + return_tensors='pt', |
| 179 | + add_special_tokens=False |
| 180 | + )['input_ids'][0]) |
| 181 | + |
| 182 | + cumulative_text += part |
| 183 | + |
| 184 | + part_end = len(self.tokenizer( |
| 185 | + cumulative_text, |
| 186 | + return_tensors='pt', |
| 187 | + add_special_tokens=False |
| 188 | + )['input_ids'][0]) |
| 189 | + |
| 190 | + token_positions.append((part_start, part_end)) |
| 191 | + |
| 192 | + # 将应该忽略的部分设为 -100 |
| 193 | + for i, (pos, should_label) in enumerate(zip(token_positions, label_parts)): |
| 194 | + if not should_label: |
| 195 | + start_pos, end_pos = pos |
| 196 | + if start_pos < len(labels): |
| 197 | + end_pos = min(end_pos, len(labels)) |
| 198 | + labels[start_pos:end_pos] = -100 |
| 199 | + |
| 200 | + return { |
| 201 | + 'input_ids': input_ids, |
| 202 | + 'attention_mask': attention_mask, |
| 203 | + 'labels': labels |
| 204 | + } |
| 205 | + |
| 206 | + def _process_alpaca(self, idx): |
| 207 | + """处理 Alpaca 格式数据 |
| 208 | +
|
| 209 | + Alpaca 格式包含: |
| 210 | + - instruction: 指令 |
| 211 | + - input: 输入(可选) |
| 212 | + - output: 输出 |
| 213 | + """ |
| 214 | + example = self.examples[idx] |
| 215 | + |
| 216 | + # 验证必需字段 |
| 217 | + if 'instruction' not in example or 'output' not in example: |
| 218 | + raise ValueError(f"Alpaca 格式缺少必需字段 'instruction' 或 'output'") |
| 219 | + |
| 220 | + instruction = example['instruction'] |
| 221 | + input_text = example.get('input', '') |
| 222 | + output = example['output'] |
| 223 | + |
| 224 | + # 构建提示词 |
| 225 | + if input_text: |
| 226 | + prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
| 227 | +
|
| 228 | +### Instruction: |
| 229 | +{instruction} |
| 230 | +
|
| 231 | +### Input: |
| 232 | +{input_text} |
| 233 | +
|
| 234 | +### Response: |
| 235 | +""" |
| 236 | + else: |
| 237 | + prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. |
| 238 | +
|
| 239 | +### Instruction: |
| 240 | +{instruction} |
| 241 | +
|
| 242 | +### Response: |
| 243 | +""" |
| 244 | + |
| 245 | + # 完整文本 = prompt + output |
| 246 | + full_text = prompt + output |
| 247 | + |
| 248 | + # tokenize 完整文本 |
| 249 | + enc = self.tokenizer( |
| 250 | + full_text, |
| 251 | + truncation=True, |
| 252 | + padding='max_length', |
| 253 | + max_length=self.max_length, |
| 254 | + return_tensors='pt' |
| 255 | + ) |
| 256 | + |
| 257 | + input_ids = enc['input_ids'].squeeze() |
| 258 | + attention_mask = enc['attention_mask'].squeeze() |
| 259 | + |
| 260 | + # 构建 labels:只计算 output 部分 |
| 261 | + labels = input_ids.clone() |
| 262 | + |
| 263 | + # 找到 prompt 的结束位置 |
| 264 | + prompt_enc = self.tokenizer( |
| 265 | + prompt, |
| 266 | + return_tensors='pt', |
| 267 | + add_special_tokens=False |
| 268 | + ) |
| 269 | + prompt_len = prompt_enc['input_ids'].shape[1] |
| 270 | + |
| 271 | + # 将 prompt 部分设为 -100(不计算损失) |
| 272 | + labels[:prompt_len] = -100 |
| 273 | + |
| 274 | + return { |
| 275 | + 'input_ids': input_ids, |
| 276 | + 'attention_mask': attention_mask, |
| 277 | + 'labels': labels |
| 278 | + } |
0 commit comments