-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluation_utils.py
More file actions
37 lines (29 loc) · 1.27 KB
/
evaluation_utils.py
File metadata and controls
37 lines (29 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# evaluation_utils.py
import torch
def evaluate_model(model, tokenizer, prompts, labels, answer_logits, progress_callback=None):
score = 0.0
model_outputs = []
for prompt, label in zip(prompts, labels):
prompt_ids = tokenizer.encode(prompt)[:, :-1]
logits = model.forward(prompt_ids, last_id_only=True).float()
logits_ans = logits[:, :, answer_logits]
prob_ans = torch.softmax(logits_ans, dim=-1)
predicted_label = torch.argmax(prob_ans).item()
# Convert string label to integer if necessary
if isinstance(label, str):
label_index = ord(label) - ord('A')
else:
label_index = label
score += prob_ans[0, 0, label_index]
# Decode the predicted answer
predicted_answer = chr(ord('A') + predicted_label)
model_outputs.append({
'prompt': prompt,
'correct_label': chr(ord('A') + label_index) if isinstance(label, int) else label,
'predicted_label': predicted_answer,
'is_correct': predicted_label == label_index
})
# Call the progress callback if provided
if progress_callback:
progress_callback()
return score / len(prompts), model_outputs