-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluatingmodelwithtest.py
More file actions
72 lines (61 loc) · 1.85 KB
/
evaluatingmodelwithtest.py
File metadata and controls
72 lines (61 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# ==========================
# Load Unseen Test Dataset
# ==========================
test_dataset = CodeDataset(
"test_snippets.csv",
vocab=train_dataset.vocab,
label_encoder=train_dataset.le
)
test_loader = DataLoader(
test_dataset,
batch_size=64,
shuffle=False,
collate_fn=collate_fn
)
# ==========================
# Evaluate Model
# ==========================
model.eval()
correct = 0
total = 0
per_language_correct = {}
per_language_total = {}
with torch.no_grad():
for x_batch, y_batch in test_loader:
x_batch, y_batch = x_batch.to(DEVICE), y_batch.to(DEVICE)
outputs = model(x_batch)
preds = outputs.argmax(dim=1)
correct += (preds == y_batch).sum().item()
total += y_batch.size(0)
# Track per-language accuracy
for label, pred in zip(y_batch, preds):
lang = test_dataset.le.inverse_transform([label.item()])[0]
if lang not in per_language_correct:
per_language_correct[lang] = 0
per_language_total[lang] = 0
per_language_total[lang] += 1
if label == pred:
per_language_correct[lang] += 1
overall_acc = correct / total
print(f"Accuracy on Unseen Test Snippets: {overall_acc:.4f}")
# ==========================
# Plot Per-Language Accuracy
# ==========================
languages = list(per_language_correct.keys())
accs = [
per_language_correct[lang] / per_language_total[lang]
for lang in languages
]
plt.figure(figsize=(10,5))
plt.bar(languages, accs, color='skyblue')
plt.ylim(0, 1.0)
plt.ylabel("Accuracy")
plt.xlabel("Language")
plt.title("Model Accuracy on Completely Unseen Code Snippets")
plt.xticks(rotation=30)
plt.tight_layout()
plt.savefig("unseen_results.png", dpi=300)
plt.show()