From 2df2317dbd16e1c00f484cefd57f70032a8c7a94 Mon Sep 17 00:00:00 2001 From: Micael Carvalho Date: Sun, 3 May 2020 13:05:08 -0700 Subject: [PATCH 01/17] Use SQLite for Logger --- bootstrap/engines/engine.py | 97 +++++++------- bootstrap/lib/logger.py | 184 ++++++++++++++++----------- bootstrap/lib/utils.py | 8 +- bootstrap/optimizers/lr_scheduler.py | 4 +- bootstrap/run.py | 1 + 5 files changed, 167 insertions(+), 127 deletions(-) diff --git a/bootstrap/engines/engine.py b/bootstrap/engines/engine.py index de37236..f134529 100755 --- a/bootstrap/engines/engine.py +++ b/bootstrap/engines/engine.py @@ -61,9 +61,13 @@ def hook(self, name): Args: name: the name of the hook """ + return_dict = {} if name in self.hooks: for func in self.hooks[name]: - func() + return_func = func() + if return_func: + return_dict.update(return_func) + return return_dict def register_hook(self, name, func): """ Register a callback function to be triggered when the hook @@ -164,28 +168,29 @@ def train_epoch(self, model, dataset, optimizer, epoch, mode='train'): 'load': None, 'run_avg': 0 } - out_epoch = {} + epoch_dict = {} batch_loader = dataset.make_batch_loader() - self.hook(f'{mode}_on_start_epoch') + epoch_dict.update(self.hook(f'{mode}_on_start_epoch')) for i, batch in enumerate(batch_loader): + batch_dict = {} timer['load'] = time.time() - timer['elapsed'] - self.hook(f'{mode}_on_start_batch') + batch_dict.update(self.hook(f'{mode}_on_start_batch')) optimizer.zero_grad() out = model(batch) - self.hook(f'{mode}_on_forward') + batch_dict.update(self.hook(f'{mode}_on_forward')) if not torch.isnan(out['loss']): out['loss'].backward() else: Logger()('NaN detected') # torch.cuda.synchronize() - self.hook(f'{mode}_on_backward') + batch_dict.update(self.hook(f'{mode}_on_backward')) optimizer.step() # torch.cuda.synchronize() - self.hook(f'{mode}_on_update') + batch_dict.update(self.hook(f'{mode}_on_update')) timer['process'] = time.time() - timer['elapsed'] if i == 0: @@ -193,23 +198,25 @@ def train_epoch(self, model, dataset, optimizer, epoch, mode='train'): else: timer['run_avg'] = timer['run_avg'] * 0.8 + timer['process'] * 0.2 - Logger().log_value(f'{mode}_batch.epoch', epoch, should_print=False) - Logger().log_value(f'{mode}_batch.batch', i, should_print=False) - Logger().log_value(f'{mode}_batch.timer.process', timer['process'], should_print=False) - Logger().log_value(f'{mode}_batch.timer.load', timer['load'], should_print=False) + batch_dict['epoch'] = epoch + batch_dict['batch'] = i + batch_dict['timer_process'] = timer['process'] + batch_dict['timer_load'] = timer['load'] for key, value in out.items(): if torch.is_tensor(value): - if value.dim() <= 1: + if value.numel() == 1: value = value.item() # get number from a torch scalar else: continue if isinstance(value, (list, dict, tuple)): continue - if key not in out_epoch: - out_epoch[key] = [] - out_epoch[key].append(value) - Logger().log_value(f'{mode}_batch.' + key, value, should_print=False) + if key not in epoch_dict: + epoch_dict[key] = [] + epoch_dict[key].append(value) + batch_dict[key] = value + + Logger().log_dict(f'{mode}_batch', batch_dict) if i % Options()['engine']['print_freq'] == 0 or i == len(batch_loader) - 1: Logger()("{}: epoch {} | batch {}/{}".format(mode, epoch, i, len(batch_loader) - 1)) @@ -219,16 +226,17 @@ def train_epoch(self, model, dataset, optimizer, epoch, mode='train'): datetime.timedelta(seconds=math.floor(timer['run_avg'] * (len(batch_loader) - 1 - i))))) Logger()("{} process: {:.5f} | load: {:.5f}".format(' ' * len(mode), timer['process'], timer['load'])) Logger()("{} loss: {:.5f}".format(' ' * len(mode), out['loss'].data.item())) - self.hook(f'{mode}_on_print') + epoch_dict.update(self.hook(f'{mode}_on_print')) timer['elapsed'] = time.time() - self.hook(f'{mode}_on_end_batch') + epoch_dict.update(self.hook(f'{mode}_on_end_batch')) - Logger().log_value(f'{mode}_epoch.epoch', epoch, should_print=True) - for key, value in out_epoch.items(): - Logger().log_value(f'{mode}_epoch.' + key, np.asarray(value).mean(), should_print=True) + for key in epoch_dict.keys(): + epoch_dict[key] = np.asarray(epoch_dict[key]).mean() + epoch_dict['epoch'] = epoch - self.hook(f'{mode}_on_end_epoch') + epoch_dict.update(self.hook(f'{mode}_on_end_epoch')) + Logger().log_dict(f'{mode}_epoch', epoch_dict, should_print=True) Logger().flush() self.hook(f'{mode}_on_flush') @@ -259,18 +267,19 @@ def eval_epoch(self, model, dataset, epoch, mode='eval', logs_json=True): 'load': None, 'run_avg': 0 } - out_epoch = {} + epoch_dict = {} batch_loader = dataset.make_batch_loader() - self.hook('{}_on_start_epoch'.format(mode)) + epoch_dict.update(self.hook('{}_on_start_epoch'.format(mode))) for i, batch in enumerate(batch_loader): + batch_dict = {} timer['load'] = time.time() - timer['elapsed'] - self.hook('{}_on_start_batch'.format(mode)) + batch_dict.update(self.hook('{}_on_start_batch'.format(mode))) with torch.no_grad(): out = model(batch) # torch.cuda.synchronize() - self.hook('{}_on_forward'.format(mode)) + batch_dict.update(self.hook('{}_on_forward'.format(mode))) timer['process'] = time.time() - timer['elapsed'] if i == 0: @@ -278,23 +287,25 @@ def eval_epoch(self, model, dataset, epoch, mode='eval', logs_json=True): else: timer['run_avg'] = timer['run_avg'] * 0.8 + timer['process'] * 0.2 - Logger().log_value('{}_batch.batch'.format(mode), i, should_print=False) - Logger().log_value('{}_batch.epoch'.format(mode), epoch, should_print=False) - Logger().log_value('{}_batch.timer.process'.format(mode), timer['process'], should_print=False) - Logger().log_value('{}_batch.timer.load'.format(mode), timer['load'], should_print=False) + batch_dict['epoch'] = epoch + batch_dict['batch'] = i + batch_dict['timer_process'] = timer['process'] + batch_dict['timer_load'] = timer['load'] for key, value in out.items(): if torch.is_tensor(value): - if value.dim() <= 1: + if value.numel() == 1: value = value.item() # get number from a torch scalar else: continue if isinstance(value, (list, dict, tuple)): continue - if key not in out_epoch: - out_epoch[key] = [] - out_epoch[key].append(value) - Logger().log_value('{}_batch.{}'.format(mode, key), value, should_print=False) + if key not in epoch_dict: + epoch_dict[key] = [] + epoch_dict[key].append(value) + batch_dict[key] = value + + Logger().log_dict(f'{mode}_batch', batch_dict) if i % Options()['engine']['print_freq'] == 0: Logger()("{}: epoch {} | batch {}/{}".format(mode, epoch, i, len(batch_loader) - 1)) @@ -303,20 +314,18 @@ def eval_epoch(self, model, dataset, epoch, mode='eval', logs_json=True): datetime.timedelta(seconds=math.floor(time.time() - timer['begin'])), datetime.timedelta(seconds=math.floor(timer['run_avg'] * (len(batch_loader) - 1 - i))))) Logger()("{} process: {:.5f} | load: {:.5f}".format(' ' * len(mode), timer['process'], timer['load'])) - self.hook('{}_on_print'.format(mode)) + batch_dict.update(self.hook('{}_on_print'.format(mode))) timer['elapsed'] = time.time() - self.hook('{}_on_end_batch'.format(mode)) + batch_dict.update(self.hook('{}_on_end_batch'.format(mode))) - out = {} - for key, value in out_epoch.items(): - out[key] = sum(value) / len(value) + for key in epoch_dict.keys(): + epoch_dict[key] = np.asarray(epoch_dict[key]).mean() + epoch_dict['epoch'] = epoch - Logger().log_value('{}_epoch.epoch'.format(mode), epoch, should_print=True) - for key, value in out.items(): - Logger().log_value('{}_epoch.{}'.format(mode, key), value, should_print=True) + epoch_dict.update(self.hook('{}_on_end_epoch'.format(mode))) + Logger().log_dict(f'{mode}_epoch', epoch_dict, should_print=True) - self.hook('{}_on_end_epoch'.format(mode)) if logs_json: Logger().flush() diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index 7b4e42b..c3ed615 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -11,8 +11,8 @@ import os import sys -import json import inspect +import sqlite3 import datetime import collections @@ -27,9 +27,9 @@ class Logger(object): Logger(dir_logs='logs/mnist') Logger().log_value('train_epoch.epoch', epoch) Logger().log_value('train_epoch.mean_acctop1', mean_acctop1) - Logger().flush() # write the logs.json + Logger().flush() # write the logs.sqlite - Logger()("Launching training procedures") # written to logs.txt + Logger()("Launching training procedures") # written to logs.txt > [I 2018-07-23 18:58:31] ...trap/engines/engine.py.80: Launching training procedures """ @@ -69,15 +69,14 @@ def code(value): SYSTEM: Colors.code(Colors.WHITE + Colors.LIGHT) } - compactjson = True log_level = None # log level dir_logs = None - path_json = None + sqlite_cur = None + sqlite_file = None + sqlite_conn = None path_txt = None file_txt = None name = None - perf_memory = {} - values = {} max_lineno_width = 3 def __new__(cls, dir_logs=None, name='logs'): @@ -90,8 +89,8 @@ def __new__(cls, dir_logs=None, name='logs'): Logger._instance.dir_logs = dir_logs Logger._instance.path_txt = os.path.join(dir_logs, '{}.txt'.format(name)) Logger._instance.file_txt = open(os.path.join(dir_logs, '{}.txt'.format(name)), 'a+') - Logger._instance.path_json = os.path.join(dir_logs, '{}.json'.format(name)) - Logger._instance.reload_json() + Logger._instance.sqlite_file = os.path.join(dir_logs, '{}.sqlite'.format(name)) + Logger._instance.init_sqlite() else: Logger._instance.log_message('No logs files will be created (dir_logs attribute is empty)', log_level=Logger.WARNING) @@ -104,9 +103,6 @@ def __call__(self, *args, **kwargs): def set_level(self, log_level): self.log_level = log_level - def set_json_compact(self, is_compact): - self.compactjson = is_compact - def log_message(self, *message, log_level=INFO, break_line=True, print_header=True, stack_displacement=1, raise_error=True, adaptive_width=True): if log_level < self.log_level: @@ -161,49 +157,6 @@ def log_message(self, *message, log_level=INFO, break_line=True, print_header=Tr if log_level == self.ERROR and raise_error: raise Exception(message) - def log_value(self, name, value, stack_displacement=2, should_print=False, log_level=SUMMARY): - if log_level < self.log_level: - return -1 - - if name not in self.values: - self.values[name] = [] - self.values[name].append(value) - - if should_print: - if type(value) == float: - if int(value) == 0: - message = '{}: {:.6f}'.format(name, value) - else: - message = '{}: {:.2f}'.format(name, value) - else: - message = '{}: {}'.format(name, value) - self.log_message(message, log_level=log_level, stack_displacement=stack_displacement + 1) - - def log_dict(self, group, dictionary, description='', stack_displacement=2, should_print=False, log_level=SUMMARY): - if log_level < self.log_level: - return -1 - - if group not in self.perf_memory: - self.perf_memory[group] = {} - else: - for key in self.perf_memory[group].keys(): - if key not in dictionary.keys(): - self.log_message('Key "{}" not in the dictionary to be logged'.format(key), log_level=self.ERROR) - for key in dictionary.keys(): - if key not in self.perf_memory[group].keys(): - self.log_message('Key "{}" is unknown. New keys are not allowed'.format(key), log_level=self.ERROR) - - for key in dictionary.keys(): - if key in self.perf_memory[group]: - self.perf_memory[group][key].extend([dictionary[key]]) - else: - self.perf_memory[group][key] = [dictionary[key]] - - self.values[group] = self.perf_memory[group] - if should_print: - self.log_dict_message(group, dictionary, description, stack_displacement + 1, log_level) - self.flush() - def log_dict_message(self, group, dictionary, description='', stack_displacement=2, log_level=SUMMARY): if log_level < self.log_level: return -1 @@ -220,28 +173,105 @@ def print_subitem(prefix, subdictionary, stack_displacement=3): self.log_message('{}: {}'.format(group, description), log_level=log_level, stack_displacement=stack_displacement) print_subitem(' ', dictionary, stack_displacement=stack_displacement + 1) - def reload_json(self): - if os.path.isfile(self.path_json): - try: - with open(self.path_json, 'r') as json_file: - self.values = json.load(json_file) - except FileNotFoundError: - self.log_message('json log file can not be open: {}'.format(self.path_json), log_level=self.WARNING) + def _execute(self, statement, parameters=None, commit=True): + parameters = parameters or () + if not isinstance(parameters, tuple): + parameters = (parameters) + return_value = self.sqlite_cur.execute(statement, parameters) + if commit: + self.sqlite_conn.commit() + return return_value + + def _run_query(self, query, parameters=None): + return self._execute(query, parameters, commit=False) + + def _get_internal_table_name(self, table_name): + return f'_{table_name}' + + def _check_table_exists(self, table_name): + table_name = self._get_internal_table_name(table_name) + query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?" + return self._run_query(query, table_name) + + def _create_table(self, table_name): + table_name = self._get_internal_table_name(table_name) + statement = f""" + CREATE TABLE {table_name} ( + "__id" INTEGER PRIMARY KEY AUTOINCREMENT, -- rowid + "__timestamp" DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """ + self._execute(statement, ()) + + def _list_columns(self, table_name): + table_name = self._get_internal_table_name(table_name) + query = "SELECT name FROM PRAGMA_TABLE_INFO(?)" + qry_cur = self._run_query(query, (table_name,)) + columns = [res[0] for res in qry_cur] + return columns + + @staticmethod + def _get_data_type(value): + # Only text and numeric are supported for now + if isinstance(value, str): + return 'TEXT' + else: + return 'NUMERIC' + + def _add_column(self, table_name, column_name, value_sample=None): + table_name = self._get_internal_table_name(table_name) + statement = "ALTER TABLE ? ADD ? {}".format(self._get_data_type(value_sample)) + return self._execute(statement, (table_name, column_name)) + + def _flatten_dict(self, dictionary, flatten_dict=None, prefix=''): + flatten_dict = flatten_dict or {} + for key, value in dictionary.items(): + local_prefix = f'{prefix}.{key}' + if isinstance(value, dict): + self._flatten_dict(value, flatten_dict, prefix=local_prefix) + else: + assert isinstance(value, (float, int)), f'Invalid value type {type(value)} for {local_prefix}' + flatten_dict[local_prefix] = value + return flatten_dict + + def _insert_row(self, table_name, flat_dictionary): + columns = self._list_columns(table_name) + table_name = self._get_internal_table_name(table_name) + column_string = ', '.join(columns) + value_placeholder = ', '.join(['?'] * len(columns)) + statement = f'INSERT INTO ?({column_string}) VALUES({value_placeholder})' + parameters = tuple(val for val in flat_dictionary.values()) + return self._execute(statement, parameters) + + def log_dict(self, group, dictionary, description='', stack_displacement=2, should_print=False, log_level=SUMMARY): + if log_level < self.log_level: + return -1 + + flat_dictionary = self._flatten_dict(dictionary) + if self._check_table_exists(group): + columns = self._list_columns(group) + for key in flat_dictionary.keys(): + if key not in columns: + self.log_message('Key "{}" is unknown. New keys are not allowed'.format(key), log_level=self.ERROR) + for column_name in columns: + if column_name not in flat_dictionary: + self.log_message('Key "{}" not in the dictionary to be logged'.format(key), log_level=self.ERROR) + else: + self._create_table(group) + for key, value in flat_dictionary.keys(): + self._add_column(group, key, value) + + self._insert_row(group, flat_dictionary) + if should_print: + self.log_dict_message(group, dictionary, description, stack_displacement + 1, log_level) + + def init_sqlite(self): + pre_existing = os.path.isfile(self.sqlite_file) + self.sqlite_conn = sqlite3.connect(self.sqlite_file) + self.sqlite_cur = self.sqlite_conn.cursor() + if not pre_existing: + self._create_table('bootstrap') def flush(self): if self.dir_logs: - self.path_tmp = self.path_json + '.tmp' - try: - with open(self.path_tmp, 'w') as json_file: - if self.compactjson: - json.dump(self.values, json_file, separators=(',', ':')) - else: - json.dump(self.values, json_file, indent=4) - if os.path.isfile(self.path_json): - os.remove(self.path_json) - os.rename(self.path_tmp, self.path_json) - except Exception as e: - print(e) - # TODO: Map what exception is this, and replace this "except Exception" for the real exception - # we cannot keep this as is, it will eventually catch things we do not want to catch, like a keyboard interrupt - raise e + self.sqlite_conn.commit() diff --git a/bootstrap/lib/utils.py b/bootstrap/lib/utils.py index c2e6394..eb54b99 100755 --- a/bootstrap/lib/utils.py +++ b/bootstrap/lib/utils.py @@ -53,15 +53,15 @@ def env_info(): info['pip_modules'] = subprocess.check_output(['pip', 'freeze'], stderr=devnull) try: git_branch_cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD'] - git_local_commit_cmd = ['git', 'rev-parse', 'HEAD'] - git_status_cmd = ['git', 'status'] + info['git_branch'] = subprocess.check_output(git_branch_cmd, stderr=devnull).strip().decode('UTF-8') git_origin_commit_cmd = ['git', 'rev-parse', 'origin/{}'.format(info['git_branch'])] git_diff_origin_commit_cmd = ['git', 'diff', 'origin/{}'.format(info['git_branch'])] + git_local_commit_cmd = ['git', 'rev-parse', 'HEAD'] + git_status_cmd = ['git', 'status'] + info['git_origin_commit'] = subprocess.check_output(git_origin_commit_cmd, stderr=devnull) git_log_since_origin_cmd = ['git', 'log', '--pretty=oneline', '{}..HEAD'.format(info['git_origin_commit'])] - info['git_branch'] = subprocess.check_output(git_branch_cmd, stderr=devnull).strip().decode('UTF-8') info['git_local_commit'] = subprocess.check_output(git_local_commit_cmd, stderr=devnull) info['git_status'] = subprocess.check_output(git_status_cmd, stderr=devnull) - info['git_origin_commit'] = subprocess.check_output(git_origin_commit_cmd, stderr=devnull) info['git_diff_origin_commit'] = subprocess.check_output(git_diff_origin_commit_cmd, stderr=devnull) info['git_log_since_origin'] = subprocess.check_output(git_log_since_origin_cmd, stderr=devnull) except subprocess.CalledProcessError: diff --git a/bootstrap/optimizers/lr_scheduler.py b/bootstrap/optimizers/lr_scheduler.py index 37f88a3..1a2d076 100644 --- a/bootstrap/optimizers/lr_scheduler.py +++ b/bootstrap/optimizers/lr_scheduler.py @@ -1,5 +1,4 @@ import torch.optim.lr_scheduler -from bootstrap.lib.logger import Logger class LearningRateScheduler(): @@ -29,7 +28,8 @@ def load_state_dict(self, state): def step_scheduler(self): self.scheduler.step() - Logger().log_value('train_epoch.lr', self.optimizer.param_groups[0]['lr']) + log_dict = {'lr': self.optimizer.param_groups[0]['lr']} + return log_dict class StepLR(LearningRateScheduler): diff --git a/bootstrap/run.py b/bootstrap/run.py index c181de7..bed8b93 100755 --- a/bootstrap/run.py +++ b/bootstrap/run.py @@ -73,6 +73,7 @@ def run(path_opts=None): Logger()('Saving environment info') Logger().log_dict('env_info', utils.env_info()) Logger().log_dict('options', Options(), should_print=True) # display options + Logger().flush() Logger()(os.uname()) # display server name if torch.cuda.is_available(): From bcab6faee7375443cf8027b84788141535861c5d Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Mon, 4 May 2020 17:39:55 -0700 Subject: [PATCH 02/17] Bugfixes --- bootstrap/lib/logger.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index c3ed615..cbbe8c5 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -9,12 +9,13 @@ # SOFTWARE. # ################################################################################# -import os -import sys +import collections +import datetime import inspect +import numbers +import os import sqlite3 -import datetime -import collections +import sys class Logger(object): @@ -176,7 +177,7 @@ def print_subitem(prefix, subdictionary, stack_displacement=3): def _execute(self, statement, parameters=None, commit=True): parameters = parameters or () if not isinstance(parameters, tuple): - parameters = (parameters) + parameters = (parameters,) return_value = self.sqlite_cur.execute(statement, parameters) if commit: self.sqlite_conn.commit() @@ -208,25 +209,28 @@ def _list_columns(self, table_name): query = "SELECT name FROM PRAGMA_TABLE_INFO(?)" qry_cur = self._run_query(query, (table_name,)) columns = [res[0] for res in qry_cur] + # remove __id and __timestamp columns + columns = [c for c in columns if not c.startswith('__')] return columns @staticmethod def _get_data_type(value): - # Only text and numeric are supported for now if isinstance(value, str): return 'TEXT' - else: + if isinstance(value, numbers.Number): return 'NUMERIC' + raise ValueError('Only text and numeric are supported for now') def _add_column(self, table_name, column_name, value_sample=None): table_name = self._get_internal_table_name(table_name) - statement = "ALTER TABLE ? ADD ? {}".format(self._get_data_type(value_sample)) - return self._execute(statement, (table_name, column_name)) + value_type = self._get_data_type(value_sample) + statement = f'ALTER TABLE {table_name} ADD COLUMN {column_name} {value_type}' + return self._execute(statement) def _flatten_dict(self, dictionary, flatten_dict=None, prefix=''): flatten_dict = flatten_dict or {} for key, value in dictionary.items(): - local_prefix = f'{prefix}.{key}' + local_prefix = f'{prefix}.{key}' if prefix else key if isinstance(value, dict): self._flatten_dict(value, flatten_dict, prefix=local_prefix) else: @@ -239,7 +243,7 @@ def _insert_row(self, table_name, flat_dictionary): table_name = self._get_internal_table_name(table_name) column_string = ', '.join(columns) value_placeholder = ', '.join(['?'] * len(columns)) - statement = f'INSERT INTO ?({column_string}) VALUES({value_placeholder})' + statement = f'INSERT INTO {table_name} ({column_string}) VALUES({value_placeholder})' parameters = tuple(val for val in flat_dictionary.values()) return self._execute(statement, parameters) @@ -248,20 +252,21 @@ def log_dict(self, group, dictionary, description='', stack_displacement=2, shou return -1 flat_dictionary = self._flatten_dict(dictionary) - if self._check_table_exists(group): + if self._check_table_exists(group).fetchone(): columns = self._list_columns(group) - for key in flat_dictionary.keys(): + for key in flat_dictionary: if key not in columns: self.log_message('Key "{}" is unknown. New keys are not allowed'.format(key), log_level=self.ERROR) for column_name in columns: if column_name not in flat_dictionary: - self.log_message('Key "{}" not in the dictionary to be logged'.format(key), log_level=self.ERROR) + self.log_message('Key "{}" not in the dictionary to be logged'.format(column_name), log_level=self.ERROR) else: self._create_table(group) - for key, value in flat_dictionary.keys(): + for key, value in flat_dictionary.items(): self._add_column(group, key, value) self._insert_row(group, flat_dictionary) + if should_print: self.log_dict_message(group, dictionary, description, stack_displacement + 1, log_level) From 8e4ef2fe2d74e12992f3dc60f644b29c5e869b32 Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 10 May 2020 12:24:27 -0700 Subject: [PATCH 03/17] Fixes (nested dicts and misc) --- bootstrap/lib/logger.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index cbbe8c5..831f54c 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -208,7 +208,7 @@ def _list_columns(self, table_name): table_name = self._get_internal_table_name(table_name) query = "SELECT name FROM PRAGMA_TABLE_INFO(?)" qry_cur = self._run_query(query, (table_name,)) - columns = [res[0] for res in qry_cur] + columns = (res[0] for res in qry_cur) # remove __id and __timestamp columns columns = [c for c in columns if not c.startswith('__')] return columns @@ -224,7 +224,7 @@ def _get_data_type(value): def _add_column(self, table_name, column_name, value_sample=None): table_name = self._get_internal_table_name(table_name) value_type = self._get_data_type(value_sample) - statement = f'ALTER TABLE {table_name} ADD COLUMN {column_name} {value_type}' + statement = f'ALTER TABLE {table_name} ADD COLUMN "{column_name}" {value_type}' return self._execute(statement) def _flatten_dict(self, dictionary, flatten_dict=None, prefix=''): @@ -233,13 +233,14 @@ def _flatten_dict(self, dictionary, flatten_dict=None, prefix=''): local_prefix = f'{prefix}.{key}' if prefix else key if isinstance(value, dict): self._flatten_dict(value, flatten_dict, prefix=local_prefix) + elif not isinstance(value, (float, int)): + raise TypeError(f'Invalid value type {type(value)} for {local_prefix}') else: - assert isinstance(value, (float, int)), f'Invalid value type {type(value)} for {local_prefix}' flatten_dict[local_prefix] = value return flatten_dict def _insert_row(self, table_name, flat_dictionary): - columns = self._list_columns(table_name) + columns = [f'"{c}"' for c in self._list_columns(table_name)] table_name = self._get_internal_table_name(table_name) column_string = ', '.join(columns) value_placeholder = ', '.join(['?'] * len(columns)) From a3fbfb469f67318e5a4227b366903fe77a7f92ac Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 10 May 2020 15:33:06 -0700 Subject: [PATCH 04/17] Fix flatten dict --- bootstrap/lib/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index 831f54c..3afcd23 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -228,7 +228,7 @@ def _add_column(self, table_name, column_name, value_sample=None): return self._execute(statement) def _flatten_dict(self, dictionary, flatten_dict=None, prefix=''): - flatten_dict = flatten_dict or {} + flatten_dict = flatten_dict if flatten_dict is not None else {} for key, value in dictionary.items(): local_prefix = f'{prefix}.{key}' if prefix else key if isinstance(value, dict): From 9a1c0fab697ec22a85eec694ef377a73e1463d67 Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 10 May 2020 15:33:26 -0700 Subject: [PATCH 05/17] Add tests for sqlite logger --- tests/test_logger.py | 78 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/test_logger.py diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 0000000..faa51e4 --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,78 @@ +import os +import pytest +import random + +from bootstrap.lib.logger import Logger + + +def test_logger_init(tmpdir): + Logger._instance = None + Logger(dir_logs=tmpdir) + + assert os.path.isfile(Logger()._instance.sqlite_file) + + # check default _bootstrap table is empty + statement = "SELECT * FROM _bootstrap" + rows = Logger()._execute(statement).fetchall() + assert rows == [] + + +def test_nested_dict(tmpdir): + Logger._instance = None + Logger(dir_logs=tmpdir) + Logger().log_dict('batch', {'loss': {'value': 0.42}}) + Logger().log_dict('epoch', { + 'timer': { + 'value': 0.42 + }, + 'cpu': { + 'usage': { + 'float_value': 0.8, + } + } + }) + + +def test_new_key(tmpdir): + Logger._instance = None + Logger(dir_logs=tmpdir) + Logger().log_dict('batch', {'loss': .22, 'metric': 0.232}) + with pytest.raises(Exception): + Logger().log_dict('batch', { + 'loss': .22, + 'metric': 0.232, + 'new-metric': 0.42 + }) + + +def test_missing_key(tmpdir): + Logger._instance = None + Logger(dir_logs=tmpdir) + Logger().log_dict('batch', {'loss': .22, 'metric': 0.232}) + with pytest.raises(Exception): + Logger().log_dict('batch', {'loss': .22}) + + +def test_invalid_value(tmpdir): + Logger._instance = None + Logger(dir_logs=tmpdir) + with pytest.raises(TypeError): + Logger().log_dict('batch', {'loss': .22, 'metric': '0.232'}) + + +def test_read(tmpdir): + dicts = [] + for _ in range(3): + dicts.append({'loss': .22, 'metric': 0.232}) + + Logger._instance = None + Logger(dir_logs=tmpdir) + + for batch_dict in dicts: + Logger().log_dict('batch', batch_dict) + + statement = "SELECT loss,metric FROM _batch" + rows = Logger()._execute(statement).fetchall() + for (row, batch_dict) in zip(rows, dicts): + assert row[0] == batch_dict['loss'] + assert row[1] == batch_dict['metric'] From db8b5a6c66e798d96a52a99ffc583af3a5104126 Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 10 May 2020 15:43:42 -0700 Subject: [PATCH 06/17] Use python3 fstring --- bootstrap/lib/logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index 3afcd23..f8d287b 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -257,10 +257,10 @@ def log_dict(self, group, dictionary, description='', stack_displacement=2, shou columns = self._list_columns(group) for key in flat_dictionary: if key not in columns: - self.log_message('Key "{}" is unknown. New keys are not allowed'.format(key), log_level=self.ERROR) + self.log_message(f'Key "{key}" is unknown. New keys are not allowed', log_level=self.ERROR) for column_name in columns: if column_name not in flat_dictionary: - self.log_message('Key "{}" not in the dictionary to be logged'.format(column_name), log_level=self.ERROR) + self.log_message(f'Key "{column_name}" not in the dictionary to be logged', log_level=self.ERROR) else: self._create_table(group) for key, value in flat_dictionary.items(): From 6f5a2dccf4956c5380552055448ec02dc48d053c Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 10 May 2020 15:45:20 -0700 Subject: [PATCH 07/17] Add select method to logger --- bootstrap/lib/logger.py | 13 +++++++++++++ tests/test_logger.py | 9 +++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index f8d287b..0a8d652 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -271,6 +271,19 @@ def log_dict(self, group, dictionary, description='', stack_displacement=2, shou if should_print: self.log_dict_message(group, dictionary, description, stack_displacement + 1, log_level) + def select(self, group, columns=None): + table_name = self._get_internal_table_name(group) + table_columns = self._list_columns(group) + if columns is None: + column_string = '*' + else: + for c in columns: + if c not in table_columns: + self.log_message(f'Unknown column "{c}"', log_level=self.ERROR) + column_string = ', '.join([f'"{c}"' for c in columns]) + statement = f'SELECT {column_string} FROM {table_name}' + return self._execute(statement) + def init_sqlite(self): pre_existing = os.path.isfile(self.sqlite_file) self.sqlite_conn = sqlite3.connect(self.sqlite_file) diff --git a/tests/test_logger.py b/tests/test_logger.py index faa51e4..a1cd42b 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -12,8 +12,7 @@ def test_logger_init(tmpdir): assert os.path.isfile(Logger()._instance.sqlite_file) # check default _bootstrap table is empty - statement = "SELECT * FROM _bootstrap" - rows = Logger()._execute(statement).fetchall() + rows = Logger().select('bootstrap').fetchall() assert rows == [] @@ -71,8 +70,10 @@ def test_read(tmpdir): for batch_dict in dicts: Logger().log_dict('batch', batch_dict) - statement = "SELECT loss,metric FROM _batch" - rows = Logger()._execute(statement).fetchall() + rows = Logger().select('batch', ['loss', 'metric']).fetchall() for (row, batch_dict) in zip(rows, dicts): assert row[0] == batch_dict['loss'] assert row[1] == batch_dict['metric'] + + with pytest.raises(Exception): + rows = Logger().select('batch', ['los', 'metric']).fetchall() From 9068aae3629e18d7046479131e424e00df69e398 Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 10 May 2020 15:57:31 -0700 Subject: [PATCH 08/17] make flake8 happy --- tests/test_logger.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index a1cd42b..36bbd94 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,6 +1,5 @@ import os import pytest -import random from bootstrap.lib.logger import Logger @@ -8,7 +7,7 @@ def test_logger_init(tmpdir): Logger._instance = None Logger(dir_logs=tmpdir) - + assert os.path.isfile(Logger()._instance.sqlite_file) # check default _bootstrap table is empty From 81be059cc369e9a711746f7acb4e67f2d66534bf Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 17 May 2020 11:18:26 -0700 Subject: [PATCH 09/17] Use (float,int) over numbers.number --- bootstrap/lib/logger.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index 0a8d652..bbb0c3f 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -12,7 +12,6 @@ import collections import datetime import inspect -import numbers import os import sqlite3 import sys @@ -217,7 +216,7 @@ def _list_columns(self, table_name): def _get_data_type(value): if isinstance(value, str): return 'TEXT' - if isinstance(value, numbers.Number): + if isinstance(value, (float, int)): return 'NUMERIC' raise ValueError('Only text and numeric are supported for now') From 3863de45580bc3c3e596688bc6e052b30e94c17f Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 17 May 2020 11:33:44 -0700 Subject: [PATCH 10/17] Use tuple type for internal _execute queries --- bootstrap/lib/logger.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index bbb0c3f..4f8a206 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -174,9 +174,8 @@ def print_subitem(prefix, subdictionary, stack_displacement=3): print_subitem(' ', dictionary, stack_displacement=stack_displacement + 1) def _execute(self, statement, parameters=None, commit=True): + assert parameters is None or isinstance(parameters, tuple) parameters = parameters or () - if not isinstance(parameters, tuple): - parameters = (parameters,) return_value = self.sqlite_cur.execute(statement, parameters) if commit: self.sqlite_conn.commit() @@ -191,7 +190,7 @@ def _get_internal_table_name(self, table_name): def _check_table_exists(self, table_name): table_name = self._get_internal_table_name(table_name) query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?" - return self._run_query(query, table_name) + return self._run_query(query, (table_name,)) def _create_table(self, table_name): table_name = self._get_internal_table_name(table_name) @@ -201,7 +200,7 @@ def _create_table(self, table_name): "__timestamp" DATETIME DEFAULT CURRENT_TIMESTAMP ); """ - self._execute(statement, ()) + self._execute(statement) def _list_columns(self, table_name): table_name = self._get_internal_table_name(table_name) From ba62fdbbe8ee865c90b209b193d1e5308d5775d2 Mon Sep 17 00:00:00 2001 From: Micael Carvalho Date: Sun, 17 May 2020 12:12:38 -0700 Subject: [PATCH 11/17] Support lists and Nones --- bootstrap/lib/logger.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index 4f8a206..6037555 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -215,9 +215,9 @@ def _list_columns(self, table_name): def _get_data_type(value): if isinstance(value, str): return 'TEXT' - if isinstance(value, (float, int)): + if isinstance(value, (float, int, type(None))): return 'NUMERIC' - raise ValueError('Only text and numeric are supported for now') + raise ValueError(f'Only text and numeric are supported for now, found {type(value)}') def _add_column(self, table_name, column_name, value_sample=None): table_name = self._get_internal_table_name(table_name) @@ -231,7 +231,10 @@ def _flatten_dict(self, dictionary, flatten_dict=None, prefix=''): local_prefix = f'{prefix}.{key}' if prefix else key if isinstance(value, dict): self._flatten_dict(value, flatten_dict, prefix=local_prefix) - elif not isinstance(value, (float, int)): + elif isinstance(value, list): + dict_list = {idx: val for idx, val in enumerate(value)} + self._flatten_dict(dict_list, flatten_dict, prefix=local_prefix) + elif not isinstance(value, (float, int, str, type(None))): raise TypeError(f'Invalid value type {type(value)} for {local_prefix}') else: flatten_dict[local_prefix] = value From 1c2500575511413bffb77226af2d99d5d3afd932 Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 17 May 2020 12:15:16 -0700 Subject: [PATCH 12/17] Add test for none and str --- tests/test_logger.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index 36bbd94..68908a7 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -51,11 +51,10 @@ def test_missing_key(tmpdir): Logger().log_dict('batch', {'loss': .22}) -def test_invalid_value(tmpdir): +def test_str_and_none_values(tmpdir): Logger._instance = None Logger(dir_logs=tmpdir) - with pytest.raises(TypeError): - Logger().log_dict('batch', {'loss': .22, 'metric': '0.232'}) + Logger().log_dict('batch', {'loss': None, 'metric': '0.232'}) def test_read(tmpdir): From dfeb3cdf95d655ae6de3b7ab9b8a285ee9c5ee9b Mon Sep 17 00:00:00 2001 From: Jean Begaint Date: Sun, 17 May 2020 12:25:26 -0700 Subject: [PATCH 13/17] Fix + tests --- bootstrap/lib/logger.py | 2 +- tests/test_logger.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index 6037555..792b896 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -231,7 +231,7 @@ def _flatten_dict(self, dictionary, flatten_dict=None, prefix=''): local_prefix = f'{prefix}.{key}' if prefix else key if isinstance(value, dict): self._flatten_dict(value, flatten_dict, prefix=local_prefix) - elif isinstance(value, list): + elif isinstance(value, (tuple, list)): dict_list = {idx: val for idx, val in enumerate(value)} self._flatten_dict(dict_list, flatten_dict, prefix=local_prefix) elif not isinstance(value, (float, int, str, type(None))): diff --git a/tests/test_logger.py b/tests/test_logger.py index 68908a7..c5eefe5 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -54,7 +54,20 @@ def test_missing_key(tmpdir): def test_str_and_none_values(tmpdir): Logger._instance = None Logger(dir_logs=tmpdir) + Logger().log_dict('batch', {'loss': 1, 'metric': '0.232'}) Logger().log_dict('batch', {'loss': None, 'metric': '0.232'}) + Logger().log_dict('batch', {'loss': 'ewewew', 'metric': '0.232'}) + + +def test_mixed_nested_dict(tmpdir): + Logger._instance = None + Logger(dir_logs=tmpdir) + Logger().log_dict('batch', + {'loss': [1, 'toto', None, { + 'list': ['a', 'b'], + 'tuple': ('foo', 'bar'), + 'str_': None, + }]}) def test_read(tmpdir): From 8e786ee2a849660f5c805589b9a74df4b0e27070 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Thu, 21 Jan 2021 16:23:11 +0100 Subject: [PATCH 14/17] Fix critical bug with saving best model --- bootstrap/engines/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bootstrap/engines/engine.py b/bootstrap/engines/engine.py index f134529..4c14a97 100755 --- a/bootstrap/engines/engine.py +++ b/bootstrap/engines/engine.py @@ -330,7 +330,7 @@ def eval_epoch(self, model, dataset, epoch, mode='eval', logs_json=True): Logger().flush() self.hook('{}_on_flush'.format(mode)) - return out + return epoch_dict def is_best(self, out, saving_criteria): """ Verify if the last model is the best for a specific saving criteria From 53f50a883aea580a85479d270ffa52f36ef95688 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sun, 24 Jan 2021 22:40:14 +0100 Subject: [PATCH 15/17] Fix plotly for sqlite (Bug with multi-threads & processes) --- bootstrap/views/generate.py | 2 +- bootstrap/views/plotly.py | 30 ++++++++++-------------------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/bootstrap/views/generate.py b/bootstrap/views/generate.py index 442c7b8..56319e1 100644 --- a/bootstrap/views/generate.py +++ b/bootstrap/views/generate.py @@ -4,7 +4,7 @@ def generate(path_opts=None): - Options(path_yaml=path_opts) + Options(source=path_opts) view = factory() view.generate() diff --git a/bootstrap/views/plotly.py b/bootstrap/views/plotly.py index d2b77b3..1668dec 100644 --- a/bootstrap/views/plotly.py +++ b/bootstrap/views/plotly.py @@ -47,14 +47,13 @@ def generate(self): data_dict = {} for log_name in log_names: - path_json = os.path.join(self.exp_dir, - '{}.json'.format(log_name)) - if os.path.isfile(path_json): - with open(path_json, 'r') as handle: - data_json = json.load(handle) - data_dict[log_name] = data_json + path_sqlite = os.path.join(self.exp_dir, '{}.sqlite'.format(log_name)) + if os.path.isfile(path_sqlite): + Logger._instance = None + data_sqlite = Logger(dir_logs=self.exp_dir, name=log_name) + data_dict[log_name] = data_sqlite else: - Logger()("Json log file '{}' not found in '{}'".format(log_name, path_json), log_level=Logger.WARNING) + Logger()("sqlite log file '{}' not found in '{}'".format(log_name, path_sqlite), log_level=Logger.WARNING) nb_keys = len(items) nb_rows = math.ceil(nb_keys / 2) @@ -83,10 +82,6 @@ def generate(self): if view['log_name'] not in data_dict: continue - if view['view_name'] not in data_dict[view['log_name']]: - Logger()("View '{}' not in '{}.json'".format(view['view_name'], view['log_name']), log_level=Logger.WARNING) - continue - if view['split_name'] not in colors: Logger()( "Split '{}' not in colors '{}'".format(view['split_name'], list(colors.keys())), @@ -95,16 +90,11 @@ def generate(self): else: color = colors[view['split_name']] - y = data_dict[view['log_name']][view['view_name']] + group = view['view_name'].split('.')[0] + columns = [view['view_name'].split('.')[1]] - if 'epoch' in view['split_name']: - # example: data_dict['logs_last']['test_epoch.epoch'] - key = view['split_name'] + '.epoch' # TODO: ugly fix, to be remove - if key not in data_dict[view['log_name']]: - key = 'eval_epoch.epoch' - x = data_dict[view['log_name']][key] - else: - x = list(range(len(y))) + y = [x[0] for x in data_dict[view['log_name']].select(group, columns)] + x = list(range(len(y))) scatter = go.Scatter( x=x, From e4ef1f80f4990bdff7e6c6d04d0f9298d524a874 Mon Sep 17 00:00:00 2001 From: Micael Carvalho Date: Sun, 24 Jan 2021 15:58:41 -0800 Subject: [PATCH 16/17] Fix multi-threading logger --- bootstrap/lib/logger.py | 106 +++++++++++++++++++++---------------- bootstrap/views/factory.py | 4 +- bootstrap/views/plotly.py | 25 ++++----- 3 files changed, 75 insertions(+), 60 deletions(-) diff --git a/bootstrap/lib/logger.py b/bootstrap/lib/logger.py index 792b896..99c65f0 100755 --- a/bootstrap/lib/logger.py +++ b/bootstrap/lib/logger.py @@ -15,6 +15,7 @@ import os import sqlite3 import sys +from contextlib import closing class Logger(object): @@ -71,15 +72,22 @@ def code(value): log_level = None # log level dir_logs = None - sqlite_cur = None sqlite_file = None - sqlite_conn = None + connection = None path_txt = None file_txt = None name = None max_lineno_width = 3 - def __new__(cls, dir_logs=None, name='logs'): + def __new__(cls, dir_logs=None, name=None): + return Logger._get_instance(dir_logs, name) + + def __call__(self, *args, **kwargs): + return self.log_message(*args, **kwargs, stack_displacement=2) + + @staticmethod + def _get_instance(dir_logs=None, name=None): + name = name or 'logs' if Logger._instance is None: Logger._instance = object.__new__(Logger) Logger._instance.set_level(Logger._instance.INFO) @@ -90,16 +98,12 @@ def __new__(cls, dir_logs=None, name='logs'): Logger._instance.path_txt = os.path.join(dir_logs, '{}.txt'.format(name)) Logger._instance.file_txt = open(os.path.join(dir_logs, '{}.txt'.format(name)), 'a+') Logger._instance.sqlite_file = os.path.join(dir_logs, '{}.sqlite'.format(name)) - Logger._instance.init_sqlite() else: Logger._instance.log_message('No logs files will be created (dir_logs attribute is empty)', log_level=Logger.WARNING) return Logger._instance - def __call__(self, *args, **kwargs): - return self.log_message(*args, **kwargs, stack_displacement=2) - def set_level(self, log_level): self.log_level = log_level @@ -173,24 +177,24 @@ def print_subitem(prefix, subdictionary, stack_displacement=3): self.log_message('{}: {}'.format(group, description), log_level=log_level, stack_displacement=stack_displacement) print_subitem(' ', dictionary, stack_displacement=stack_displacement + 1) - def _execute(self, statement, parameters=None, commit=True): + def _execute(self, statement, parameters=None, commit=True, cursor=None): assert parameters is None or isinstance(parameters, tuple) parameters = parameters or () - return_value = self.sqlite_cur.execute(statement, parameters) + return_value = cursor.execute(statement, parameters) if commit: - self.sqlite_conn.commit() + self.get_conn().commit() return return_value - def _run_query(self, query, parameters=None): - return self._execute(query, parameters, commit=False) + def _run_query(self, query, parameters=None, cursor=None): + return self._execute(query, parameters, commit=False, cursor=cursor) def _get_internal_table_name(self, table_name): return f'_{table_name}' - def _check_table_exists(self, table_name): + def _check_table_exists(self, table_name, cursor=None): table_name = self._get_internal_table_name(table_name) query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?" - return self._run_query(query, (table_name,)) + return self._run_query(query, (table_name,), cursor=cursor) def _create_table(self, table_name): table_name = self._get_internal_table_name(table_name) @@ -200,15 +204,17 @@ def _create_table(self, table_name): "__timestamp" DATETIME DEFAULT CURRENT_TIMESTAMP ); """ - self._execute(statement) + with closing(self.get_conn().cursor()) as cursor: + self._execute(statement, cursor=cursor) def _list_columns(self, table_name): table_name = self._get_internal_table_name(table_name) query = "SELECT name FROM PRAGMA_TABLE_INFO(?)" - qry_cur = self._run_query(query, (table_name,)) - columns = (res[0] for res in qry_cur) - # remove __id and __timestamp columns - columns = [c for c in columns if not c.startswith('__')] + with closing(self.get_conn().cursor()) as cursor: + qry_cur = self._run_query(query, (table_name,), cursor=cursor) + columns = (res[0] for res in qry_cur) + # remove __id and __timestamp columns + columns = [c for c in columns if not c.startswith('__')] return columns @staticmethod @@ -223,7 +229,8 @@ def _add_column(self, table_name, column_name, value_sample=None): table_name = self._get_internal_table_name(table_name) value_type = self._get_data_type(value_sample) statement = f'ALTER TABLE {table_name} ADD COLUMN "{column_name}" {value_type}' - return self._execute(statement) + with closing(self.get_conn().cursor()) as cursor: + return self._execute(statement, cursor=cursor) def _flatten_dict(self, dictionary, flatten_dict=None, prefix=''): flatten_dict = flatten_dict if flatten_dict is not None else {} @@ -247,51 +254,58 @@ def _insert_row(self, table_name, flat_dictionary): value_placeholder = ', '.join(['?'] * len(columns)) statement = f'INSERT INTO {table_name} ({column_string}) VALUES({value_placeholder})' parameters = tuple(val for val in flat_dictionary.values()) - return self._execute(statement, parameters) + with closing(self.get_conn().cursor()) as cursor: + return self._execute(statement, parameters, cursor=cursor) def log_dict(self, group, dictionary, description='', stack_displacement=2, should_print=False, log_level=SUMMARY): if log_level < self.log_level: return -1 flat_dictionary = self._flatten_dict(dictionary) - if self._check_table_exists(group).fetchone(): - columns = self._list_columns(group) - for key in flat_dictionary: - if key not in columns: - self.log_message(f'Key "{key}" is unknown. New keys are not allowed', log_level=self.ERROR) - for column_name in columns: - if column_name not in flat_dictionary: - self.log_message(f'Key "{column_name}" not in the dictionary to be logged', log_level=self.ERROR) - else: - self._create_table(group) - for key, value in flat_dictionary.items(): - self._add_column(group, key, value) + with closing(self.get_conn().cursor()) as cursor: + if self._check_table_exists(group, cursor=cursor).fetchone(): + columns = self._list_columns(group) + for key in flat_dictionary: + if key not in columns: + self.log_message(f'Key "{key}" is unknown. New keys are not allowed', log_level=self.ERROR) + for column_name in columns: + if column_name not in flat_dictionary: + self.log_message(f'Key "{column_name}" not in the dictionary to be logged', log_level=self.ERROR) + else: + self._create_table(group) + for key, value in flat_dictionary.items(): + self._add_column(group, key, value) self._insert_row(group, flat_dictionary) if should_print: self.log_dict_message(group, dictionary, description, stack_displacement + 1, log_level) - def select(self, group, columns=None): - table_name = self._get_internal_table_name(group) - table_columns = self._list_columns(group) + @staticmethod + def select(group, columns=None): + logger = Logger._get_instance(dir_logs=None, name=None) + table_name = logger._get_internal_table_name(group) + table_columns = logger._list_columns(group) if columns is None: column_string = '*' else: for c in columns: if c not in table_columns: - self.log_message(f'Unknown column "{c}"', log_level=self.ERROR) + logger.log_message(f'Unknown column "{c}"', log_level=Logger.ERROR) column_string = ', '.join([f'"{c}"' for c in columns]) statement = f'SELECT {column_string} FROM {table_name}' - return self._execute(statement) - - def init_sqlite(self): - pre_existing = os.path.isfile(self.sqlite_file) - self.sqlite_conn = sqlite3.connect(self.sqlite_file) - self.sqlite_cur = self.sqlite_conn.cursor() - if not pre_existing: - self._create_table('bootstrap') + with closing(logger.get_conn().cursor()) as cursor: + return logger._execute(statement, cursor=cursor, commit=False).fetchall() + + def get_conn(self): + if self.connection is None: + pre_existing = os.path.isfile(self.sqlite_file) + connection = sqlite3.connect(self.sqlite_file, check_same_thread=False, isolation_level='IMMEDIATE') + self.connection = connection + if not pre_existing: + self._create_table('bootstrap') + return self.connection def flush(self): if self.dir_logs: - self.sqlite_conn.commit() + self.get_conn().commit() diff --git a/bootstrap/views/factory.py b/bootstrap/views/factory.py index 17a4e20..f75b566 100644 --- a/bootstrap/views/factory.py +++ b/bootstrap/views/factory.py @@ -5,8 +5,6 @@ def factory(engine=None): - Logger()('Creating views...') - # if views does not exist, pick view # to support backward compatibility if 'views' in Options(): @@ -24,6 +22,8 @@ def factory(engine=None): exp_dir = Options()['exp.dir'] + Logger(exp_dir)('Creating views...') + if 'names' in opt: view = make_multi_views(opt, exp_dir) return view diff --git a/bootstrap/views/plotly.py b/bootstrap/views/plotly.py index 1668dec..07bb013 100644 --- a/bootstrap/views/plotly.py +++ b/bootstrap/views/plotly.py @@ -1,5 +1,4 @@ import os -import json import math import plotly.graph_objects as go from plotly.subplots import make_subplots @@ -49,7 +48,6 @@ def generate(self): for log_name in log_names: path_sqlite = os.path.join(self.exp_dir, '{}.sqlite'.format(log_name)) if os.path.isfile(path_sqlite): - Logger._instance = None data_sqlite = Logger(dir_logs=self.exp_dir, name=log_name) data_dict[log_name] = data_sqlite else: @@ -93,16 +91,19 @@ def generate(self): group = view['view_name'].split('.')[0] columns = [view['view_name'].split('.')[1]] - y = [x[0] for x in data_dict[view['log_name']].select(group, columns)] - x = list(range(len(y))) - - scatter = go.Scatter( - x=x, - y=y, - name=view['view_interim'], - line={'color': color} - ) - figure.append_trace(scatter, figure_pos_y, figure_pos_x) + try: + y = [x[0] for x in data_dict[view['log_name']].select(group, columns)] + x = list(range(len(y))) + + scatter = go.Scatter( + x=x, + y=y, + name=view['view_interim'], + line={'color': color} + ) + figure.append_trace(scatter, figure_pos_y, figure_pos_x) + except Exception: + pass figure['layout'].update( autosize=True, From 7782b4c3e0a46215789595452890fd096cd2bdeb Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Fri, 29 Jan 2021 14:37:35 +0100 Subject: [PATCH 17/17] Fix compare for sqlite --- bootstrap/compare.py | 229 ++++++++++++++++++++++--------------------- 1 file changed, 118 insertions(+), 111 deletions(-) diff --git a/bootstrap/compare.py b/bootstrap/compare.py index 51f0ab1..7585644 100644 --- a/bootstrap/compare.py +++ b/bootstrap/compare.py @@ -1,121 +1,128 @@ -import json -import numpy as np import argparse from os import path as osp from tabulate import tabulate - - -def load_values(dir_logs, metrics, nb_epochs=-1, best=None): - json_files = {} - values = {} - - # load argsup of best - if best: - if best['json'] not in json_files: - with open(osp.join(dir_logs, f'{best["json"]}.json')) as f: - json_files[best['json']] = json.load(f) - - jfile = json_files[best['json']] - vals = jfile[best['name']] - end = len(vals) if nb_epochs == -1 else nb_epochs - argsup = np.__dict__[f'arg{best["order"]}'](vals[:end]) - - # load logs - for _key, metric in metrics.items(): - # open json_files - if metric['json'] not in json_files: - with open(osp.join(dir_logs, f'{metric["json"]}.json')) as f: - json_files[metric['json']] = json.load(f) - - jfile = json_files[metric['json']] - - if 'train' in metric['name']: - epoch_key = 'train_epoch.epoch' - else: - epoch_key = 'eval_epoch.epoch' - - if epoch_key in jfile: - epochs = jfile[epoch_key] - else: - epochs = jfile['epoch'] - - vals = jfile[metric['name']] - if not best: - end = len(vals) if nb_epochs == -1 else nb_epochs - argsup = np.__dict__[f'arg{metric["order"]}'](vals[:end]) - - try: - values[metric['name']] = epochs[argsup], vals[argsup] - except IndexError: - values[metric['name']] = epochs[argsup - 1], vals[argsup - 1] - return values - - -def main(args): - dir_logs = {} - for raw in args.dir_logs: - tmp = raw.split(':') - if len(tmp) == 2: - key, path = tmp - elif len(tmp) == 1: - path = tmp[0] - key = osp.basename(osp.normpath(path)) - else: - raise ValueError(raw) - dir_logs[key] = path - - metrics = {} - for json_obj, name, order in args.metrics: - metrics[f'{json_obj}_{name}'] = { - 'json': json_obj, - 'name': name, - 'order': order - } - - if args.best: - json_obj, name, order = args.best - best = { - 'json': json_obj, - 'name': name, - 'order': order - } - else: - best = None - - logs = {} - for name, dir_log in dir_logs.items(): - logs[name] = load_values(dir_log, metrics, - nb_epochs=args.nb_epochs, - best=best) - - for _key, metric in metrics.items(): - names = [] - values = [] - epochs = [] - for name, vals in logs.items(): - if metric['name'] in vals: - names.append(name) - epoch, value = vals[metric['name']] - epochs.append(epoch) - values.append(value) - if values: - values_names = sorted(zip(values, names, epochs), reverse=metric['order'] == 'max') - values_names = [[i + 1, name, value, epoch] for i, (value, name, epoch) in enumerate(values_names)] - print('\n\n## {}\n'.format(metric['name'])) - print(tabulate(values_names, headers=['Place', 'Method', 'Score', 'Epoch'])) +import sqlite3 +from contextlib import closing + +# def get_internal_table_name(table_name): +# return f'_{table_name}' + +# def run_query(conn, query, parameters=None, cursor=None): +# return execute(conn, query, parameters, commit=False, cursor=cursor) + +# def list_columns(conn, table_name): +# table_name = get_internal_table_name(table_name) +# query = "SELECT name FROM PRAGMA_TABLE_INFO(?)" +# with closing(conn.cursor()) as cursor: +# qry_cur = run_query(conn, query, (table_name,), cursor=cursor) +# columns = (res[0] for res in qry_cur) +# # remove __id and __timestamp columns +# columns = [c for c in columns if not c.startswith('__')] +# return columns + +# def select(conn, group, columns=None, where=None): +# table_name = get_internal_table_name(group) +# table_columns = list_columns(conn, group) +# if columns is None: +# column_string = '*' +# else: +# for c in columns: +# if c not in table_columns: +# Logger()(f'Unknown column "{c}"', log_level=Logger.ERROR) +# column_string = ', '.join([f'"{c}"' for c in columns]) +# statement = f'SELECT {column_string} FROM {table_name}' +# with closing(conn.cursor()) as cursor: +# return execute(conn, statement, cursor=cursor, commit=False).fetchall() + + +def execute(conn, statement, parameters=None, commit=True, cursor=None): + assert parameters is None or isinstance(parameters, tuple) + parameters = parameters or () + return_value = cursor.execute(statement, parameters) + if commit: + conn.commit() + return return_value + + +def load_table(list_dir, metric, nb_epochs=None, best=None): + table = [] + for dir_logs in list_dir: + # if metric['fname'] == best['fname']: + # path_sql = osp.join(dir_logs, f'{metric["fname"]}.sqlite') + # conn = sqlite3.connect(path_sql, check_same_thread=False, isolation_level='IMMEDIATE') + # statement = f'SELECT m.{metric["column"]}, m.epoch FROM _{metric["group"]} AS m, _{best["group"]} AS b' + # if nb_epochs: + # statement += f' WHERE m.epoch < {nb_epochs}' + # if best['order'] == 'max': + # order = 'DESC' + # elif best['order'] == 'min': + # order = 'ASC' + # statement += f' ORDER BY b.{best["column"]} {order} LIMIT 1' + # with closing(conn.cursor()) as cursor: + # score, epoch = execute(conn, statement, cursor=cursor).fetchone() + # else: + path_sql = osp.join(dir_logs, f'{best["fname"]}.sqlite') + conn = sqlite3.connect(path_sql, check_same_thread=False, isolation_level='IMMEDIATE') + statement = f'SELECT {best["column"]}, epoch FROM _{best["group"]}' + if nb_epochs: + statement += f' WHERE epoch < {nb_epochs}' + if best['order'] == 'max': + order = 'DESC' + elif best['order'] == 'min': + order = 'ASC' + statement += f' ORDER BY {best["column"]} {order} LIMIT 1' + with closing(conn.cursor()) as cursor: + best_score, best_epoch = execute(conn, statement, cursor=cursor).fetchone() + + path_sql = osp.join(dir_logs, f'{metric["fname"]}.sqlite') + conn = sqlite3.connect(path_sql, check_same_thread=False, isolation_level='IMMEDIATE') + statement = f'SELECT {metric["column"]}, epoch FROM _{metric["group"]}' + statement += f' WHERE epoch == {best_epoch}' + with closing(conn.cursor()) as cursor: + score, epoch = execute(conn, statement, cursor=cursor).fetchone() + + table.append([dir_logs, score, epoch]) + + if best['order'] == 'max': + reverse = True + elif best['order'] == 'min': + reverse = False + table.sort(key=lambda x: x[1], reverse=reverse) + + for i, x in enumerate(table): + x.insert(0, f'# {i+1}') + return table + + +def metric_str_to_dict(metric): + split_ = metric.split('.') + return { + 'fname': split_[0], + 'group': split_[1], + 'column': split_[2], + 'order': split_[3] + } + + +def display_metrics(list_dir, metrics, nb_epochs=None, best=None): + best = metric_str_to_dict(best) + for mstr in metrics: + metric = metric_str_to_dict(mstr) + table = load_table(list_dir, metric, nb_epochs=nb_epochs, best=best) + print(f'\n\n## {mstr}\n') + print(tabulate(table, headers=['Place', 'Method', 'Score', 'Epoch'])) if __name__ == '__main__': parser = argparse.ArgumentParser(description='') parser.add_argument('-n', '--nb_epochs', default=-1, type=int) parser.add_argument('-d', '--dir_logs', default='', type=str, nargs='*') - parser.add_argument('-m', '--metrics', type=str, action='append', nargs=3, - metavar=('json', 'name', 'order'), - default=[['logs', 'eval_epoch.accuracy_top1', 'max'], - ['logs', 'eval_epoch.accuracy_top5', 'max'], - ['logs', 'eval_epoch.loss', 'min']]) - parser.add_argument('-b', '--best', type=str, nargs=3, - metavar=('json', 'name', 'order'), - default=['logs', 'eval_epoch.accuracy_top1', 'max']) + parser.add_argument('-m', '--metrics', type=str, nargs='*', + default=['logs.eval_epoch.accuracy.max', + 'logs.train_epoch.loss.min', + 'logs.train_epoch.accuracy.max']) + parser.add_argument('-b', '--best', type=str, + default='logs.eval_epoch.accuracy.max') args = parser.parse_args() - main(args) + nb_epochs = None if args.nb_epochs == -1 else args.nb_epochs + display_metrics(args.dir_logs, args.metrics, nb_epochs, args.best)