forked from mfederici/dl-kit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensorboard.py
More file actions
18 lines (15 loc) · 915 Bytes
/
tensorboard.py
File metadata and controls
18 lines (15 loc) · 915 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import pytorch_lightning.loggers as loggers
import matplotlib.pyplot as plt
from code.loggers.log_entry import SCALAR_ENTRY, SCALARS_ENTRY, IMAGE_ENTRY, LogEntry
class TensorBoardLogger(loggers.TensorBoardLogger):
def log(self, name: str, log_entry: LogEntry, global_step=None):
if log_entry.data_type == SCALAR_ENTRY:
self.log_metrics({name: log_entry.value}, global_step=global_step)
elif log_entry.data_type == SCALARS_ENTRY:
self.log_metrics({'%s/%s' % (name, sub_name): v for sub_name, v in log_entry.value.items()},
global_step=global_step)
elif log_entry.data_type == IMAGE_ENTRY:
self.experiment.add_image(name, log_entry.value, global_step=global_step)
plt.close(log_entry.value)
else:
raise Exception('Data type %s is not recognized by TensorBoardLogger' % log_entry.data_type)