Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions gem/multiagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from gem.multiagent.multi_agent_env import AgentSelector, MultiAgentEnv
from gem.multiagent.multi_agent_env import MultiAgentEnv

__all__ = [
"MultiAgentEnv",
"AgentSelector",
]
303 changes: 164 additions & 139 deletions gem/multiagent/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,67 +13,141 @@
# limitations under the License.

import abc
from typing import Any, Dict, List, Optional, Tuple
import warnings
from typing import Any, Dict, List, Union, Optional, Tuple

from gem.core import Env


class MultiAgentEnv(Env):
'''
Base multi-agent environment
'''

def __init__(self):
super().__init__()

self.possible_agents: List[str] = []
self.agents: List[str] = []
self.active_mask: Dict[str, bool] = {}

self.terminations: Dict[str, bool] = {}
self.truncations: Dict[str, bool] = {}
self.rewards: Dict[str, float] = {}
self.infos: Dict[str, dict] = {}
self._cumulative_rewards: Dict[str, float] = {}

self.agent_selector: Optional["AgentSelector"] = None
self._agent_iter = None

self.shared_memory = []
self.global_context = ""

def step(self, actions: Dict[str, str]) -> Tuple[
def _reset_agent_iter(self):
self._agent_iter = AgentIterator(self.agents, self.active_mask)

@property
def agent_iter(self):
if not self._agent_iter or self._agent_iter.is_end():
self._reset_agent_iter()
return self._agent_iter

def step(self, action_or_actions: Union[str, Dict[str, str]]) -> Tuple[
Dict[str, str],
Dict[str, float],
Dict[str, bool],
Dict[str, bool],
Dict[str, dict],
]:
if not isinstance(actions, dict):
raise ValueError(f"Actions must be a dict, got {type(actions)}")

active_agents = (
self.agent_selector.get_active_agents()
if self.agent_selector
else self.agents
)

self._validate_actions(actions, active_agents)

observations, rewards, terminations, truncations, infos = self._process_actions(
actions
)

for agent in self.agents:
if agent in rewards:
self._cumulative_rewards[agent] = (
self._cumulative_rewards.get(agent, 0.0) + rewards[agent]
'''
Master function for environment stepping.

By default, will attempt to call the Simultaneous step function.
If not implemented, will fall back to sequential stepping.
'''
ret = None
if isinstance(action_or_actions, dict):
try:
ret = self._step(action_or_actions)
except NotImplementedError:
warnings.warn(
"Simultaneous step not implemented, falling back to sequential stepping."
)
for agent in self.agent_iter:
# Don't silently fail, if the action is invalid
# let it raise an error.
self._handle_single_step(action_or_actions[agent])
ret = None
return self._step_global_dynamics(ret)
else:
ret = self._handle_single_step(action_or_actions)
if self._agent_iter.is_end():
ret = self._step_global_dynamics()
return ret

def _handle_single_step(self, action: str) -> None:
'''
Validation wrapper over a sequential step function.
'''
current_agent = self._agent_iter.current_agent
if not current_agent:
raise ValueError("No active agent selected")
if current_agent not in self.agents:
raise ValueError(f"Agent {current_agent} not in environment")

self._step_single(current_agent, action)

def _step(self, actions: Dict[str, str]) -> Tuple[
Dict[str, str],
Dict[str, float],
Dict[str, bool],
Dict[str, bool],
Dict[str, dict],
]:
'''
Parallel step function. As the lack of "parallel" suggests,
this is ideally the default mode of operation for multi-agent environments.
'''
raise NotImplementedError

self._remove_dead_agents()
def _step_single(self, current_agent: str, action: str) -> None:
'''
Sequential step function, the more general mode of
operation for multi-agent environments.

if self.agent_selector:
self.agent_selector.next()
Provides testing convenience and expressivity for games
whose players cannot make moves in parallel (i.e. Chess)

return observations, rewards, terminations, truncations, infos
This is because the environment steps can be broken up to be
on a per-player basis, allowing for more complex interactions
between individual players and the environment.
'''
raise NotImplementedError

def _validate_actions(self, actions: Dict[str, str], active_agents: List[str]):
for agent in active_agents:
def _step_global_dynamics(self, ret = None) -> Tuple[
Dict[str, str],
Dict[str, float],
Dict[str, bool],
Dict[str, bool],
Dict[str, dict],
]:
'''
After all players have finished making their moves,
handle all global environmental dynamics before
starting a new "turn".
'''
self._reset_agent_iter()
if ret is None:
ret = self.get_all_states()
# Flush the cumulative rewards after each step.
for agent in self.agents:
self.rewards[agent] = 0.0
self._cumulative_rewards[agent] = 0.0
# Update the active mask
if self.terminations[agent] or self.truncations[agent]:
self.active_mask[agent] = False
return ret

def _validate_actions(self, actions: Dict[str, str]):
for agent in self.agents:
if agent not in self.terminations or self.terminations[agent]:
continue
if agent not in self.truncations or self.truncations[agent]:
Expand All @@ -82,26 +156,16 @@ def _validate_actions(self, actions: Dict[str, str], active_agents: List[str]):
raise ValueError(f"Missing action for active agent {agent}")

for agent in actions:
if agent not in active_agents:
if not self.active_mask[agent]:
raise ValueError(f"Agent {agent} provided action but is not active")

@abc.abstractmethod
def _process_actions(self, actions: Dict[str, str]) -> Tuple[
Dict[str, str],
Dict[str, float],
Dict[str, bool],
Dict[str, bool],
Dict[str, dict],
]:
raise NotImplementedError

def reset(
self, seed: Optional[int] = None
) -> Tuple[Dict[str, str], Dict[str, Any]]:
if seed is not None:
self._np_random = self._make_np_random(seed)

self.agents = self.possible_agents.copy()
self.active_mask = {agent: True for agent in self.agents}

self.terminations = {agent: False for agent in self.agents}
self.truncations = {agent: False for agent in self.agents}
Expand All @@ -112,15 +176,13 @@ def reset(
self.shared_memory = []
self.global_context = ""

if self.agent_selector:
self.agent_selector.reinit(self.agents)
self._reset_agent_iter()

observations = {agent: self.observe(agent) for agent in self.agents}
infos = {agent: {} for agent in self.agents}

return observations, infos

@abc.abstractmethod

def observe(self, agent: str) -> str:
raise NotImplementedError

Expand All @@ -136,55 +198,30 @@ def get_state(self, agent: str) -> Tuple[str, float, bool, bool, dict]:
self.infos.get(agent, {}),
)

def get_active_states(self) -> Dict[str, Tuple[str, float, bool, bool, dict]]:
active_agents = (
self.agent_selector.get_active_agents()
if self.agent_selector
else self.agents
def get_all_states(self) -> Tuple[
Dict[str, str],
Dict[str, float],
Dict[str, bool],
Dict[str, bool],
Dict[str, dict],
]:
# Get a tuple for all states
all_states = [self.get_state(agent) for agent in self.agents]
# Translate into "characteristic-first" instead of "agent-first"
observations, rewards, terminations, truncations, infos = zip(*all_states)
return (
dict(zip(self.agents, observations)),
dict(zip(self.agents, rewards)),
dict(zip(self.agents, terminations)),
dict(zip(self.agents, truncations)),
dict(zip(self.agents, infos)),
)

return {
agent: self.get_state(agent)
for agent in active_agents
if agent in self.agents
}

def add_agent(self, agent_id: str):
if agent_id in self.agents:
return

self.agents.append(agent_id)
self.terminations[agent_id] = False
self.truncations[agent_id] = False
self.rewards[agent_id] = 0.0
self.infos[agent_id] = {}
self._cumulative_rewards[agent_id] = 0.0

if self.agent_selector:
self.agent_selector.add_agent(agent_id)

def remove_agent(self, agent_id: str):
if agent_id not in self.agents:
return

self.agents.remove(agent_id)
del self.terminations[agent_id]
del self.truncations[agent_id]
del self.rewards[agent_id]
del self.infos[agent_id]
del self._cumulative_rewards[agent_id]

if self.agent_selector:
self.agent_selector.remove_agent(agent_id)

def _remove_dead_agents(self):
dead_agents = [
agent
for agent in self.agents
if self.terminations.get(agent, False) or self.truncations.get(agent, False)
]
for agent in dead_agents:
self.remove_agent(agent)
def update_reward(self, agent: str, reward: float):
if agent not in self.agents:
raise ValueError(f"Agent {agent} not in environment")
self.rewards[agent] = reward
self._cumulative_rewards[agent] += reward

def send_message(self, from_agent: str, to_agent: str, message: str):
if from_agent not in self.agents:
Expand All @@ -206,53 +243,41 @@ def broadcast_message(self, from_agent: str, message: str):
{"from": from_agent, "to": agent, "message": message}
)


class AgentSelector:

def __init__(self, agents: List[str], mode: str = "sequential"):
self.mode = mode
class AgentIterator:
'''
Iterator to help Sequential environments iterate over agents.
'''

def __init__(self, agents: List[str], active_mask):
self._agents = agents.copy()
self.active_mask = active_mask
self.current_agent = None
self._current_idx = 0
self.selected = self._agents[0] if self._agents else None

def get_active_agents(self) -> List[str]:
if self.mode == "sequential":
return [self.selected] if self.selected else []
elif self.mode == "parallel":
return self._agents.copy()
else:
raise ValueError(f"Unknown mode: {self.mode}")

def next(self):
if self.mode == "sequential" and self._agents:
self._current_idx = (self._current_idx + 1) % len(self._agents)
self.selected = self._agents[self._current_idx]

def is_first(self) -> bool:
return self._current_idx == 0

def is_last(self) -> bool:
return self._current_idx == len(self._agents) - 1

def reinit(self, agents: List[str]):
self._agents = agents.copy()
self._current_idx = 0
self.selected = self._agents[0] if self._agents else None

def add_agent(self, agent: str):
if agent not in self._agents:
self._agents.append(agent)

def remove_agent(self, agent: str):
if agent in self._agents:
idx = self._agents.index(agent)
self._agents.remove(agent)

if self._agents:
if idx <= self._current_idx:
self._current_idx = max(0, self._current_idx - 1)
self._current_idx = self._current_idx % len(self._agents)
self.selected = self._agents[self._current_idx]
else:
self._current_idx = 0
self.selected = None
def __next__(self) -> Optional[str]:
if not any(self.active_mask):
return None

if self.is_end():
raise StopIteration

self.current_agent = self._agents[self._current_idx]

# Point to the next agent
self._current_idx += 1
while not self.is_end() and not self.active_mask[self._agents[self._current_idx]]:
self._current_idx += 1

return self.current_agent

def is_alive(self, agent):
try:
return self.active_mask[agent]
except KeyError:
raise KeyError("Agent does not exist in this iterator.")

def is_end(self):
return self._current_idx >= len(self._agents)

def __iter__(self):
return self
Loading