-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathOperationOnly.py
More file actions
115 lines (91 loc) · 4.31 KB
/
OperationOnly.py
File metadata and controls
115 lines (91 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from arcle.envs import O2ARCv2Env
from arcle.loaders import ARCLoader, Loader
import numpy as np
from arcle.loaders import ARCLoader
from gymnasium.core import ObsType, ActType
import numpy as np
from typing import Dict,Optional, Tuple, SupportsFloat, SupportsInt, SupportsIndex, Any
from functools import wraps
class OpOnlyenv(O2ARCv2Env):
def __init__(self, data_loader: Loader = ARCLoader(), max_grid_size: Tuple[SupportsInt, SupportsInt] = (30,30), colors: SupportsInt = 10, max_trial: SupportsInt = -1, render_mode: str = None, render_size: Tuple[SupportsInt, SupportsInt] = None) -> None:
super().__init__(data_loader, max_grid_size, colors, max_trial, render_mode, render_size)
self.reset_options = {
'adaptation': True, # Default is true (adaptation first!). To change this mode, call 'post_adaptation()'
'prob_index': None
}
self.num_func = 5
def create_operations(self) :
from arcle.actions.critical import crop_grid
from arcle.actions.object import reset_sel
ops = super().create_operations()
new_ops = []
for i, op in enumerate(ops):
if i in [24,25,26,27,34]: #[4,6,8,9,24,25,26,27, 29, 30, 21, 34]:
new_ops.append(ops[i])
return new_ops
def reset(self, seed = None, options: Optional[Dict] = None, subprob = None):
super().reset(seed=seed,options=options)
# Reset Internal States
self.truncated = False
self.submit_count = 0
self.last_action: ActType = None
self.last_action_op : SupportsIndex = None
self.last_reward: SupportsFloat = 0
self.action_steps: SupportsInt = 0
self.eval_subprob = subprob
# env option
self.prob_index = None
self.subprob_index = None
self.adaptation = True
self.reset_on_submit = False
self.options = options
if options is not None:
self.prob_index = options.get('prob_index')
self.subprob_index = options.get('subprob_index')
_ad = options.get('adaptation')
self.adaptation = True if _ad is None else bool(_ad)
_ros = options.get('reset_on_submit')
self.reset_on_submit = False if _ros is None else _ros
ex_in, ex_out, tt_in, tt_out, desc = self.loader.pick(data_index=self.prob_index)
if self.adaptation:
self.subprob_index = np.random.randint(0,len(ex_in)) if self.subprob_index is None else self.subprob_index
self.input_ = ex_in[self.subprob_index]
self.answer = ex_out[self.subprob_index]
else: #eval_problem 1 to 100
self.subprob_index = self.eval_subprob
self.input_ = tt_in[self.subprob_index]
self.answer = tt_out[self.subprob_index]
self.init_state(self.input_.copy(),options)
self.description = desc
if self.render_mode:
self.render()
obs = self.current_state
self.info = self.init_info()
return obs, self.info
def reward(self, state) -> SupportsFloat:
if not self.last_action_op == len(self.operations)-1:
return 0
if tuple(state['grid_dim']) == self.answer.shape:
h,w = self.answer.shape
if np.all(state['grid'][0:h, 0:w] == self.answer):
return 1
return 0
def step(self, action: ActType):
operation = int(action['operation'])
self.transition(self.current_state, action)
self.last_action_op = operation
self.last_action = action
# do action
state = self.current_state
reward = self.reward(state)
self.last_reward = reward
self.action_steps+=1
self.info['steps'] = self.action_steps
self.info['submit_count'] = self.submit_count
self.render()
return self.current_state, reward, bool(state["terminated"][0]), self.truncated, self.info
def transition(self, state: ObsType, action: ActType) -> None:
op = int(action['operation'])
self.last_action_op = op
self.last_action = action
self.operations[op](state,action)