-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_eval_positioning.py
More file actions
108 lines (91 loc) · 4.54 KB
/
main_eval_positioning.py
File metadata and controls
108 lines (91 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from tqdm import tqdm
import models_vit
from dataset_classes.positioning import Positioning5G
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
from pathlib import Path
from torch.utils.data import random_split, DataLoader
def reverse_normalize(x, coord_min, coord_max):
return (x + 1) / 2 * (coord_max - coord_min) + coord_min
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
scene = 'outdoor'
dataset_train = Positioning5G(Path('fine-tuning_datasets/5G_NR_Positioning/outdoor/train'), scene=scene)
dataset_test = Positioning5G(Path('fine-tuning_datasets/5G_NR_Positioning/outdoor/test'), scene=scene)
coord_min, coord_max = dataset_train.coord_nominal_min.view((1, -1)), dataset_train.coord_nominal_max.view((1, -1))
dataloader_train = DataLoader(dataset_train, batch_size=256, shuffle=False, num_workers=0)
dataloader_test = DataLoader(dataset_test, batch_size=256, shuffle=False, num_workers=0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
model = models_vit.__dict__['vit_small_patch16'](global_pool='token', num_classes=3, drop_path_rate=0.1, in_chans=4)
checkpoint = torch.load(Path('/home/ict317-3/Mohammad/mae/output_dir/checkpoint-40.pth'), map_location='cpu', weights_only=False)
msg = model.load_state_dict(checkpoint['model'], strict=True)
print(msg)
model = model.to(device)
distances_train = torch.zeros((len(dataset_train),))
with torch.no_grad():
for i, batch in tqdm(enumerate(dataloader_train), desc='Train Batch', total=len(dataloader_train)):
image, target = batch
image = image.to(device)
pred_position = reverse_normalize(model(image).cpu(), coord_min, coord_max)
position = reverse_normalize(target.cpu(), coord_min, coord_max)
num_samples = target.shape[0]
distances_train[i * num_samples: (i + 1) * num_samples] = torch.sqrt(torch.sum((pred_position - position) ** 2, dim=1))
distances_test = torch.zeros((len(dataset_test),))
with torch.no_grad():
for i, batch in tqdm(enumerate(dataloader_test), desc='Test Batch', total=len(dataloader_test)):
image, target = batch
image = image.to(device)
pred_position = reverse_normalize(model(image).cpu(), coord_min, coord_max)
position = reverse_normalize(target.cpu(), coord_min, coord_max)
num_samples = target.shape[0]
distances_test[i * num_samples: (i + 1) * num_samples] = torch.sqrt(torch.sum((pred_position - position) ** 2, dim=1))
distances_train = distances_train.numpy()
distances_test = distances_test.numpy()
# distances_train.sort()
# distances_test.sort()
# cdf_train = np.linspace(0, 1, len(dataset_train))
# cdf_test = np.linspace(0, 1, len(dataset_test))
# idx_90_train = np.argmin(np.abs(cdf_train - 0.1))
# idx_90_test = np.argmin(np.abs(cdf_test - 0.1))
#
# plt.rcParams['font.family'] = 'serif'
# fig, axs = plt.subplots(1, 1)
# axs.plot(distances_train, cdf_train, label='train', linewidth=2, color='r')
# axs.plot(distances_test, cdf_test, label='test', linewidth=2, color='b')
# axs.axhline(y=0.1, linewidth=1, linestyle='--', label='90% likely', color='k')
# axs.axvline(x=distances_train[idx_90_train], linewidth=1, linestyle='--', color='r', alpha=0.8)
# axs.axvline(x=distances_test[idx_90_test], linewidth=1, linestyle='--', color='b', alpha=0.8)
# axs.set_xlabel('Positioning Error (m)')
# axs.set_ylabel('CDF')
# axs.legend(loc='lower right')
# plt.tight_layout()
# # plt.savefig('cdf_positioning.png', dpi=300)
# plt.show()
plt.rcParams['font.family'] = 'serif'
mean_train = np.mean(distances_train)
mean_test = np.mean(distances_test)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# model = 'Finetuning ViT-M'
# other = '(2 out of 12 blocks + linear layer)'
# fig.suptitle(f'{model} {other}\n{scene} scenario')
bins = 25
axs[0].hist(distances_train, bins=bins, color='red', edgecolor='w', alpha=0.7, density=True)
axs[0].axvline(mean_train, color='black', linestyle='--', linewidth=2, label=f'Mean: {mean_train:.2f} (m)')
# axs[0].set_title('Training')
axs[0].set_xlabel('Positioning Error (m)', fontsize=16)
axs[0].set_ylabel('Probability Density', fontsize=16)
axs[0].legend(fontsize=16)
axs[1].hist(distances_test, bins=bins, color='blue', edgecolor='w', alpha=0.7, density=True)
axs[1].axvline(mean_test, color='black', linestyle='--', linewidth=2, label=f'Mean: {mean_test:.2f} (m)')
# axs[1].set_title('Test')
axs[1].set_xlabel('Positioning Error (m)', fontsize=16)
axs[1].set_ylabel('Probability Density', fontsize=16)
axs[1].legend(fontsize=16)
plt.tight_layout()
plt.savefig('hist_positioning.png', dpi=300)
plt.show()