-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPolicyNetwork.py
More file actions
29 lines (23 loc) · 928 Bytes
/
PolicyNetwork.py
File metadata and controls
29 lines (23 loc) · 928 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf
# Define the policy network
class PolicyNetwork(tf.keras.Model):
def __init__(self, num_actions, hidden_units):
super(PolicyNetwork, self).__init__()
self.dense1 = tf.keras.layers.Dense(hidden_units, activation='relu')
self.dense2 = tf.keras.layers.Dense(num_actions, activation='softmax')
'''def call(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
return x'''
# Instantiate the policy network
num_actions = 10
hidden_units = 32
policy_network = PolicyNetwork(num_actions, hidden_units)
# Define the loss function
def policy_loss(y_true, y_pred, returns):
neg_log_prob = tf.reduce_sum(-tf.math.log(y_pred) * y_true, axis=-1)
return tf.reduce_mean(neg_log_prob * returns)
# Define the optimizer
optimizer = tf.keras.optimizers.Adam(lr=0.001)
# Compile the model
policy_network.compile(optimizer=optimizer, loss=policy_loss)