-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathreversi_robot.py
More file actions
93 lines (82 loc) · 3.4 KB
/
reversi_robot.py
File metadata and controls
93 lines (82 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
import numpy as np
class reversi_robot:
status = {}
enables = {}
def __init__(self,env,sess,model,player,eps=0.5):
self.env = env
self.sess = sess
self.player = player
self.model = model
self.eps = eps
self.status[0] = []
self.status[1] = []
self.enables[0] = []
self.enables[1] = []
def fixed_reward(self,r):
pass
return r
def step(self,enables,player):
s,r,done,info = self.env.step((enables,player))
if player == 1:
self.status[1].append(s)
else:
self.status[0].append(s)
r = self.fixed_reward(r)
return s,r,done,info
def get_next_action(self,enables,Q):
if np.random.rand(1) < self.eps:
a = np.random.choice(enables)
else:
Q_flatted = np.ravel(Q)
a = enables[np.argmax(Q_flatted[enables])]
return a
def get_possible_actions(self,status,player):
enable = self.env.get_possible_actions(status, player)
if player == 1:
self.enables[1].append(enable)
else:
self.enables[0].append(enable)
return enable
def flat(self, s):
return np.reshape(s, (1, self.model.input_length))
def update_Q(self,winner,gamma):
Q = self.sess.run(self.model.Q,\
feed_dict={self.model.input_s: self.flat(self.status[winner][0])})
# if winner == 1: # 白棋赢
# r = 10
# else:
# r = 10
loser = 1- winner
winner_r = 10
loser_p = -10
for i in range(len(self.status[winner])):
if i == len(self.status[winner])-1: break
next_Q = self.sess.run(self.model.Q, \
feed_dict={self.model.input_s: self.flat(self.status[winner][i+1])})
next_Q_flatted = np.ravel(next_Q)
max_next_Q = np.max(next_Q_flatted[self.enables[winner][i]])
Q_target = Q
a = self.get_next_action(self.enables[winner][i],Q)
Q_target[0][a] = winner_r + gamma * max_next_Q
_ = self.sess.run(self.model.update, \
feed_dict={self.model.Q_target: Q_target,
self.model.input_s: self.flat(self.status[winner][i+1])})
Q = self.sess.run(self.model.Q,\
feed_dict={self.model.input_s: self.flat(self.status[winner][i+1])})
Q = self.sess.run(self.model.Q,\
feed_dict={self.model.input_s: self.flat(self.status[loser][0])})
for i in range(len(self.status[loser])):
if i == len(self.status[loser])-1: break
next_Q = self.sess.run(self.model.Q, \
feed_dict={self.model.input_s: self.flat(self.status[loser][i+1])})
next_Q_flatted = np.ravel(next_Q)
max_next_Q = np.max(next_Q_flatted[self.enables[loser][i]])
Q_target = Q
a = self.get_next_action(self.enables[loser][i],Q)
Q_target[0][a] = loser_p + gamma * max_next_Q
_ = self.sess.run(self.model.update, \
feed_dict={self.model.Q_target: Q_target,
self.model.input_s: self.flat(self.status[loser][i+1])})
Q = self.sess.run(self.model.Q,\
feed_dict={self.model.input_s: self.flat(self.status[loser][i+1])})