-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenv.py
More file actions
89 lines (67 loc) · 2.2 KB
/
env.py
File metadata and controls
89 lines (67 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from logging import info
from collections import deque
import gym
import time
from gym import spaces
from torch.nn.modules.module import register_module_forward_hook
from cozmo_actions import actions
import rospy
import keyboard
from openface2_ros.msg import Faces
import time
from stable_baselines.common.callbacks import BaseCallback
import numpy as np
class CozmoEnv(gym.Env):
def __init__(self):
super(CozmoEnv, self).__init__()
self.queue = deque(maxlen=7)
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions:
self.action_maps = {0 : actions.action_party, 1 : actions.action_sadness, 2 : actions.action_confused_cute, 3 : actions.action_cute}
self.action_space = spaces.Discrete(4)
self.observation_space = spaces.Box(0, 1, [2])
self.action_units = None
self.movingAvgRewards = []
self.actionCounter = [0, 0, 0, 0]
def data_callback(self,data):
# print(data.faces[0].action_units)
self.action_units = data.faces[0].action_units
def step(self, action):
reward = 0
self.actionCounter[action] += 1
step_action = self.action_maps[action]
step_action(self)
# print("smile! or don't, I don't care")
start_time = time.time()
smile_count = 0
while time.time() - start_time < 5.0:
if self.action_units:
observation = [self.action_units[8].intensity/5, self.action_units[13].intensity/5]
else:
observation = [0] * 2
if observation[0] > 0.3/5 and observation[1] > 0.5/5:
smile_count += 1
time.sleep(0.1)
if smile_count >= 5:
reward = 1
print(smile_count)
print("observation registered")
# Calculates moving average reward
self.queue.appendleft(reward)
if len(self.queue) == 7:
tmp = self.queue.copy()
movingAvgReward = 0
while len(tmp)!=0:
movingAvgReward += tmp.pop()
movingAvgReward /= 7
self.movingAvgRewards.append(movingAvgReward)
print(movingAvgReward)
print(self.movingAvgRewards)
self.queue.pop()
return observation, reward, True, {}
def reset(self):
return [0] * 2
def close (self):
pass
# rospy.spin()