Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions neat/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from neat.population import Population
from neat.reporting import BaseReporter
import neat.reporting


class Checkpointer(BaseReporter):
Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(self, generation_interval=100, time_interval_seconds=300,
def start_generation(self, generation):
self.current_generation = generation

def end_generation(self, config, population, species_set):
def end_generation(self, config, population, species_set, reporters):
checkpoint_due = False

if self.time_interval_seconds is not None:
Expand All @@ -52,23 +53,40 @@ def end_generation(self, config, population, species_set):
checkpoint_due = True

if checkpoint_due:
self.save_checkpoint(config, population, species_set, self.current_generation)
self.save_checkpoint(config, population, species_set, self.current_generation, reporters)
self.last_generation_checkpoint = self.current_generation
self.last_time_checkpoint = time.time()

def save_checkpoint(self, config, population, species_set, generation):
def save_checkpoint(self, config, population, species_set, generation,reporters):
""" Save the current simulation state. """
filename = '{0}{1}'.format(self.filename_prefix, generation)
print("Saving checkpoint to {0}".format(filename))

with gzip.open(filename, 'w', compresslevel=5) as f:
data = (generation, config, population, species_set, random.getstate())
data = (generation, config, population, species_set, random.getstate(),reporters)
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

@staticmethod
def restore_checkpoint(filename):
"""Resumes the simulation from a previous saved point."""
with gzip.open(filename) as f:
generation, config, population, species_set, rndstate = pickle.load(f)
random.setstate(rndstate)
return Population(config, (population, species_set, generation))
try:
with gzip.open(filename) as f:
generation, config, population, species_set, rndstate, reporters = pickle.load(f)
except ValueError as e:
with gzip.open(filename) as f:
print("There was an issue loading the checkpoint. Trying old mode")
generation, config, population, species_set, rndstate = pickle.load(f)


random.setstate(rndstate)
p = Population(config, (population, species_set, generation))
if 'reporters' in locals():
print("Reporters were extracted from the save file and will be added to the population")
for reporter in reporters:
p.add_reporter(reporter)
else:
print("No reporters found. Adding an StdOut and Statistics reporter")
p.add_reporter(neat.StdOutReporter(True))
p.add_reporter(neat.StatisticsReporter())
return p

2 changes: 1 addition & 1 deletion neat/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def start_generation(self, gen):

def end_generation(self, config, population, species_set):
for r in self.reporters:
r.end_generation(config, population, species_set)
r.end_generation(config, population, species_set, self.reporters)

def post_evaluate(self, config, population, species, best_genome):
for r in self.reporters:
Expand Down