forked from Triang-jyed-driung/Albatross
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbatch.py
More file actions
148 lines (116 loc) · 5.14 KB
/
batch.py
File metadata and controls
148 lines (116 loc) · 5.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
########################################################################################################
#
# The RWKV-7 "Goose" Language Model - https://github.com/BlinkDL/RWKV-LM
#
########################################################################################################
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch, copy, time, random, json, math, gc
from tqdm import tqdm
from torch.nn import functional as F
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
########################################################################################################
args = types.SimpleNamespace()
args.vocab_size = 65536
args.head_size = 64
#
# model download: https://huggingface.co/BlinkDL/rwkv7-g1
#
args.MODEL_NAME = "/root/models/rwkv7-g1a-0.1b-20250728-ctx4096"
print(f'\nUsing CUDA fp16. Loading {args.MODEL_NAME} ...\n')
from reference.rwkv7 import RWKV_x070
model = RWKV_x070(args)
from reference.utils import TRIE_TOKENIZER, sampler_simple_batch
tokenizer = TRIE_TOKENIZER("reference/rwkv_vocab_v20230424.txt")
########################################################################################################
# prompts = ["The apple can be", "The cat can be"]
# prompts = ["The apple can't be", "The cat can't be"]
prompts = ["The apple can be", "The cat can't be", "他们发现,这", "Q: 1+1=?\nA: 1+1=2."]
BATCH_SIZE = len(prompts)
state = model.generate_zero_state(BATCH_SIZE)
init_outs = model.forward_batch([tokenizer.encode(prompt) for prompt in prompts], state)
for n in range(BATCH_SIZE):
print(prompts[n])
init_out = init_outs[n]
probs = F.softmax(init_out.float(), dim=-1) # compute softmax in float (more accurate)
_, indices = torch.topk(probs, 5) # print top-5 possibilities
for i in range(len(indices)):
token_id = indices[i].item()
token = tokenizer.decode([token_id])
token_prob = probs[token_id].item()
print(repr(token), f'[probability {token_prob:.2%}]')
if n != BATCH_SIZE-1:
print()
########################################################################################################
prompts = ["也许", "我看到", "他们发现", "我认为", "哈哈", "这是一个有趣的", "List of Emojis:"]
BATCH_SIZE = len(prompts)
# prompts = ["这是一个有趣的"] * BATCH_SIZE
# prompts = ["他们发现"] * BATCH_SIZE
# prompts = ["我看到"] * BATCH_SIZE
state = model.generate_zero_state(BATCH_SIZE)
out = model.forward_batch([tokenizer.encode(prompt) for prompt in prompts], state)
tokens = []
GENERATE_LENGTH = 10
for i in range(GENERATE_LENGTH):
new_tokens = sampler_simple_batch(out, noise=0).tolist()
tokens.append(new_tokens)
out = model.forward_batch(new_tokens, state)
tokens = np.transpose(np.array(tokens), axes=(1,0,2)).squeeze(-1)
print('\n')
for n in range(BATCH_SIZE):
print(prompts[n], end='')
print(tokenizer.decode(tokens[n], utf8_errors="ignore"))
print('#'*80)
########################################################################################################
BATCH_SIZE=256
print(f'BATCH_SIZE {BATCH_SIZE} LAMBADA eval')
def eval_qa_batch(todo, print_interval, pad_eod = True, loss_mode = False, BATCH_SIZE = 1):
xsum = 0
xcnt = 0
xacc = 0
fwd_tokens = []
fwd_desc = []
for i in range(len(todo)):
# get src and dst
d = todo[i]
if pad_eod:
src = [0] + tokenizer.encode(d[0])
else:
src = tokenizer.encode(d[0])
dst = tokenizer.encode(d[1])
# store jobs
fwd_tokens.append(src+dst)
fwd_desc.append((src, dst))
if len(fwd_tokens) >= BATCH_SIZE or i == len(todo)-1:
# batch fwd
out_batch = model.forward_batch(fwd_tokens, model.generate_zero_state(BATCH_SIZE), full_output=True)
# process output
for j in range(len(fwd_desc)):
out = out_batch[j]
src, dst = fwd_desc[j]
logits = 0
correct = True
for n in range(len(dst)):
ooo = out[len(src)-1+n].float()
probs = F.softmax(ooo, dim=-1)
logits += math.log(probs[dst[n]])
if torch.argmax(probs).item() != dst[n]:
correct = False
xcnt += 1
xsum += logits
xacc += 1 if correct else 0
if xcnt % print_interval == 0 or xcnt == len(todo):
if loss_mode:
print('loss', round(-xsum / xcnt, 2), 'acc', round(xacc/xcnt*100, 1))
else:
print(xcnt, 'ppl', round(math.exp(-xsum / xcnt), 2), 'acc', round(xacc/xcnt*100, 1))
fwd_tokens = []
fwd_desc = []
with open(f"eval/lambada_test.jsonl", "r", encoding="utf-8") as f:
todo = [json.loads(line) for line in f]
todo = [[doc['text'].rsplit(' ', 1)[0], " " + doc['text'].rsplit(' ', 1)[1]] for doc in todo]
eval_qa_batch(todo, print_interval=1000, BATCH_SIZE=BATCH_SIZE)