Minimal repo to use the policy of any IRIS agent trained on the 26 Atari 100k games.
Install with pip:
pip install git+https://github.com/eloialonso/iris_agent.gitCreate agent:
from iris_agent import Agent
agent = Agent() # choose from list of games, or
# agent = Agent('Breakout') # specify game name
Use the policy:
import torch
n = 1
agent.reset(n)
obs = torch.randn(n, 3, 64, 64) # obs is a (n, 3, 64, 64) tensor in [0.,1.], and you should use the standard atari wrappers (see IRIS codebase)
act = agent.act(obs) # act is a (n,) long tensor in {0,...,num_actions-1}