forked from Triang-jyed-driung/Albatross
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdemo.py
More file actions
414 lines (369 loc) · 16.2 KB
/
demo.py
File metadata and controls
414 lines (369 loc) · 16.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
import curses, time, threading, random
from collections import deque
from typing import List, Optional
import math
import multiprocessing as mp
import queue
import torch, types, copy
import numpy as np
from torch.nn import functional as F
GRID_W, GRID_H = 16, 8
NUM_PANELS = GRID_W * GRID_H
BATCH_SIZE = NUM_PANELS
GENERATION_LENGTH = 4000
# GENERATION_LENGTH = 1000
SAMPLER_NOISE = 3.0 # here we use simple (fast) sampling = greedy(logits + noise)
# find some random tokens as first token?
# prompts = []
# with open("reference/rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
# lines = f.readlines()
# for l in lines:
# x = eval(l[l.index(' '):l.rindex(' ')])
# if isinstance(x, str) and all(c.isalpha() for c in x) and x[0].isupper() and all(c.islower() for c in x[1:]) and ' ' not in x:
# prompts.append(x)
# prompts = random.sample(prompts, BATCH_SIZE)
# or, use "The" for all panels?
prompts = ["The" for _ in range(BATCH_SIZE)]
# prompts = ["List of Emojis:" for _ in range(BATCH_SIZE)]
# prompts = ["Q: 1+1=?\nA: 1+1=2." for _ in range(BATCH_SIZE)]
# prompts = ["Assistant: <think" for _ in range(BATCH_SIZE)]
# prompts = ["Assistant: <think>嗯" for _ in range(BATCH_SIZE)]
# prompts = ["Assistant: <think>私" for _ in range(BATCH_SIZE)]
SHOW_SPEED_PERCENTILE = 50
LOG_FILE = open("demo.log", "w")
########################################################################################################
args = types.SimpleNamespace()
args.vocab_size = 65536
args.head_size = 64
#
# model download: https://huggingface.co/BlinkDL/rwkv7-g1
#
# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1a-0.1b-20250728-ctx4096"
# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1a-0.4b-20250905-ctx4096"
# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1-1.5b-20250429-ctx4096"
# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1-2.9b-20250519-ctx4096"
args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g0a-7.2b-20250829-ctx4096"
from reference.rwkv7 import RWKV_x070
from reference.utils import TRIE_TOKENIZER, sampler_simple_batch
########################################################################################################
def computation_process(text_queue, shutdown_event):
"""Computation process that runs the model inference and sends results to UI"""
try:
# Initialize model
model = RWKV_x070(args)
tokenizer = TRIE_TOKENIZER("reference/rwkv_vocab_v20230424.txt")
# Initialize state
state = model.generate_zero_state(BATCH_SIZE)
# Send initial prompts to UI
for i, prompt in enumerate(prompts):
text_queue.put(("text", i, prompt))
# Initial state with initial prompts
out = model.forward_batch([tokenizer.encode(prompt) for prompt in prompts], state)
perf_interval = 10
times = []
all_times = []
tokens = [[] for _ in range(BATCH_SIZE)]
for i in range(GENERATION_LENGTH):
if shutdown_event.is_set():
break
t00 = time.perf_counter()
new_tokens = sampler_simple_batch(out, SAMPLER_NOISE).tolist()
tokens = [tokens[n] + new_tokens[n] for n in range(BATCH_SIZE)]
# Send decoded tokens to UI
for n in range(BATCH_SIZE):
try:
decoded_text = tokenizer.decode(tokens[n], utf8_errors="strict") # only send full utf-8 tokens
text_queue.put(("text", n, decoded_text))
tokens[n] = []
except:
pass
torch.cuda.synchronize()
t0 = time.perf_counter()
out = model.forward_batch(new_tokens, state)
torch.cuda.synchronize()
t1 = time.perf_counter()
times.append(t1 - t0)
all_times.append(t1 - t00)
# time.sleep(0.1)
if i % perf_interval == 0:
times_tmp = np.percentile(times, SHOW_SPEED_PERCENTILE) if times else 0
all_times_tmp = np.percentile(all_times, SHOW_SPEED_PERCENTILE) if all_times else 0
times.clear()
all_times.clear()
# Send performance info to main process
if times_tmp > 0 and all_times_tmp > 0:
text_queue.put(("perf", -1, f'RWKV-7 7.2B FP16 bsz{BATCH_SIZE} inference @ RTX5090 || Token/s = {round(BATCH_SIZE/times_tmp,2)} (forward), {round(BATCH_SIZE/all_times_tmp,2)} (full) || Const speed & VRAM because this is RNN || https://github.com/BlinkDL/Albatross https://rwkv.com'))
except Exception as e:
text_queue.put(("error", -1, f"Error: {str(e)}"))
LOG_FILE.write(str(e) + "\n")
finally:
text_queue.put(("done", -1, "Finished"))
########################################################################################################
_ui_singleton = None
_singleton_lock = threading.Lock()
class PanelState:
__slots__ = ("width","height","lines","cur","pending","lock","dirty")
def __init__(self,w:int,h:int,history_lines:int=4096):
self.width=max(4,w)
self.height=max(1,h)
self.lines=deque(maxlen=history_lines)
self.cur=""
self.pending=deque()
self.lock=threading.Lock()
self.dirty=True
def set_size(self,w:int,h:int):
self.width=max(4,w)
self.height=max(1,h)
def enqueue(self,s:str):
with self.lock:
self.pending.append(s)
def drain_and_wrap(self)->bool:
buf=None
with self.lock:
if self.pending:
buf="".join(self.pending); self.pending.clear()
if not buf: return False
w=self.width; cur=self.cur; i=0; L=len(buf)
while i<L:
ch=buf[i]
if ch=="\n":
self.lines.append(cur); cur=""
else:
cur+=ch
if len(cur)>=w:
self.lines.append(cur); cur=""
i+=1
self.cur=cur
self.dirty=True
return True
def visible_tail(self)->List[str]:
tail=list(self.lines)
if self.cur: tail.append(self.cur)
h=self.height
if len(tail)>=h: return tail[-h:]
return [""]*(h-len(tail))+tail
class TextGridUI:
def __init__(self,fps:int=30,history_lines:int=4096,text_queue=None):
self.fps=max(1,fps)
self.history_lines=history_lines
self.stop_event=threading.Event()
self.stdscr=None
self.panels:List[PanelState]=[]
self.windows:List[curses.window]=[]
self.grid_cell_w=None
self.grid_cell_h=None
self._started=False
self.text_queue=text_queue
self.perf_info="Loading..."
self.perf_dirty=True
self.first_text_received=False
def add_text(self,idx:int,s:str):
if 0<=idx<NUM_PANELS: self.panels[idx].enqueue(s)
def stop(self):
"""Stop the UI and clean up"""
self.stop_event.set()
def start(self):
if self._started: return
self._started=True
curses.wrapper(self._curses_main)
def _compute_layout(self):
max_y,max_x=self.stdscr.getmaxyx()
# Reserve 1 line at the top for performance info
available_height = max_y - 1
cell_w=max_x//GRID_W
cell_h=available_height//GRID_H
# Account for borders: need minimum space for border + content
if cell_w<10 or cell_h<4:
raise RuntimeError(f"Terminal too small: need >= {10*GRID_W}x{4*GRID_H}, current={max_x}x{max_y}")
self.grid_cell_w, self.grid_cell_h = cell_w, cell_h
def _init_windows(self):
self.windows=[]; self.panels=[]
cw=self.grid_cell_w; ch=self.grid_cell_h
for r in range(GRID_H):
for c in range(GRID_W):
# Offset by 1 row to make space for performance info at top
top=r*ch+1; left=c*cw
win=self.stdscr.derwin(ch,cw,top,left)
win.scrollok(False); win.nodelay(True)
self.windows.append(win)
# Account for border space: 1 char on each side horizontally, 1 line for top border, 1 for bottom
# Content area is reduced by 2 chars width and 2 lines height
p=PanelState(cw-2,ch-3,history_lines=self.history_lines)
self.panels.append(p)
for i,win in enumerate(self.windows):
try:
win.erase()
# Draw border
win.box()
# Add label inside the border
label=f"[{i:03d}]"
win.addnstr(1,1,label,len(label))
win.noutrefresh()
except curses.error:
pass
curses.doupdate()
def _curses_main(self,stdscr):
self.stdscr=stdscr
curses.curs_set(0); curses.noecho(); curses.cbreak()
stdscr.nodelay(True); stdscr.keypad(True)
# Initialize colors if supported
if curses.has_colors():
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_WHITE, -1) # idle - white
curses.init_pair(2, curses.COLOR_YELLOW, -1) # processing - yellow
curses.init_pair(3, curses.COLOR_GREEN, -1) # completed - green
curses.init_pair(4, curses.COLOR_RED, -1) # error - red
curses.init_pair(5, curses.COLOR_CYAN, -1) # header - cyan
try:
self._compute_layout(); self._init_windows()
except Exception as e:
curses.endwin(); print(e); return
target_dt=1.0/self.fps; last=time.perf_counter()
while not self.stop_event.is_set():
# Process messages from computation process
if self.text_queue:
try:
while True:
msg_type, panel_id, content = self.text_queue.get_nowait()
if msg_type == "text" and 0 <= panel_id < NUM_PANELS:
# Clear screen on first text to remove any warnings
if not self.first_text_received:
self.stdscr.clear()
self.first_text_received = True
# Force re-initialization of windows and performance display
self._init_windows()
self.perf_dirty = True
self.add_text(panel_id, content)
elif msg_type == "perf":
# Store performance info for display at top
self.perf_info = content
self.perf_dirty = True
elif msg_type == "error":
self.perf_info = f"Error: {content}"
self.perf_dirty = True
elif msg_type == "done":
# Computation finished, we can exit
self.perf_info = "Computation finished"
self.perf_dirty = True
self.stop_event.set()
break
except queue.Empty:
pass
# Display performance info at the top
if self.perf_dirty:
try:
max_y, max_x = self.stdscr.getmaxyx()
# Clear the top line
self.stdscr.move(0, 0)
self.stdscr.clrtoeol()
# Display performance info with color
perf_text = self.perf_info[:max_x-1] # Truncate if too long
if curses.has_colors():
self.stdscr.addnstr(0, 0, perf_text, max_x-1, curses.color_pair(3) | curses.A_BOLD) # Yellow and bold
else:
self.stdscr.addnstr(0, 0, perf_text, max_x-1)
self.perf_dirty = False
except curses.error:
pass
any_dirty=False
for idx,p in enumerate(self.panels):
changed=p.drain_and_wrap()
if not (changed or p.dirty): continue
win=self.windows[idx]
cw=p.width; ch=self.grid_cell_h
try:
win.erase()
# Draw border
win.box()
# Simple header with batch info only - inside the border
header = f"[{idx:03d}]"
# Choose color based on status for header only
header_color = 5 # cyan for header
if curses.has_colors():
win.addnstr(1, 1, header, cw, curses.color_pair(header_color) | curses.A_BOLD)
else:
win.addnstr(1, 1, header, cw)
# Content - display inside border area
visible=p.visible_tail()
row=2 # Start from row 2 (after border and header)
max_rows=ch-3 # Account for top border, header, and bottom border
for line in visible[-max_rows:]:
if len(line)>cw: line=line[:cw]
# Use normal color for content text, positioned inside border
win.addnstr(row,1,line,cw)
row+=1
if row>=ch-1: break # Stop before bottom border
win.noutrefresh()
p.dirty=False
any_dirty=True
except curses.error:
pass
# Also mark as dirty if performance info was updated
if self.perf_dirty:
any_dirty = True
if any_dirty: curses.doupdate()
try:
ch=self.stdscr.getch()
if ch in (ord('q'),ord('Q'),27): self.stop_event.set()
except curses.error:
pass
now=time.perf_counter(); dt=now-last
if dt<target_dt: time.sleep(target_dt-dt)
last=now
curses.nocbreak(); self.stdscr.keypad(False); curses.echo(); curses.endwin()
def start_ui(fps:int=30,history_lines:int=4096,text_queue=None):
global _ui_singleton
with _singleton_lock:
if _ui_singleton is None:
_ui_singleton=TextGridUI(fps=fps,history_lines=history_lines,text_queue=text_queue)
_ui_singleton.start()
def add_text(n:int,s:str):
ui=_ui_singleton
if ui is None: raise RuntimeError("UI not started. Call start_ui().")
ui.add_text(n,s)
def stop_ui():
"""Stop the UI and clean up resources"""
global _ui_singleton
if _ui_singleton is not None:
_ui_singleton.stop()
time.sleep(0.1) # Give a moment for the UI thread to clean up
if __name__=="__main__":
# Set multiprocessing start method
mp.set_start_method('spawn', force=True)
# Create communication queue and shutdown event
text_queue = mp.Queue()
shutdown_event = mp.Event()
# Start computation process
comp_process = mp.Process(
target=computation_process,
args=(text_queue, shutdown_event),
daemon=False
)
comp_process.start()
# computation_process(text_queue, shutdown_event)
# quit()
try:
# Start UI in main process
start_ui(fps=30, text_queue=text_queue)
except Exception as e:
LOG_FILE.write(str(e) + "\n")
finally:
# Signal computation process to shutdown
shutdown_event.set()
# Wait for computation process to finish
comp_process.join(timeout=5.0)
if comp_process.is_alive():
print("Terminating computation process...")
comp_process.terminate()
comp_process.join(timeout=2.0)
if comp_process.is_alive():
comp_process.kill()
comp_process.join()
# Always stop UI and restore terminal
stop_ui()
# Ensure terminal is fully restored
import os
os.system('stty sane') # Reset terminal settings
print("\nDemo completed! Terminal restored.")
LOG_FILE.close()