Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions logger.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,43 @@
import torch


class Logger(object):
class Logger:
def __init__(self, runs, info=None):
self.info = info
self.results = [[] for _ in range(runs)]

def add_result(self, run, result):
assert len(result) == 3
assert run >= 0 and run < len(self.results)
if len(result) != 3 or not (0 <= run < len(self.results)):
raise ValueError("Invalid result format or run index.")
self.results[run].append(result)

def print_statistics(self, run=None):
def calculate_statistics(data):
train_max = data[:, 0].max().item()
valid_max = data[:, 1].max().item()
best_index = data[:, 1].argmax()
final_train = data[best_index, 0].item()
final_test = data[best_index, 2].item()
return train_max, valid_max, final_train, final_test

if run is not None:
result = 100 * torch.tensor(self.results[run])
argmax = result[:, 1].argmax().item()
data = 100 * torch.tensor(self.results[run])
train_max, valid_max, final_train, final_test = calculate_statistics(data)
print(f'Run {run + 1:02d}:')
print(f'Highest Train: {result[:, 0].max():.2f}')
print(f'Highest Valid: {result[:, 1].max():.2f}')
print(f' Final Train: {result[argmax, 0]:.2f}')
print(f' Final Test: {result[argmax, 2]:.2f}')
print(f'Highest Train: {train_max:.2f}')
print(f'Highest Valid: {valid_max:.2f}')
print(f' Final Train: {final_train:.2f}')
print(f' Final Test: {final_test:.2f}')
else:
result = 100 * torch.tensor(self.results)

best_results = []
for r in result:
train1 = r[:, 0].max().item()
valid = r[:, 1].max().item()
train2 = r[r[:, 1].argmax(), 0].item()
test = r[r[:, 1].argmax(), 2].item()
best_results.append((train1, valid, train2, test))
all_results = 100 * torch.tensor(self.results)
stats = [calculate_statistics(run_data) for run_data in all_results]
stats_tensor = torch.tensor(stats)

best_result = torch.tensor(best_results)
def print_mean_std(idx, label):
values = stats_tensor[:, idx]
print(f'{label}: {values.mean():.2f} ± {values.std():.2f}')

print(f'All runs:')
r = best_result[:, 0]
print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 1]
print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 2]
print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 3]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}')
print_mean_std(0, 'Highest Train')
print_mean_std(1, 'Highest Valid')
print_mean_std(2, 'Final Train')
print_mean_std(3, 'Final Test')