-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAnswers.py
More file actions
69 lines (55 loc) · 2.27 KB
/
Answers.py
File metadata and controls
69 lines (55 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# These are the answer for the Flappy Bird Project. It includes some commented solutions for different possible implementations.
import random
import torch
import torch.nn as nn
MUTATION_RATE = 0.2 # Probability of each weight/bias being mutated
MUTATION_STRENGTH = 0.3
N = 10 # Number of birds to select for next round
class Model(nn.Module):
def __init__(self, in_channels, out_channels):
super(Model, self).__init__()
self.fc1 = nn.Linear(in_channels, out_channels)
def forward(self, input):
x = self.fc1(input)
x = torch.sigmoid(x)
return x
# Mutates the weights and bias of a bird's model using a mutation rate and strength
# specified by the parameters
def mutate_bird(bird):
new_bird = bird.clone()
with torch.no_grad():
for param in new_bird.model.parameters():
# Sets the mutation to be between [-strength, strength]
mutation = 2 * (torch.rand_like(param.data) - 0.5) * MUTATION_STRENGTH
# Sets (1 - rate) of the values to be 0
mutation = mutation * (torch.rand_like(param.data) < MUTATION_RATE).float()
# Adds the mutation
param.data += mutation
return new_bird
"""def select_next_round(birds):
# Sorts birds in descending order
birds.sort(key=lambda bird: bird.deathTime, reverse=True)
# Gets first 3 birds
best_birds = birds[:3]
# Prints best 3 birds' death times
print([best_birds[i].deathTime for i in range(3)])
return best_birds"""
# ORRRRRR
def select_next_round(birds):
best_birds = []
total_fitness = sum(bird.deathTime for bird in birds)
if total_fitness == 0:
return birds[:N]
raw_weights = [bird.deathTime for bird in birds]
# pick top N birds based on probabilities
# probabilities = [fitness / total_fitness for fitness in raw_weights]
# for _ in range(N):
# random_value = random.random()
# for i in range(len(birds)):
# random_value -= probabilities[i]
# if random_value <= 0:
# best_birds.append(birds[i])
# break
# ORRRRR
best_birds = random.choices(birds, weights=raw_weights, k=N)
return best_birds