44import json
55import yaml
66import argparse
7+ import shutil
78
89
910class TensorInfo :
@@ -67,16 +68,17 @@ def export(self):
6768
6869 # 保存全局配置
6970 self ._save_config ()
71+ self ._copy_tokenizer_files ()
7072
7173 def _save_tensor (self , name , tensor ):
7274 """保存单个张量的元数据和二进制数据"""
73- base_path = os .path .join (self .output_dir , * name .split ('.' ))
75+ # 将名称中的点转换为下划线,并创建统一路径
76+ base_path = os .path .join (self .output_dir , "tensors" , name )
7477 os .makedirs (os .path .dirname (base_path ), exist_ok = True )
7578
7679 # 处理bfloat16类型
7780 dtype_str = str (tensor .dtype ).replace ("torch." , "" )
7881 if dtype_str == "bfloat16" :
79- # 转换为numpy支持的float32格式
8082 tensor = tensor .float ()
8183 dtype_str = "float32"
8284
@@ -103,6 +105,22 @@ def _save_config(self):
103105 'format_version' : 'deepx'
104106 }, f , default_flow_style = False )
105107
108+ def _copy_tokenizer_files (self ):
109+ """复制tokenizer相关文件到输出目录"""
110+ required_files = [
111+ "tokenizer.json" ,
112+ "tokenizer_config.json" ,
113+ "special_tokens_map.json" ,
114+ "vocab.json" ,
115+ "merges.txt" ,
116+ "added_tokens.json"
117+ ]
118+
119+ for filename in required_files :
120+ src = os .path .join (self .model_dir , filename )
121+ if os .path .exists (src ):
122+ shutil .copy2 (src , os .path .join (self .output_dir , filename ))
123+
106124
107125class SafeTensorLoader :
108126 def __init__ (self , model_dir ):
0 commit comments