Skip to content

Commit 99b894a

Browse files
committed
qwen2:infer初步实现推理,待分析问题
1 parent 80a02ff commit 99b894a

3 files changed

Lines changed: 206 additions & 35 deletions

File tree

model/safetensor_deepx/safetensor_deepx/loader.py

Lines changed: 77 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import yaml
66
import argparse
77
import shutil
8+
import glob
9+
import re
810

911

1012
class TensorInfo:
@@ -54,19 +56,37 @@ def _load_config(self):
5456
return json.load(f)
5557
return {}
5658

59+
def _find_model_files(self):
60+
"""查找所有分片模型文件"""
61+
single_file = os.path.join(self.model_dir, "model.safetensors")
62+
shard_files = glob.glob(os.path.join(self.model_dir, "model-*-of-*.safetensors"))
63+
64+
# 使用正则表达式提取分片编号
65+
pattern = re.compile(r"model-(\d+)-of-(\d+)\.safetensors")
66+
filtered_shards = []
67+
for f in shard_files:
68+
match = pattern.search(os.path.basename(f))
69+
if match:
70+
filtered_shards.append( (int(match.group(1)), f) )
71+
72+
if os.path.exists(single_file):
73+
return [single_file]
74+
elif filtered_shards:
75+
# 按分片编号排序后返回路径
76+
filtered_shards.sort(key=lambda x: x[0])
77+
return [f[1] for f in filtered_shards]
78+
raise FileNotFoundError(f"No model files found in {self.model_dir}")
79+
5780
def export(self):
5881
"""导出safetensor模型到指定目录"""
59-
model_path = os.path.join(self.model_dir, "model.safetensors")
60-
if not os.path.exists(model_path):
61-
raise FileNotFoundError(f"找不到模型文件: {model_path}")
62-
63-
# 修改为使用PyTorch框架加载
64-
with safe_open(model_path, framework="pt") as f: # 改为pt框架
65-
for key in f.keys():
66-
tensor = f.get_tensor(key)
67-
self._save_tensor(key, tensor)
82+
model_files = self._find_model_files()
83+
84+
for model_path in model_files:
85+
with safe_open(model_path, framework="pt") as f:
86+
for key in f.keys():
87+
tensor = f.get_tensor(key)
88+
self._save_tensor(key, tensor)
6889

69-
# 保存全局配置
7090
self._save_config()
7191
self._copy_tokenizer_files()
7292

@@ -135,34 +155,57 @@ def _load_config(self):
135155
return json.load(f)
136156
return {}
137157

158+
def _find_model_files(self):
159+
"""查找所有分片模型文件"""
160+
single_file = os.path.join(self.model_dir, "model.safetensors")
161+
shard_files = glob.glob(os.path.join(self.model_dir, "model-*-of-*.safetensors"))
162+
163+
# 统一使用正则表达式匹配
164+
pattern = re.compile(r"model-(\d+)-of-(\d+)\.safetensors")
165+
filtered_shards = []
166+
for f in shard_files:
167+
match = pattern.search(os.path.basename(f))
168+
if match:
169+
filtered_shards.append( (int(match.group(1)), f) )
170+
171+
if os.path.exists(single_file):
172+
return [single_file]
173+
elif filtered_shards:
174+
filtered_shards.sort(key=lambda x: x[0])
175+
return [f[1] for f in filtered_shards]
176+
else:
177+
raise FileNotFoundError(f"No model files found in {self.model_dir}")
178+
138179
def load(self):
139180
"""加载safetensor模型文件"""
140181
tensors = {}
141182
metadata = {}
142-
143-
model_path = os.path.join(self.model_dir, "model.safetensors")
144-
if not os.path.exists(model_path):
145-
raise FileNotFoundError(f"找不到模型文件: {model_path}")
146-
147-
with safe_open(model_path, framework="pt") as f: # 修改为pt框架
148-
metadata = f.metadata() if hasattr(f, 'metadata') else {}
149-
for key in f.keys():
150-
pt_tensor = f.get_tensor(key).cpu().detach() # 获取PyTorch张量
151-
152-
# 构造TensorInfo
153-
tensor_info = TensorInfo(
154-
dtype=str(pt_tensor.dtype).replace("torch.", ""),
155-
ndim=pt_tensor.ndim,
156-
shape=tuple(pt_tensor.shape),
157-
size=pt_tensor.numel(),
158-
strides=pt_tensor.stride() if pt_tensor.is_contiguous() else None
159-
)
160-
161-
# 转换为字节流(保持内存对齐)
162-
byte_buffer = pt_tensor.numpy().tobytes() if pt_tensor.device == "cpu" \
163-
else pt_tensor.cpu().numpy().tobytes()
164-
165-
tensors[key] = Tensor(byte_buffer, tensor_info)
183+
184+
model_files = self._find_model_files()
185+
186+
for model_path in model_files:
187+
with safe_open(model_path, framework="pt") as f:
188+
# 合并metadata
189+
file_metadata = f.metadata() if hasattr(f, 'metadata') else {}
190+
metadata.update(file_metadata)
191+
192+
for key in f.keys():
193+
pt_tensor = f.get_tensor(key).cpu().detach()
194+
195+
# 构造TensorInfo
196+
tensor_info = TensorInfo(
197+
dtype=str(pt_tensor.dtype).replace("torch.", ""),
198+
ndim=pt_tensor.ndim,
199+
shape=tuple(pt_tensor.shape),
200+
size=pt_tensor.numel(),
201+
strides=pt_tensor.stride() if pt_tensor.is_contiguous() else None
202+
)
203+
204+
# 转换为字节流(保持内存对齐)
205+
byte_buffer = pt_tensor.numpy().tobytes() if pt_tensor.device == "cpu" \
206+
else pt_tensor.cpu().numpy().tobytes()
207+
208+
tensors[key] = Tensor(byte_buffer, tensor_info)
166209

167210
metadata["model_config"] = self.config
168211
return tensors, metadata

todo/infer.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import sys
2+
import threading
3+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
4+
import torch
5+
6+
def init_model():
7+
model_path = "/home/lipeng/model/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
8+
try:
9+
tokenizer = AutoTokenizer.from_pretrained(model_path)
10+
tokenizer.pad_token = tokenizer.eos_token
11+
12+
model = AutoModelForCausalLM.from_pretrained(
13+
model_path,
14+
trust_remote_code=True,
15+
torch_dtype=torch.bfloat16,
16+
device_map="auto",
17+
# use_flash_attention_2=True # 启用Flash Attention
18+
).eval()
19+
20+
return model, tokenizer
21+
except Exception as e:
22+
raise RuntimeError(f"模型初始化失败: {str(e)}")
23+
24+
class StdoutStreamer(TextStreamer):
25+
def __init__(self, tokenizer):
26+
super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True)
27+
self.cache = []
28+
self.first_token = True
29+
30+
def on_finalized_text(self, text: str, stream_end: bool = False):
31+
self.cache.append(text)
32+
if stream_end or len(self.cache) >= 2:
33+
full_text = "".join(self.cache)
34+
sys.stdout.write(full_text)
35+
sys.stdout.flush()
36+
self.cache = []
37+
38+
def generate_stream(model, tokenizer, text, max_length):
39+
formatted_text = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n"
40+
inputs = tokenizer(
41+
formatted_text,
42+
return_tensors='pt',
43+
add_special_tokens=False,
44+
return_attention_mask=True
45+
).to(model.device)
46+
streamer = StdoutStreamer(tokenizer)
47+
48+
generation_kwargs = {
49+
"input_ids": inputs.input_ids,
50+
"attention_mask": inputs.attention_mask,
51+
"max_new_tokens": max_length,
52+
"pad_token_id": tokenizer.eos_token_id,
53+
"temperature": 0.3, # 降低随机性
54+
"top_p": 0.85, # 限制采样范围
55+
"repetition_penalty": 1.2, # 增强重复抑制
56+
"do_sample": True,
57+
"streamer": streamer
58+
}
59+
60+
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
61+
thread.start()
62+
thread.join()
63+
print("\n") # 流式结束换行
64+
65+
def generate_text(model, tokenizer, text, max_length=50):
66+
formatted_text = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n"
67+
inputs = tokenizer(
68+
formatted_text,
69+
return_tensors='pt',
70+
add_special_tokens=False,
71+
return_attention_mask=True
72+
).to(model.device)
73+
74+
with torch.no_grad():
75+
output = model.generate(
76+
inputs.input_ids,
77+
attention_mask=inputs.attention_mask,
78+
max_new_tokens=max_length,
79+
pad_token_id=tokenizer.eos_token_id,
80+
temperature=0.3,
81+
top_p=0.85,
82+
repetition_penalty=1.2,
83+
do_sample=True
84+
)
85+
86+
return tokenizer.decode(
87+
output[0][len(inputs.input_ids[0]):],
88+
skip_special_tokens=True,
89+
clean_up_tokenization_spaces=True
90+
)
91+
92+
def main():
93+
try:
94+
model, tokenizer = init_model()
95+
sys.stderr.write("模型加载成功,输入提示开始生成(Ctrl+C退出)\n")
96+
except Exception as e:
97+
sys.stderr.write(f"服务启动失败: {e}\n")
98+
return
99+
100+
# 单独测试分词器
101+
text = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n"
102+
tokens = tokenizer.encode(text, add_special_tokens=False)
103+
decoded = tokenizer.decode(tokens)
104+
assert decoded == text # 验证编码解码一致性
105+
try:
106+
for line in sys.stdin:
107+
text = line.strip()
108+
if not text:
109+
continue
110+
111+
# 固定参数设置
112+
max_length = 2048 # 最大生成长度
113+
stream = True # 始终使用流式
114+
115+
if stream:
116+
generate_stream(model, tokenizer, text, max_length)
117+
else:
118+
result = generate_text(model, tokenizer, text, max_length)
119+
print(result)
120+
121+
except KeyboardInterrupt:
122+
sys.stderr.write("\n服务已终止\n")
123+
except Exception as e:
124+
sys.stderr.write(f"运行时错误: {str(e)}\n")
125+
126+
if __name__ == '__main__':
127+
main()
128+
129+

todo/qwen2_infer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
import torch.nn as nn
1010
from transformers import AutoTokenizer
11-
import re
1211

1312
class ModelConfig:
1413
def __init__(self):

0 commit comments

Comments
 (0)