-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTree.py
More file actions
132 lines (93 loc) · 5.1 KB
/
Tree.py
File metadata and controls
132 lines (93 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import math
import time
import numpy as np
from matplotlib import pyplot as plt
# A tree is constructed per game. All nodes must be part of a tree
class Tree:
def __init__(self, top_node):
self.top = top_node
def get_top_node(self):
return self.top
def print_all_nodes_as_boards(self):
for child in self.get_top_node().get_children():
child.get_state().get_board().print_board()
# Perform Monte Carlo tree search until one of the players win and add data to the RBUF
def mcts_tree_default_until_end(self, rollouts_per_episode, RBUF, visualize, min_pause_length, node_expansion=1, anet=None):
current_node = self.get_top_node()
if visualize[1]:
current_node.get_state().get_board().initialize_board_plot()
while not current_node.is_endstate():
time_start = time.time()
current_node.set_as_top_node()
current_node.set_leaf_status()
for i in range(rollouts_per_episode):
#current_node.mcts_tree_policy(player, opposing_player, node_expansion)
current_node.mcts_tree_policy(node_expansion, anet)
#print(current_node.get_score())
# Since every child of a node rarely have been generated, create an array with board_size**2 elements of type [node_number, probability]
current_root_arcs = []
for i in range(0, current_node.get_state().get_board().get_board_size()**2):
current_root_arcs.append([i, 0.0])
# If a winning move is a direct child it won't have a lot of visits because it is an endnode
# If any of the children is an endstate, give this a score of 1 and all others a score of 0
for child in current_node.get_children():
if child.is_endstate():
child.set_score([1, 1])
for child2 in current_node.get_children():
if child2 != child:
child2.set_score([0, 0])
break
# Add each child's visit count to their fixed position in the board_grid
for child in current_node.get_children():
if current_node.get_score()[0] == 0:
print("Empty child selected. Set to 0.")
current_root_arcs[child.get_node_num()][1] = 0.0
else:
current_root_arcs[child.get_node_num()][1] = child.get_score()[0] / current_node.get_score()[0]
board_size = current_node.get_state().get_board().get_board_size()
current_root_arcs = np.reshape(current_root_arcs, (board_size, board_size, 2))
RBUF.append([current_node.merge_boards_to_anet(), current_root_arcs])
# Move to best child node
#current_node = current_node.calc_best_child(player, opposing_player, True)
current_node = current_node.get_child_with_highest_visit_count()
# Remove parent
current_node.set_parent(None)
if visualize[1]:
current_node.get_state().get_board().create_board_plot(
self.get_top_node().get_state().get_board().get_fig(),
self.get_top_node().get_state().get_board().get_ax())
plt.pause(0.5)
if visualize[0]:
current_node.get_state().get_board().print_board()
if (time.time() < time_start + min_pause_length):
time.sleep(time.time() - time_start + min_pause_length)
if visualize[0] or visualize[1]:
print(str(current_node.get_state().get_next_turn().get_color()) + " won!")
if visualize[1]:
plt.show()
# Make a single move based on the anet's predictions
def anet_one_turn(self, current_node, anet, visualize, min_pause_length):
time_start = time.time()
next_move = current_node.anet_policy(anet)
next_move = [math.floor(next_move / current_node.get_state().get_board().get_board_size()), next_move % current_node.get_state().get_board().get_board_size()]
current_node.create_child_node(next_move)
current_node = current_node.get_children()[0]
if visualize[1]:
current_node.get_state().get_board().create_board_plot(
self.get_top_node().get_state().get_board().get_fig(),
self.get_top_node().get_state().get_board().get_ax())
plt.pause(0.5)
if visualize[0]:
current_node.get_state().get_board().print_board()
print()
if (time.time() < time_start + min_pause_length):
time.sleep(time.time() - time_start + min_pause_length)
return current_node
def anet_make_move(self, current_node, anet, visualize):
next_move = current_node.anet_policy(anet)
next_move = [math.floor(next_move / current_node.get_state().get_board().get_board_size()),
next_move % current_node.get_state().get_board().get_board_size()]
if visualize[0]:
current_node.get_state().get_board().print_board()
print()
return next_move