-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
89 lines (70 loc) · 2.47 KB
/
main.py
File metadata and controls
89 lines (70 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
/main.py
-------------------------
Adaptive speculative decoding system using LM Studio
Automatically selects optimal model precision (FP16, Q8, Q4)
based on prompt complexity and performs speculative decoding
for faster inference without compromising quality.
"""
from __future__ import annotations
import argparse
import logging
import lmstudio as lms
from complexity import ClassifyPrompt, ComplexityLevels
from utils import detect_device
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("main")
MODELS = {
"Q16": "qwen2.5-1.5b-instruct@fp16",
"Q8": "qwen2.5-1.5b-instruct@q8_0",
"Q4": "qwen2.5-1.5b-instruct@q4_k_m",
}
def speculative_decode(prompt: str) -> str:
"""
Perform speculative decoding
The complexity classifier decides which model precision to load
and whether to use a draft model for speculative decoding.
"""
try:
complexity = ClassifyPrompt.get_complexity(prompt)
logger.info(f"Prompt complexity classified as: {complexity}")
if complexity == ComplexityLevels.low:
model_key = MODELS["Q16"]
draft_key = MODELS["Q4"]
elif complexity == ComplexityLevels.mid:
model_key = MODELS["Q16"]
draft_key = MODELS["Q8"]
else:
model_key = MODELS["Q16"]
draft_key = None
logger.info(f"Loading target model: {model_key}")
model = lms.llm(model_key)
if draft_key:
logger.info(f"Using draft model: {draft_key}")
result = model.respond(prompt, config={"draftModel": draft_key})
else:
result = model.respond(prompt)
return str(result)
except Exception as e:
logger.exception(f"Error during speculative decoding: {e}")
return f"Error: {str(e)}"
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Adaptive speculative decoding using LM Studio.")
parser.add_argument(
"--prompt",
type=str,
required=False,
default="Explain the theorem of CAP in distributed systems for GPUs, highlighting trade-offs and failure cases.",
help="Prompt text for the model",
)
args = parser.parse_args()
device = detect_device()
logger.info(f"Running on device: {device}")
output = speculative_decode(args.prompt)
print("\n" + "=" * 80)
print(output)
print("=" * 80)