-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
27 lines (22 loc) · 1016 Bytes
/
train.py
File metadata and controls
27 lines (22 loc) · 1016 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from mario_env import MarioEnv
from roboflow_env import RoboflowEnvironment
from save_best_training import SaveOnBestTrainingRewardCallback
from nes_py.wrappers import JoypadSpace
from gymnasium.wrappers import FrameStackObservation as FrameStack
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecMonitor
from stable_baselines3 import PPO
log_dir = './tmp/'
def make_env():
mario_env = MarioEnv(skip=4)
mario_env = JoypadSpace(mario_env, SIMPLE_MOVEMENT)
env = RoboflowEnvironment(mario_env, "mario-ibyfv/2", api_key="ITukAND4XqHSos8UA9me", max_boxes=10)
env = FrameStack(env, 2)
return env
env = DummyVecEnv([make_env for _ in range(4)])
env = VecMonitor(env, log_dir)
model = PPO('MlpPolicy', env, verbose=1)
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir='tmp/')
model.learn(total_timesteps=3000000000, callback=callback)
model.save("./mario_model.zip")