-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexperience.py
More file actions
45 lines (32 loc) · 973 Bytes
/
experience.py
File metadata and controls
45 lines (32 loc) · 973 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from mbpo import *
from stable_baselines3 import SAC, TD3
import gymnasium as gym
import sys
args = sys.argv[1:]
assert (
len(args) == 5
), "python3 experience.py env_name estimator_cls pol_optim_cls nb_iter exp_name"
env_name = args[0]
iters = int(args[3])
exp_name = args[4]
if args[1] == "tree":
transi = FullTransitionTreeModel()
done = DoneTreeModel()
elif args[1] == "mlp":
transi = FullTransitionMLPModel()
done = DoneMLPModel()
elif args[1] == "cvtree":
transi = FullTransitionTreeCVModel()
done = DoneTreeCVModel() # maybe also cv tree
else:
AssertionError, "Only Model estimators are Decision Tree, Best CV Tree, and MLP"
if args[2] == "sac":
agent_cls = SAC
elif args[2] == "td3":
agent_cls = TD3
else:
AssertionError, "Only Pol Ooptim algos are SAC and TD3"
env = gym.wrappers.NormalizeObservation(gym.make(env_name))
mbpo = MBPOAgent(env, transi, done, agent_cls)
mbpo.learn(iters)
mbpo.save(exp_name)