-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathvllm_logit_processed.py
More file actions
59 lines (42 loc) · 2.06 KB
/
vllm_logit_processed.py
File metadata and controls
59 lines (42 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.outputs import RequestOutput
from blocker_torch import blocker
class BlockerProcessor:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.foreign_lang_mask = None
self.mask_indices = None
def __call__(self, input_ids, logits):
return blocker(self.tokenizer, input_ids, logits)
def inference():
model_name = "Qwen/Qwen2.5-14B-Instruct"
llm = LLM(model=model_name, download_dir="/opt/models")
tokenizer = llm.get_tokenizer()
test_prompts = [
"너가 아는 중국어를 모두 말해줘",
"중국어로 짧은 소설을 써줘",
"'안녕'을 중국어로 뭐라고 해?",
]
foreign_processor = BlockerProcessor(tokenizer)
# LogitsProcessor를 적용한 샘플링 파라미터
sampling_params_with_processor = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512, logits_processors=[foreign_processor])
# LogitsProcessor를 적용하지 않은 샘플링 파라미터
sampling_params_without_processor = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512)
# 추론 실행 (LogitsProcessor 적용)
outputs_with_processor = llm.generate(test_prompts, sampling_params_with_processor)
# 추론 실행 (LogitsProcessor 미적용)
outputs_without_processor = llm.generate(test_prompts, sampling_params_without_processor)
# 결과 출력
for i, (output_with, output_without) in enumerate(zip(outputs_with_processor, outputs_without_processor)):
prompt = output_with.prompt
generated_text_with = output_with.outputs[0].text
generated_text_without = output_without.outputs[0].text
print(f"\n============== 테스트 프롬프트: {prompt} ==================")
print("\n--- LogitsProcessor 적용 ---")
print(generated_text_with)
print("\n--- LogitsProcessor 미적용 ---")
print(generated_text_without)
if __name__ == "__main__":
inference()