-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
80 lines (63 loc) · 2.52 KB
/
run.py
File metadata and controls
80 lines (63 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import gymnasium as gym
import hydra
from omegaconf import DictConfig
from omegaconf import OmegaConf
from OperationOnly import OpOnlyenv
from arcle.loaders import Loader
from loader import SizeConstrainedLoader, EntireSelectionLoader
import wandb
import argparse
from ppo.ppo import learn
class TestLoader(Loader):
def __init__(self, size_x, size_y, **kwargs):
self.size_x = size_x
self.size_y = size_y
self.rng = np.random.default_rng(12345)
super().__init__(**kwargs)
def get_path(self, **kwargs):
return ['']
def pick(self, **kwargs):
return self.parse()[0]
def parse(self, **kwargs):
ti= np.zeros((self.size_x,self.size_y), dtype=np.uint8)
to = np.zeros((self.size_x,self.size_y), dtype=np.uint8)
ei = np.zeros((self.size_x,self.size_y), dtype=np.uint8)
eo = np.zeros((self.size_x,self.size_y), dtype=np.uint8)
ti[0:self.size_x, 0:self.size_y] = self.rng.integers(0,10, size=[self.size_x,self.size_y])
to[0:self.size_x, 0:self.size_y] = self.rng.integers(0,10, size=[self.size_x,self.size_y])
ei[0:self.size_x, 0:self.size_y] = self.rng.integers(0,10, size=[self.size_x,self.size_y])
eo[0:self.size_x, 0:self.size_y] = self.rng.integers(0,10, size=[self.size_x,self.size_y])
return [([ti],[to],[ei],[eo], {'desc': "just for test"})]
@hydra.main(config_path="ppo", config_name="ppo_config_entsel")
def main(cfg: DictConfig) -> None:
# wandb.init(
# entity = "",
# project="",
# config=OmegaConf.to_container(cfg)
# )
if cfg.env.use_arc:
env = gym.make(
'ARCLE/O2ARCv2Env-v0',
data_loader = SizeConstrainedLoader(cfg.env.grid_x),
max_trial = 3,
max_grid_size=(cfg.env.grid_x, cfg.env.grid_y),
colors=cfg.env.num_colors)
elif cfg.env.ent_sel:
env = OpOnlyenv(
data_loader=EntireSelectionLoader(train_task=cfg.train.task, eval_task=cfg.eval.task),
max_trial=3,
max_grid_size=(cfg.env.grid_x, cfg.env.grid_y),
colors=cfg.env.num_colors
)
print(env.action_space)
else:
env = gym.make(
'ARCLE/O2ARCv2Env-v0',
data_loader = TestLoader(cfg.env.grid_x, cfg.env.grid_y),
max_trial = 3,
max_grid_size=(cfg.env.grid_x, cfg.env.grid_y),
colors=cfg.env.num_colors)
learn(cfg, env)
if __name__ == "__main__":
main()