-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
135 lines (115 loc) · 5.21 KB
/
main.py
File metadata and controls
135 lines (115 loc) · 5.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
import os
import pickle
import sys
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
from nsga import run
from eval_functions import functions
from organism import Organism
from plot_utils import (fast_non_dominated_sort, final_pop_distribution,
final_pop_histogram, get_perfect_pop, T)
from random import seed
def plot_line(log:dict, generations, ylabel, title, save_loc, logscale=False, transparent=False):
figure, axis = plt.subplots(1, 1)
for func_name in log.keys():
axis.plot(generations, log[func_name], label=func_name)
if logscale:
axis.set_yscale("log")
figure.supxlabel("Generations")
figure.supylabel(ylabel)
figure.legend()
if transparent:
figure.patch.set_alpha(0.0)
plt.savefig("{}/{}.png".format(save_loc, title))
plt.close()
def plotParetoFront(population, config, save_loc=None, first_front_only=False):
#sort
allFronts = fast_non_dominated_sort(population)
#plot
funcNames = list(config["eval_funcs"].keys())
for i, feature1 in enumerate(funcNames):
for j, feature2 in enumerate(funcNames):
if j <= i: continue
for frontNumber in sorted(allFronts.keys()):
R = sorted(sorted([(org.errors[feature1], org.errors[feature2]) for org in allFronts[frontNumber]],
key=lambda r: r[1], reverse=True), key=lambda r: r[0])
plt.plot(*T(R), marker="o", linestyle="--",label=frontNumber)
if first_front_only: break
plt.title(feature1 + " " + feature2)
plt.xlabel(feature1 + " Error")
plt.ylabel(feature2 + " Error")
plt.legend()
if save_loc is not None:
plt.savefig("{}/pareto_{}_{}.png".format(save_loc, feature1, feature2))
plt.close()
else:
plt.show()
def diversity(perfect_pop:list[Organism], popsize:int, save_loc_i:str):
N = len(perfect_pop)
with open("{}/diversity.csv".format(save_loc_i), 'w') as diversityFile:
diversityFile.write("property,entropy,uniformity,spread,unique_types,optimized_proportion\n")
for name in functions:
typeCounter = Counter([organism.getProperty(name)
if "distribution" not in name
else tuple(organism.getProperty(name))
for organism in perfect_pop])
entropy = -sum([(count/N)*np.log2(count/N) for count in typeCounter.values()])
uniformity = entropy / np.log2(len(typeCounter))
spread = entropy / np.log2(N)
unique_types = len(typeCounter)
optimized_proportion = N / popsize
diversityFile.write("{},{},{},{},{},{}\n".format(name, entropy, uniformity, spread,
unique_types, optimized_proportion))
def run_rep(i, save_loc, config):
seed(i)
save_loc_i = "{}/{}".format(save_loc, i)
if not os.path.exists(save_loc_i):
os.makedirs(save_loc_i)
objectives = config["eval_funcs"]
final_pop, fitness_log, diversity_log = run(config)
perfect_pop = get_perfect_pop(final_pop, objectives)
if config["save_data"] == 1:
with open("{}/final_pop.pkl".format(save_loc_i), "wb") as f:
pickle.dump(final_pop, f)
with open("{}/fitness_log.pkl".format(save_loc_i), "wb") as f:
pickle.dump(fitness_log, f)
with open("{}/diversity_log.pkl".format(save_loc_i), "wb") as f:
pickle.dump(diversity_log, f)
diversity(perfect_pop, config["popsize"], save_loc_i)
if config["plot_data"] == 1:
tracking_frequency = config["tracking_frequency"]
generations = [x*tracking_frequency for x in range((config["num_generations"]//tracking_frequency)+1)]
if len(perfect_pop) > 0:
final_pop_histogram(perfect_pop, objectives, save_loc_i, plot_all=True)
final_pop_distribution(perfect_pop, objectives, save_loc_i, plot_all=True, with_error=True)
plot_line(fitness_log, generations, "Error", "fitness", save_loc_i, logscale=True)
plot_line(diversity_log, generations, "Count of Unique Types", "unique_types", save_loc_i)
plotParetoFront(final_pop, config, save_loc_i, first_front_only=False)
def main(config, rep=None):
save_loc = "{}/{}".format(config["data_dir"], config["name"])
if not os.path.exists(save_loc):
os.makedirs(save_loc)
config_path = "{}/config.json".format(save_loc)
with open(config_path, "w") as f:
json.dump(config, f, indent=4)
if rep: #cmd specified only
run_rep(rep, save_loc, config)
else:
for i in range(config["reps"]):
run_rep(i, save_loc, config)
if __name__ == "__main__":
try:
config_file = sys.argv[1]
config = json.load(open(config_file))
except:
print("Please give a valid config json to read parameters from.")
exit()
if len(sys.argv) == 2:
main(config)
elif len(sys.argv) == 3:
rep = sys.argv[2]
main(config, rep)
else:
print("Please pass in valid arguments: config and (rep)(optional)")