From 0ec8077c729257e5cb641fc2fe3b7bb970e96b9f Mon Sep 17 00:00:00 2001 From: tommymarkstein Date: Sun, 13 Aug 2023 06:35:43 +0200 Subject: [PATCH] feature: store and load statistics --- neat/checkpoint.py | 34 ++++++++++++++++++++++++++-------- neat/reporting.py | 2 +- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/neat/checkpoint.py b/neat/checkpoint.py index d2899e9a..74ad7193 100644 --- a/neat/checkpoint.py +++ b/neat/checkpoint.py @@ -7,6 +7,7 @@ from neat.population import Population from neat.reporting import BaseReporter +import neat.reporting class Checkpointer(BaseReporter): @@ -38,7 +39,7 @@ def __init__(self, generation_interval=100, time_interval_seconds=300, def start_generation(self, generation): self.current_generation = generation - def end_generation(self, config, population, species_set): + def end_generation(self, config, population, species_set, reporters): checkpoint_due = False if self.time_interval_seconds is not None: @@ -52,23 +53,40 @@ def end_generation(self, config, population, species_set): checkpoint_due = True if checkpoint_due: - self.save_checkpoint(config, population, species_set, self.current_generation) + self.save_checkpoint(config, population, species_set, self.current_generation, reporters) self.last_generation_checkpoint = self.current_generation self.last_time_checkpoint = time.time() - def save_checkpoint(self, config, population, species_set, generation): + def save_checkpoint(self, config, population, species_set, generation,reporters): """ Save the current simulation state. """ filename = '{0}{1}'.format(self.filename_prefix, generation) print("Saving checkpoint to {0}".format(filename)) with gzip.open(filename, 'w', compresslevel=5) as f: - data = (generation, config, population, species_set, random.getstate()) + data = (generation, config, population, species_set, random.getstate(),reporters) pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) @staticmethod def restore_checkpoint(filename): """Resumes the simulation from a previous saved point.""" - with gzip.open(filename) as f: - generation, config, population, species_set, rndstate = pickle.load(f) - random.setstate(rndstate) - return Population(config, (population, species_set, generation)) + try: + with gzip.open(filename) as f: + generation, config, population, species_set, rndstate, reporters = pickle.load(f) + except ValueError as e: + with gzip.open(filename) as f: + print("There was an issue loading the checkpoint. Trying old mode") + generation, config, population, species_set, rndstate = pickle.load(f) + + + random.setstate(rndstate) + p = Population(config, (population, species_set, generation)) + if 'reporters' in locals(): + print("Reporters were extracted from the save file and will be added to the population") + for reporter in reporters: + p.add_reporter(reporter) + else: + print("No reporters found. Adding an StdOut and Statistics reporter") + p.add_reporter(neat.StdOutReporter(True)) + p.add_reporter(neat.StatisticsReporter()) + return p + diff --git a/neat/reporting.py b/neat/reporting.py index d85eec82..1402cee6 100644 --- a/neat/reporting.py +++ b/neat/reporting.py @@ -29,7 +29,7 @@ def start_generation(self, gen): def end_generation(self, config, population, species_set): for r in self.reporters: - r.end_generation(config, population, species_set) + r.end_generation(config, population, species_set, self.reporters) def post_evaluate(self, config, population, species, best_genome): for r in self.reporters: