-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllama_server.py
More file actions
277 lines (234 loc) · 12 KB
/
llama_server.py
File metadata and controls
277 lines (234 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
from __future__ import annotations
import sys, argparse
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
class SimpleTokenizer:
def __init__(self, vocab: list[str]):
self.vocab: list[str] = vocab
self.biggest_token: int = max(map(len, vocab))
self.token_to_id: dict[str, int] = {tok: i for i, tok in enumerate(vocab)}
self.replace_space = "Ġ"
self.replace_newline = "Ċ"
def encode(self, text:str) -> list[int]:
s = text.replace(" ", self.replace_space).replace("\n", self.replace_newline)
out: list[int] = []
i = 0
while i < len(s):
j = min(i+self.biggest_token, len(s))
while i < j and (tid:=self.token_to_id.get(s[i:j])) is None: j -= 1
if tid is None: raise RuntimeError(f"token not found in {s}")
assert tid is not None, f"token not found in {s}"
out.append(tid)
i = j
return out
def decode(self, ids: list[int]) -> str:
return ''.join(self.vocab[tid] for tid in ids).replace(self.replace_space, " ").replace(self.replace_newline, "\n")
def role(self, role:str):
return [t for x in ["<|start_header_id|>", role, "<|end_header_id|>\n\n"] for t in self.encode(x)] # llama style
def apply_rope(x:Tensor, start_pos:int|UOp, base:int=10000):
B, H, T, Hd = x.shape
# NOTE: this is usually in a RoPE cache, but tinygrad JIT should prune it outside the kernel
# TODO: make it do that
freq = base ** (-Tensor.arange(0, 1, 2/Hd, dtype='float32'))
angles = Tensor.arange(start_pos, start_pos+T, dtype='float32')[None, None, :, None] * freq
cos, sin = angles.cos(), angles.sin()
x = x.reshape(B, H, T, Hd // 2, 2) # split into pairs
y1 = x[..., 0] * cos - x[..., 1] * sin
y2 = x[..., 0] * sin + x[..., 1] * cos
return Tensor.stack(y1, y2, dim=-1).reshape(B, H, T, Hd)
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
self.max_context = max_context
# --- attention projections (all linear, bias-free) ------------------
kv_proj_out = self.head_dim * n_kv_heads # Llama-3 uses the same dim for K/V
self.attn_q = nn.Linear(dim, dim, bias=False)
self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
self.attn_output = nn.Linear(dim, dim, bias=False)
# --- RMSNorms --------------------------------------------------------
self.attn_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
# --- feed-forward ----------------------------------------------------
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
x_norm = self.attn_norm(x) # (B,T,D)
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
B, T, _ = x.shape
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
q = apply_rope(q, start_pos)
k = apply_rope(k, start_pos)
# TODO: remove these kv cache realizes
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize()
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() # type: ignore
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
attn = self.attn_output(attn)
return x + attn
def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h)
gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm)
return h + self.ffn_down(gated)
def __call__(self, x: Tensor, start_pos: int|UOp):
return self._feed_forward(self._attention(x, start_pos))
class Transformer:
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, max_context):
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context) for _ in range(num_blocks)]
self.token_embd = nn.Embedding(vocab_size, dim)
self.output_norm = nn.RMSNorm(dim, norm_eps)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.max_context = max_context
# JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp
self.forward_jit = TinyJit(self.forward)
def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
x = self.token_embd(tokens) # (B, T, D)
for block in self.blk: x = block(x, start_pos)
# TODO: add temperature
return self.output(self.output_norm(x))[:, -1, :].softmax(-1).argmax(-1, keepdim=True)
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
@staticmethod
def from_gguf(gguf:Tensor, max_context:int|None=None) -> tuple[Transformer, dict]:
# TODO: remove the need for copy to default device
kv, state_dict = nn.state.gguf_load(gguf.to(None))
# all state items should be float16, not float32
state_dict = {k:v.cast('float16') for k,v in state_dict.items()}
# some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight
if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight']
arch = kv['general.architecture']
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
n_heads=kv[f'{arch}.attention.head_count'], n_kv_heads=kv[f'{arch}.attention.head_count_kv'],
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
return model, kv
def generate(self, tokens:list[int], start_pos=0):
v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
start_pos = 0
t = Tensor([tokens[start_pos:]], dtype="int32")
self.forward_jit.reset() # TODO: why is this required? root cause the issue and make it not be needed
while len(tokens) < self.max_context:
t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
next_id = int(t.item())
tokens.append(next_id)
start_pos = len(tokens) - 1
yield next_id
models = {
"1B": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
"3B": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf",
"3B_f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf",
"8B": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
}
from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import Optional
from fastapi.responses import StreamingResponse, JSONResponse
app = FastAPI()
class CompletionRequest(BaseModel):
model: str
prompt: str
max_tokens: int = 64
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
stop: Optional[list[str]] = None
stream: Optional[bool] = False
# Load model and tokenizer at server startup
MODEL, KV = Transformer.from_gguf(Tensor.from_url(models["1B"]), 4096)
print(models["1B"])
TOKENIZER = SimpleTokenizer(KV["tokenizer.ggml.tokens"])
BOS_ID = KV["tokenizer.ggml.bos_token_id"]
EOS_ID = KV["tokenizer.ggml.eos_token_id"]
@app.post("/v1/completions")
async def openai_completion(req: CompletionRequest):
prompt_ids = [BOS_ID] + TOKENIZER.encode(req.prompt)
start_pos = len(prompt_ids) - 1
def generator():
for token in MODEL.generate(prompt_ids, start_pos):
if token == EOS_ID:
break
text = TOKENIZER.decode([token])
#yield f"data: {{"f'choices':[{{'text': {text!r}, 'index': 0, 'finish_reason': None}}]}}\n\n"
yield f"data: {{\"choices\": [{{\"text\": {text!r}, \"index\": 0, \"finish_reason\": None}}]}}\n\n"
if req.stream:
return StreamingResponse(generator(), media_type="text/event-stream")
else:
for token in MODEL.generate(prompt_ids, start_pos):
if token == EOS_ID:
break
prompt_ids.append(token)
decoded = TOKENIZER.decode(prompt_ids[len(prompt_ids) - req.max_tokens:])
return JSONResponse({
"id": "cmpl-tinygrad",
"object": "text_completion",
"model": req.model,
"choices": [{
"text": decoded,
"index": 0,
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(prompt_ids),
"completion_tokens": req.max_tokens,
"total_tokens": len(prompt_ids) + req.max_tokens
}
})
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
messages = body.get("messages", [])
max_tokens = body.get("max_tokens", 64)
stream = body.get("stream", False)
model_name = body.get("model", "1B")
if not messages:
return JSONResponse(status_code=422, content={"error": "Missing 'messages'"})
# Reconstruct the prompt from messages
prompt_ids = [BOS_ID]
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
prompt_ids += TOKENIZER.role(role) + TOKENIZER.encode(content) + [EOS_ID]
prompt_ids += TOKENIZER.role("assistant")
start_pos = len(prompt_ids) - 1
def generator():
for token in MODEL.generate(prompt_ids, start_pos):
if token == EOS_ID: break
text = TOKENIZER.decode([token])
yield f"data: {{\"choices\": [{{\"delta\": {{\"content\": {text!r}}}, \"index\": 0, \"finish_reason\": None}}], \"object\": \"chat.completion.chunk\"}}\n\n"
yield "data: [DONE]\n\n"
if stream:
return StreamingResponse(generator(), media_type="text/event-stream")
else:
completion = []
for token in MODEL.generate(prompt_ids, start_pos):
if token == EOS_ID: break
completion.append(token)
decoded = TOKENIZER.decode(completion)
return JSONResponse({
"id": "chatcmpl-tinygrad",
"object": "chat.completion",
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": decoded
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(prompt_ids),
"completion_tokens": len(completion),
"total_tokens": len(prompt_ids) + len(completion)
}
})