|
5 | 5 | import yaml |
6 | 6 | import argparse |
7 | 7 | import shutil |
| 8 | +import glob |
| 9 | +import re |
8 | 10 |
|
9 | 11 |
|
10 | 12 | class TensorInfo: |
@@ -54,19 +56,37 @@ def _load_config(self): |
54 | 56 | return json.load(f) |
55 | 57 | return {} |
56 | 58 |
|
| 59 | + def _find_model_files(self): |
| 60 | + """查找所有分片模型文件""" |
| 61 | + single_file = os.path.join(self.model_dir, "model.safetensors") |
| 62 | + shard_files = glob.glob(os.path.join(self.model_dir, "model-*-of-*.safetensors")) |
| 63 | + |
| 64 | + # 使用正则表达式提取分片编号 |
| 65 | + pattern = re.compile(r"model-(\d+)-of-(\d+)\.safetensors") |
| 66 | + filtered_shards = [] |
| 67 | + for f in shard_files: |
| 68 | + match = pattern.search(os.path.basename(f)) |
| 69 | + if match: |
| 70 | + filtered_shards.append( (int(match.group(1)), f) ) |
| 71 | + |
| 72 | + if os.path.exists(single_file): |
| 73 | + return [single_file] |
| 74 | + elif filtered_shards: |
| 75 | + # 按分片编号排序后返回路径 |
| 76 | + filtered_shards.sort(key=lambda x: x[0]) |
| 77 | + return [f[1] for f in filtered_shards] |
| 78 | + raise FileNotFoundError(f"No model files found in {self.model_dir}") |
| 79 | + |
57 | 80 | def export(self): |
58 | 81 | """导出safetensor模型到指定目录""" |
59 | | - model_path = os.path.join(self.model_dir, "model.safetensors") |
60 | | - if not os.path.exists(model_path): |
61 | | - raise FileNotFoundError(f"找不到模型文件: {model_path}") |
62 | | - |
63 | | - # 修改为使用PyTorch框架加载 |
64 | | - with safe_open(model_path, framework="pt") as f: # 改为pt框架 |
65 | | - for key in f.keys(): |
66 | | - tensor = f.get_tensor(key) |
67 | | - self._save_tensor(key, tensor) |
| 82 | + model_files = self._find_model_files() |
| 83 | + |
| 84 | + for model_path in model_files: |
| 85 | + with safe_open(model_path, framework="pt") as f: |
| 86 | + for key in f.keys(): |
| 87 | + tensor = f.get_tensor(key) |
| 88 | + self._save_tensor(key, tensor) |
68 | 89 |
|
69 | | - # 保存全局配置 |
70 | 90 | self._save_config() |
71 | 91 | self._copy_tokenizer_files() |
72 | 92 |
|
@@ -135,34 +155,57 @@ def _load_config(self): |
135 | 155 | return json.load(f) |
136 | 156 | return {} |
137 | 157 |
|
| 158 | + def _find_model_files(self): |
| 159 | + """查找所有分片模型文件""" |
| 160 | + single_file = os.path.join(self.model_dir, "model.safetensors") |
| 161 | + shard_files = glob.glob(os.path.join(self.model_dir, "model-*-of-*.safetensors")) |
| 162 | + |
| 163 | + # 统一使用正则表达式匹配 |
| 164 | + pattern = re.compile(r"model-(\d+)-of-(\d+)\.safetensors") |
| 165 | + filtered_shards = [] |
| 166 | + for f in shard_files: |
| 167 | + match = pattern.search(os.path.basename(f)) |
| 168 | + if match: |
| 169 | + filtered_shards.append( (int(match.group(1)), f) ) |
| 170 | + |
| 171 | + if os.path.exists(single_file): |
| 172 | + return [single_file] |
| 173 | + elif filtered_shards: |
| 174 | + filtered_shards.sort(key=lambda x: x[0]) |
| 175 | + return [f[1] for f in filtered_shards] |
| 176 | + else: |
| 177 | + raise FileNotFoundError(f"No model files found in {self.model_dir}") |
| 178 | + |
138 | 179 | def load(self): |
139 | 180 | """加载safetensor模型文件""" |
140 | 181 | tensors = {} |
141 | 182 | metadata = {} |
142 | | - |
143 | | - model_path = os.path.join(self.model_dir, "model.safetensors") |
144 | | - if not os.path.exists(model_path): |
145 | | - raise FileNotFoundError(f"找不到模型文件: {model_path}") |
146 | | - |
147 | | - with safe_open(model_path, framework="pt") as f: # 修改为pt框架 |
148 | | - metadata = f.metadata() if hasattr(f, 'metadata') else {} |
149 | | - for key in f.keys(): |
150 | | - pt_tensor = f.get_tensor(key).cpu().detach() # 获取PyTorch张量 |
151 | | - |
152 | | - # 构造TensorInfo |
153 | | - tensor_info = TensorInfo( |
154 | | - dtype=str(pt_tensor.dtype).replace("torch.", ""), |
155 | | - ndim=pt_tensor.ndim, |
156 | | - shape=tuple(pt_tensor.shape), |
157 | | - size=pt_tensor.numel(), |
158 | | - strides=pt_tensor.stride() if pt_tensor.is_contiguous() else None |
159 | | - ) |
160 | | - |
161 | | - # 转换为字节流(保持内存对齐) |
162 | | - byte_buffer = pt_tensor.numpy().tobytes() if pt_tensor.device == "cpu" \ |
163 | | - else pt_tensor.cpu().numpy().tobytes() |
164 | | - |
165 | | - tensors[key] = Tensor(byte_buffer, tensor_info) |
| 183 | + |
| 184 | + model_files = self._find_model_files() |
| 185 | + |
| 186 | + for model_path in model_files: |
| 187 | + with safe_open(model_path, framework="pt") as f: |
| 188 | + # 合并metadata |
| 189 | + file_metadata = f.metadata() if hasattr(f, 'metadata') else {} |
| 190 | + metadata.update(file_metadata) |
| 191 | + |
| 192 | + for key in f.keys(): |
| 193 | + pt_tensor = f.get_tensor(key).cpu().detach() |
| 194 | + |
| 195 | + # 构造TensorInfo |
| 196 | + tensor_info = TensorInfo( |
| 197 | + dtype=str(pt_tensor.dtype).replace("torch.", ""), |
| 198 | + ndim=pt_tensor.ndim, |
| 199 | + shape=tuple(pt_tensor.shape), |
| 200 | + size=pt_tensor.numel(), |
| 201 | + strides=pt_tensor.stride() if pt_tensor.is_contiguous() else None |
| 202 | + ) |
| 203 | + |
| 204 | + # 转换为字节流(保持内存对齐) |
| 205 | + byte_buffer = pt_tensor.numpy().tobytes() if pt_tensor.device == "cpu" \ |
| 206 | + else pt_tensor.cpu().numpy().tobytes() |
| 207 | + |
| 208 | + tensors[key] = Tensor(byte_buffer, tensor_info) |
166 | 209 |
|
167 | 210 | metadata["model_config"] = self.config |
168 | 211 | return tensors, metadata |
|
0 commit comments