-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathget_scores.py
More file actions
73 lines (55 loc) · 2.21 KB
/
get_scores.py
File metadata and controls
73 lines (55 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import math
from collections import defaultdict
from pathlib import Path
from playground import load_structured_file, save_structured_file
def main(lfp: Path, scale: float = 1.0):
safe_total_numbers = defaultdict(int)
safe_total_scores = defaultdict(float)
hallu_total_numbers = defaultdict(int)
hallu_total_scores = defaultdict(float)
scores = defaultdict(float)
sfp = lfp.with_name(lfp.name.replace("details", "heads"))
lf = load_structured_file(lfp)
for line in lf:
is_safe = line["is_safe"]
line = line["data"]
for layer_id, item in line.items():
layer_id = int(layer_id)
for head_id, score in item.items():
head_id = int(head_id)
key = (layer_id, head_id)
if not math.isnan(score) and not math.isinf(score):
if is_safe:
safe_total_numbers[key] += 1
safe_total_scores[key] += score
else:
hallu_total_numbers[key] += 1
hallu_total_scores[key] += score
safe_total_numbers = dict(safe_total_numbers)
safe_total_scores = dict(safe_total_scores)
hallu_total_numbers = dict(hallu_total_numbers)
hallu_total_scores = dict(hallu_total_scores)
assert len(safe_total_numbers) == len(safe_total_scores)
assert len(hallu_total_numbers) == len(hallu_total_scores)
for key, total_number in safe_total_numbers.items():
total_score = safe_total_scores[key]
score = total_score / total_number
scores[key] = score
for key, total_number in hallu_total_numbers.items():
total_score = hallu_total_scores[key]
score = total_score / total_number
scores[key] -= score * scale
scores = sorted(scores.items(), key=lambda item: item[1], reverse=False)
save_structured_file(scores, sfp, "w")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"load_file_path",
type=Path,
)
parser.add_argument("--scale", type=float, default=1.0)
args = parser.parse_args()
lfp = args.load_file_path
scale = args.scale
main(lfp, scale)