diff --git a/examples/mis/lwd/README.md b/examples/mis/lwd/README.md
new file mode 100644
index 000000000..1d6070137
--- /dev/null
+++ b/examples/mis/lwd/README.md
@@ -0,0 +1,150 @@
+# LWD: Learning What to Defer for Maximum Independent Sets
+
+It is the application of the [LwD](http://proceedings.mlr.press/v119/ahn20a.html) method (published in ICML 2020).
+Some code are referred from the original GitHub [repository](https://github.com/sungsoo-ahn/learning_what_to_defer)
+held by the authors.
+
+## Algorithm Introduction
+
+### Deferred Markov Decision Process
+
+
+
+ 
+
+
+
+#### State
+
+Each state of the MDP is represented as a *vertex-state* vector:
+
+
+
+$s = [s_i: i \in V] \in \{0, 1, *\}^V$
+
+
+
+where *0, 1, \** indicates vertex *i* is *excluded*, *included*,
+and *the determination is deferred and expected to be made in later iterations* respectively.
+The MDP is *initialized* with the deferred vertex-states, i.e., $s_i = *, \forall i \in V$,
+while *terminated* when (a) there is no deferred vertex-state left or (b) time limit is reached.
+
+#### Action
+
+Actions correspond to new assignments for the next state of vertices, defined only on the deferred vertices here:
+
+
+
+$a_* = [a_i: i \in V_*] \in \{0, 1, *\}^{V_*}$
+
+
+
+where $V_* = \{i: i \in V, x_i = *\}$.
+
+#### Transition
+
+The transition $P_{a_*}(s, s')$ consists of two deterministic phases:
+
+- *update phase*: takes the action $a_*$ to get an intermediate vertex-state $\hat{s}$,
+i.e., $\hat{s_i} = a_i$ if $i \in V_*$ and $\hat{s_i} = s_i$ otherwise.
+- *clean-up phase*: modifies $\hat{s}$ to yield a valid vertex-state $s'$.
+
+ - Whenever there exists a pair of included vertices adjacent to each other,
+ they are both mapped back to the deferred vertex-state.
+ - Excludes any deferred vertex neighboring with an included vertex.
+
+Here is an illustration of the transition fucntion:
+
+
+
+
+
+
+
+#### Reward
+
+A *cardinality reward* is defined here:
+
+
+
+$R(s, s') = \sum_{i \in V_* \setminus V_*'}{s_i'}$
+
+
+
+where $V_*$ and $V_*'$ are the set of vertices with deferred vertex-state with respect to $s$ and $s'$ respectively.
+By doing so, the overall reward of the MDP corresonds to the cardinality of the independent set returned.
+
+### Diversification Reward
+
+Couple two copies of MDPs defined on an indentical graph $G$ into a new MDP.
+Then the new MDP is associated with a pair of distinct vertex-state vectors $(s, \bar{s})$,
+and let the resulting solutions be $(x, \bar{x})$.
+We directly reward the deviation between the coupled solutions in terms of $l_1$-norm, i.e., $||x-\bar{x}||_1$.
+To be specific, the deviation is decomposed into rewards in each iteration of the MDP defined by:
+
+
+
+$R_{div}(s, s', \bar{s}, \bar{s}') = \sum_{i \in \hat{V}}|s_i'-\bar{s}_i'|$, where $\hat{V}=(V_* \setminus V_*')\cup(\bar{V}_* \setminus \bar{V}_*')$
+
+
+
+Here is an example of the diversity reward:
+
+
+
+
+
+
+
+*The Entropy Regularization plays a similar role to the diversity reward introduced above.
+But note that, the entropy regularition only attempts to generate diverse trajectories of the same MDP,
+which does not necessarily lead to diverse solutions at last,
+since there existing many trajectories resulting in the same solution.*
+
+### Design of the Neural Network
+
+The policy network $\pi(a|s)$ and the value network $V(s)$ is designed to follow the
+[GraphSAGE](https://proceedings.neurips.cc/paper/2017/hash/5dd9db5e033da9c6fb5ba83c7a7ebea9-Abstract.html) architecture,
+which is a general inductive framework that leverages node feature information
+to efficiently generate node embeddings by sampling and aggregating features from a node's local neighborhood.
+Each network consists of multiple layers $h^{(n)}$ with $n = 1, ..., N$
+where the $n$-layer with weights $W_1^{(n)}$ and $W_2^{(n)}$ performs the following transformation on input $H$:
+
+
+
+$h^{(n)} = ReLU(HW_1^{(n)}+D^{-\frac{1}{2}}BD^{-\frac{1}{2}}HW_2^{(N)})$.
+
+
+
+Here $B$ and $D$ corresponds to adjacency and degree matrix of the graph $G$, respectively. At the final layer,
+the policy and value networks apply softmax function and graph readout function with sum pooling instead of ReLU
+to generate actions and value estimates, respectively.
+
+### Input of the Neural Network
+
+- The subgraph that is induced on the deferred vertices $V_*$ as the input of the networks
+since the determined part of the graph no longer affects the future rewards of the MDP.
+- Input features:
+
+ - Vertex degrees;
+ - The current iteration-index of the MDP, normalized by the maximum number of iterations.
+
+### Training Algorithm
+
+The Proximal Policy Optimization (PPO) is used in this solution.
+
+## Quick Start
+
+Please make sure the environment is correctly set up, refer to
+[MARO](https://github.com/microsoft/maro#install-maro-from-source) for more installation guidance.
+To try the example code, you can simply run:
+
+```sh
+python examples/rl/run.py examples/mis/lwd/config.yml
+```
+
+The default log path is set to *examples/mis/lwd/log/test*, the recorded metrics and training curves can be found here.
+
+To adjust the configurations of the training workflow, go to file: *examples/mis/lwd/config.yml*,
+To adjust the problem formulation, network setting and some other detailed configurations,
+go to file *examples/mis/lwd/config.py*.
diff --git a/examples/mis/lwd/__init__.py b/examples/mis/lwd/__init__.py
new file mode 100644
index 000000000..a1517b73c
--- /dev/null
+++ b/examples/mis/lwd/__init__.py
@@ -0,0 +1,99 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import torch
+
+from maro.rl.rl_component.rl_component_bundle import RLComponentBundle
+from maro.rl.utils.common import get_env
+from maro.simulator import Env
+
+from examples.mis.lwd.config import Config
+from examples.mis.lwd.env_sampler.mis_env_sampler import MISEnvSampler, MISPlottingCallback
+from examples.mis.lwd.simulator.mis_business_engine import MISBusinessEngine
+from examples.mis.lwd.ppo import get_ppo_policy, get_ppo_trainer
+
+
+config = Config()
+
+# Environments
+learn_env = Env(
+ business_engine_cls=MISBusinessEngine,
+ durations=config.max_tick,
+ options={
+ "graph_batch_size": config.train_graph_batch_size,
+ "num_samples": config.train_num_samples,
+ "device": torch.device(config.device),
+ "num_node_lower_bound": config.num_node_lower_bound,
+ "num_node_upper_bound": config.num_node_upper_bound,
+ "node_sample_probability": config.node_sample_probability,
+ },
+)
+
+test_env = Env(
+ business_engine_cls=MISBusinessEngine,
+ durations=config.max_tick,
+ options={
+ "graph_batch_size": config.eval_graph_batch_size,
+ "num_samples": config.eval_num_samples,
+ "device": torch.device(config.device),
+ "num_node_lower_bound": config.num_node_lower_bound,
+ "num_node_upper_bound": config.num_node_upper_bound,
+ "node_sample_probability": config.node_sample_probability,
+ },
+)
+
+# Agent, policy, and trainers
+agent2policy = {agent: f"ppo_{agent}.policy" for agent in learn_env.agent_idx_list}
+
+policies = [
+ get_ppo_policy(
+ name=f"ppo_{agent}.policy",
+ state_dim=config.input_dim,
+ action_num=config.output_dim,
+ hidden_dim=config.hidden_dim,
+ num_layers=config.num_layers,
+ init_lr=config.init_lr,
+ )
+ for agent in learn_env.agent_idx_list
+]
+
+trainers = [
+ get_ppo_trainer(
+ name=f"ppo_{agent}",
+ state_dim=config.input_dim,
+ hidden_dim=config.hidden_dim,
+ num_layers=config.num_layers,
+ init_lr=config.init_lr,
+ clip_ratio=config.clip_ratio,
+ max_tick=config.max_tick,
+ batch_size=config.batch_size,
+ reward_discount=config.reward_discount,
+ graph_batch_size=config.train_graph_batch_size,
+ graph_num_samples=config.train_num_samples,
+ num_train_epochs=config.num_train_epochs,
+ norm_base=config.reward_normalization_base,
+ )
+ for agent in learn_env.agent_idx_list
+]
+
+device_mapping = {f"ppo_{agent}.policy": config.device for agent in learn_env.agent_idx_list}
+
+# Build RLComponentBundle
+rl_component_bundle = RLComponentBundle(
+ env_sampler=MISEnvSampler(
+ learn_env=learn_env,
+ test_env=test_env,
+ policies=policies,
+ agent2policy=agent2policy,
+ diversity_reward_coef=config.diversity_reward_coef,
+ reward_normalization_base=config.reward_normalization_base,
+ ),
+ agent2policy=agent2policy,
+ policies=policies,
+ trainers=trainers,
+ device_mapping=device_mapping,
+ customized_callbacks=[MISPlottingCallback(log_dir=get_env("LOG_PATH", required=False, default="./"))],
+)
+
+
+__all__ = ["rl_component_bundle"]
diff --git a/examples/mis/lwd/config.py b/examples/mis/lwd/config.py
new file mode 100644
index 000000000..eb59ae8ef
--- /dev/null
+++ b/examples/mis/lwd/config.py
@@ -0,0 +1,39 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+
+class Config(object):
+ device: str = "cuda:0"
+
+ # Configuration for graph batch size
+ train_graph_batch_size = 32
+ eval_graph_batch_size = 32
+
+ # Configuration for num_samples
+ train_num_samples = 2
+ eval_num_samples = 10
+
+ # Configuration for the MISEnv
+ max_tick = 32 # Once the max_tick reached, the timeout processing will set all deferred nodes to excluded
+ num_node_lower_bound: int = 40
+ num_node_upper_bound: int = 50
+ node_sample_probability: float = 0.15
+
+ # Configuration for the reward definition
+ diversity_reward_coef = 0.1 # reward = cardinality reward + coef * diversity Reward
+ reward_normalization_base = 20
+
+ # Configuration for the GraphBasedActorCritic
+ input_dim = 2
+ output_dim = 3
+ hidden_dim = 128
+ num_layers = 5
+
+ # Configuration for PPO update
+ init_lr = 1e-4
+ clip_ratio = 0.2
+ reward_discount = 1.0
+
+ # Configuration for main loop
+ batch_size = 16
+ num_train_epochs = 4
diff --git a/examples/mis/lwd/config.yml b/examples/mis/lwd/config.yml
new file mode 100644
index 000000000..b057b2c8b
--- /dev/null
+++ b/examples/mis/lwd/config.yml
@@ -0,0 +1,37 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations.
+
+job: mis_lwd
+scenario_path: "examples/mis/lwd"
+# The log dir where you want to save the training loggings and model checkpoints.
+log_path: "examples/mis/lwd/log/test_40_50"
+main:
+ # Number of episodes to run. Each episode is one cycle of roll-out and training.
+ num_episodes: 1000
+ # This can be an integer or a list of integers. An integer indicates the interval at which policies are evaluated.
+ # A list indicates the episodes at the end of which policies are to be evaluated. Note that episode indexes are
+ # 1-based.
+ eval_schedule: [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 600, 700, 800, 900, 1000]
+ # Number of Episodes to run in evaluation.
+ num_eval_episodes: 5
+ min_n_sample: 1
+ logging:
+ stdout: INFO
+ file: DEBUG
+rollout:
+ logging:
+ stdout: INFO
+ file: DEBUG
+training:
+ mode: simple
+ load_path: null
+ load_episode: null
+ checkpointing:
+ path: null
+ # Interval at which trained policies / models are persisted to disk.
+ interval: 200
+ logging:
+ stdout: INFO
+ file: DEBUG
diff --git a/examples/mis/lwd/env_sampler/baseline.py b/examples/mis/lwd/env_sampler/baseline.py
new file mode 100644
index 000000000..3282d1a88
--- /dev/null
+++ b/examples/mis/lwd/env_sampler/baseline.py
@@ -0,0 +1,46 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import random
+from typing import Dict, List
+
+
+def _choose_by_weight(graph: Dict[int, List[int]], node2weight: Dict[int, float]) -> List[int]:
+ """Choose node in the order of descending weight if not blocked..
+
+ Args:
+ graph (Dict[int, List[int]]): The adjacent matrix of the target graph. The key is the node id of each node. The
+ value is a list of each node's neighbor nodes.
+ node2weight (Dict[int, float]): The node to weight dictionary with node id as key and node weight as value.
+
+ Returns:
+ List[int]: A list of chosen node id.
+ """
+ node_weight_list = [(node, weight) for node, weight in node2weight.items()]
+ # Shuffle the candidates to get random result in the case there are nodes sharing the same weight.
+ random.shuffle(node_weight_list)
+ # Sort node candidates with descending weight.
+ sorted_nodes = sorted(node_weight_list, key=lambda x: x[1], reverse=True)
+
+ chosen_node_id_set: set = set()
+ blocked_node_id_set: set = set()
+ # Choose node in the order of descending weight if it is not blocked yet by the chosen nodes.
+ for node, _ in sorted_nodes:
+ if node in blocked_node_id_set:
+ continue
+ chosen_node_id_set.add(node)
+ for neighbor_node in graph[node]:
+ blocked_node_id_set.add(neighbor_node)
+
+ chosen_node_ids = [node for node in chosen_node_id_set]
+ return chosen_node_ids
+
+def uniform_mis_solver(graph: Dict[int, List[int]]) -> List[int]:
+ node2weight: Dict[int, float] = {node: 1 for node in graph.keys()}
+ chosen_node_list = _choose_by_weight(graph, node2weight)
+ return chosen_node_list
+
+def greedy_mis_solver(graph: Dict[int, List[int]]) -> List[int]:
+ node2weight: Dict[int, float] = {node: 1 / (1 + len(neighbor_list)) for node, neighbor_list in graph.items()}
+ chosen_node_list = _choose_by_weight(graph, node2weight)
+ return chosen_node_list
diff --git a/examples/mis/lwd/env_sampler/mis_env_sampler.py b/examples/mis/lwd/env_sampler/mis_env_sampler.py
new file mode 100644
index 000000000..c0aaabe60
--- /dev/null
+++ b/examples/mis/lwd/env_sampler/mis_env_sampler.py
@@ -0,0 +1,343 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+
+from maro.rl.policy.abs_policy import AbsPolicy
+from maro.rl.rollout import SimpleAgentWrapper, AbsEnvSampler, CacheElement
+from maro.rl.utils import ndarray_to_tensor
+from maro.rl.workflows.callback import Callback
+from maro.simulator.core import Env
+
+from examples.mis.lwd.env_sampler.baseline import greedy_mis_solver, uniform_mis_solver
+from examples.mis.lwd.ppo.ppo import GraphBasedPPOPolicy
+from examples.mis.lwd.ppo.replay_memory import GraphBasedExpElement
+from examples.mis.lwd.simulator import Action, MISDecisionPayload, MISEnvMetrics, MISBusinessEngine
+
+
+class MISAgentWrapper(SimpleAgentWrapper):
+ def __init__(self, policy_dict: Dict[str, AbsPolicy], agent2policy: Dict[Any, str]) -> None:
+ super().__init__(policy_dict, agent2policy)
+
+ def _choose_actions_impl(self, state_by_agent: Dict[Any, torch.Tensor]) -> Dict[Any, np.ndarray]:
+ assert len(state_by_agent) == 1
+ for agent_name, state in state_by_agent.items():
+ break
+
+ policy_name = self._agent2policy[agent_name]
+ policy = self._policy_dict[policy_name]
+ assert isinstance(policy, GraphBasedPPOPolicy)
+
+ assert isinstance(state, Tuple)
+ assert len(state) == 2
+ action = policy.get_actions(state[0], graph=state[1])
+ return {agent_name: action}
+
+
+class MISEnvSampler(AbsEnvSampler):
+ def __init__(
+ self,
+ learn_env: Env,
+ test_env: Env,
+ policies: List[AbsPolicy],
+ agent2policy: Dict[Any, str],
+ trainable_policies: List[str] = None,
+ reward_eval_delay: int = None,
+ max_episode_length: int = None,
+ diversity_reward_coef: float = 0.1,
+ reward_normalization_base: float = None
+ ) -> None:
+ super(MISEnvSampler, self).__init__(
+ learn_env=learn_env,
+ test_env=test_env,
+ policies=policies,
+ agent2policy=agent2policy,
+ trainable_policies=trainable_policies,
+ agent_wrapper_cls=MISAgentWrapper,
+ reward_eval_delay=reward_eval_delay,
+ max_episode_length=max_episode_length,
+ )
+ be = learn_env.business_engine
+ assert isinstance(be, MISBusinessEngine)
+ self._device = be._device
+ self._diversity_reward_coef = diversity_reward_coef
+ self._reward_normalization_base = reward_normalization_base
+
+ self._sample_metrics: List[tuple] = []
+ self._eval_metrics: List[tuple] = []
+
+ def _get_global_and_agent_state(
+ self,
+ event: Any,
+ tick: int = None,
+ ) -> Tuple[Optional[Any], Dict[Any, Union[np.ndarray, list]]]:
+ return self._get_global_and_agent_state_impl(event, tick)
+
+ def _get_global_and_agent_state_impl(
+ self,
+ event: MISDecisionPayload,
+ tick: int = None,
+ ) -> Tuple[Union[None, np.ndarray, list], Dict[Any, Union[np.ndarray, list]]]:
+ vertex_states = event.vertex_states.unsqueeze(2).float().cpu()
+ normalized_tick = torch.full(vertex_states.size(), tick / self.env._business_engine._max_tick)
+ state = torch.cat([vertex_states, normalized_tick], dim=2).cpu().detach().numpy()
+ return None, {0: (state, event.graph)}
+
+ def _translate_to_env_action(self, action_dict: dict, event: Any) -> dict:
+ return {k: Action(vertex_states=ndarray_to_tensor(v, self._device)) for k, v in action_dict.items()}
+
+ def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]:
+ be = self.env.business_engine
+ assert isinstance(be, MISBusinessEngine)
+
+ cardinality_record_dict = self.env.metrics[MISEnvMetrics.IncludedNodeCount]
+ hamming_distance_record_dict = self.env.metrics[MISEnvMetrics.HammingDistanceAmongSamples]
+
+ assert (tick - 1) in cardinality_record_dict
+ cardinality_reward = cardinality_record_dict[tick - 1] - cardinality_record_dict.get(tick - 2, 0)
+
+ assert (tick - 1) in hamming_distance_record_dict
+ diversity_reward = hamming_distance_record_dict[tick - 1]
+
+ reward = cardinality_reward + self._diversity_reward_coef * diversity_reward
+ reward /= self._reward_normalization_base
+
+ return {0: reward.cpu().detach().numpy()}
+
+ def _eval_baseline(self) -> Dict[str, float]:
+ assert isinstance(self.env.business_engine, MISBusinessEngine)
+ graph_list: List[Dict[int, List[int]]] = self.env.business_engine._batch_adjs
+ num_samples: int = self.env.business_engine._num_samples
+
+ def best_among_samples(
+ solver: Callable[[Dict[int, List[int]]], List[int]],
+ num_samples: int,
+ graph: Dict[int, List[int]],
+ ) -> int:
+ res = 0
+ for _ in range(num_samples):
+ res = max(res, len(solver(graph)))
+ return res
+
+ graph_size_list = [len(graph) for graph in graph_list]
+ uniform_size_list = [best_among_samples(uniform_mis_solver, num_samples, graph) for graph in graph_list]
+ greedy_size_list = [best_among_samples(greedy_mis_solver, num_samples, graph) for graph in graph_list]
+
+ return {
+ "graph_size": np.mean(graph_size_list),
+ "uniform_size": np.mean(uniform_size_list),
+ "greedy_size": np.mean(greedy_size_list),
+ }
+
+ def sample(self, policy_state: Optional[Dict[str, Dict[str, Any]]] = None, num_steps: Optional[int] = None) -> dict:
+ if policy_state is not None: # Update policy state if necessary
+ self.set_policy_state(policy_state)
+ self._switch_env(self._learn_env) # Init the env
+ self._agent_wrapper.explore() # Collect experience
+
+ # One complete episode in one sample call here.
+ self._reset()
+ experiences: List[GraphBasedExpElement] = []
+
+ while not self._end_of_episode:
+ state = self._agent_state_dict[0][0]
+ graph = self._agent_state_dict[0][1]
+ # Get agent actions and translate them to env actions
+ action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
+ env_action_dict = self._translate_to_env_action(action_dict, self._event)
+ # Update env and get new states (global & agent)
+ self._step(list(env_action_dict.values()))
+ assert self._reward_eval_delay is None
+ reward_dict = self._get_reward(env_action_dict, self._event, self.env.tick)
+
+ experiences.append(
+ GraphBasedExpElement(
+ state=state,
+ action=action_dict[0],
+ reward=reward_dict[0],
+ is_done=self.env.metrics[MISEnvMetrics.IsDoneMasks],
+ graph=graph,
+ )
+ )
+
+ self._total_number_interactions += 1
+ self._current_episode_length += 1
+ self._post_step(None)
+
+ return {
+ "experiences": [experiences],
+ "info": [],
+ }
+
+ def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict:
+ self._switch_env(self._test_env)
+ info_list = []
+
+ for _ in range(num_episodes):
+ self._reset()
+
+ baseline_info = self._eval_baseline()
+ info_list.append(baseline_info)
+
+ if policy_state is not None:
+ self.set_policy_state(policy_state)
+
+ self._agent_wrapper.exploit()
+ while not self._end_of_episode:
+ action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
+ env_action_dict = self._translate_to_env_action(action_dict, self._event)
+ # Update env and get new states (global & agent)
+ self._step(list(env_action_dict.values()))
+
+ self._post_eval_step(None)
+
+ return {"info": info_list}
+
+ def _post_step(self, cache_element: CacheElement) -> None:
+ if not (self._end_of_episode or self.truncated):
+ return
+
+ node_count_record = self.env.metrics[MISEnvMetrics.IncludedNodeCount]
+ num_steps = max(node_count_record.keys()) + 1
+ # Report the mean among samples as the average MIS size in rollout.
+ num_nodes = torch.mean(node_count_record[num_steps - 1]).item()
+
+ self._sample_metrics.append((num_steps, num_nodes))
+
+ def _post_eval_step(self, cache_element: CacheElement) -> None:
+ if not (self._end_of_episode or self.truncated):
+ return
+
+ node_count_record = self.env.metrics[MISEnvMetrics.IncludedNodeCount]
+ num_steps = max(node_count_record.keys()) + 1
+ # Report the maximum among samples as the MIS size in evaluation.
+ num_nodes = torch.mean(torch.max(node_count_record[num_steps - 1], dim=1).values).item()
+
+ self._eval_metrics.append((num_steps, num_nodes))
+
+ def post_collect(self, info_list: list, ep: int) -> None:
+ assert len(self._sample_metrics) == 1, f"One Episode for One Rollout/Collection in Current Workflow Design."
+ num_steps, mis_size = self._sample_metrics[0]
+
+ cur = {
+ "n_steps": num_steps,
+ "avg_mis_size": mis_size,
+ "n_interactions": self._total_number_interactions,
+ }
+ self.metrics.update(cur)
+
+ # clear validation metrics
+ self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")}
+ self._sample_metrics.clear()
+
+ def post_evaluate(self, info_list: list, ep: int) -> None:
+ num_eval = len(self._eval_metrics)
+ assert num_eval > 0, f"Num evaluation rounds much be positive!"
+
+ cur = {
+ "val/num_eval": num_eval,
+ "val/avg_n_steps": np.mean([n for n, _ in self._eval_metrics]),
+ "val/avg_mis_size": np.mean([s for _, s in self._eval_metrics]),
+ "val/std_mis_size": np.std([s for _, s in self._eval_metrics]),
+ "val/avg_graph_size": np.mean([info["graph_size"] for info in info_list]),
+ "val/uniform_size": np.mean([info["uniform_size"] for info in info_list]),
+ "val/uniform_std": np.std([info["uniform_size"] for info in info_list]),
+ "val/greedy_size": np.mean([info["greedy_size"] for info in info_list]),
+ "val/greedy_std": np.std([info["greedy_size"] for info in info_list]),
+ }
+ self.metrics.update(cur)
+ self._eval_metrics.clear()
+
+
+class MISPlottingCallback(Callback):
+ def __init__(self, log_dir: str) -> None:
+ super().__init__()
+ self._log_dir = log_dir
+
+ def _plot_trainer(self) -> None:
+ for trainer_name in self.training_manager._trainer_dict.keys():
+ df = pd.read_csv(os.path.join(self._log_dir, f"{trainer_name}.csv"))
+ columns = [
+ "Steps",
+ "Mean Reward",
+ "Mean Return",
+ "Mean Advantage",
+ "0-Action",
+ "1-Action",
+ "2-Action",
+ "Critic Loss",
+ "Actor Loss",
+ ]
+
+ _, axes = plt.subplots(len(columns), sharex=False, figsize=(20, 21))
+
+ for col, ax in zip(columns, axes):
+ data = df[col].dropna().to_numpy()[:-1]
+ ax.plot(data, label=col)
+ ax.legend()
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(self._log_dir, f"{trainer_name}.png"))
+ plt.close()
+ plt.cla()
+ plt.clf()
+
+ @staticmethod
+ def _plot_mean_std(
+ ax: plt.Axes, x: np.ndarray, y_mean: np.ndarray, y_std: np.ndarray, color: str, label: str,
+ ) -> None:
+ ax.plot(x, y_mean, label=label, color=color)
+ ax.fill_between(x, y_mean - y_std, y_mean + y_std, color=color, alpha=0.2)
+
+ def _plot_metrics(self) -> None:
+ df_rollout = pd.read_csv(os.path.join(self._log_dir, "metrics_full.csv"))
+ df_eval = pd.read_csv(os.path.join(self._log_dir, "metrics_valid.csv"))
+
+ n_steps = df_rollout["n_steps"].to_numpy()
+ eval_n_steps = df_eval["val/avg_n_steps"].to_numpy()
+
+ eval_ep = df_eval["ep"].to_numpy()
+ graph_size = df_eval["val/avg_graph_size"].to_numpy()
+
+ mis_size = df_rollout["avg_mis_size"].to_numpy()
+ eval_mis_size, eval_size_std = df_eval["val/avg_mis_size"].to_numpy(), df_eval["val/std_mis_size"].to_numpy()
+ uniform_size, uniform_std = df_eval["val/uniform_size"].to_numpy(), df_eval["val/uniform_std"].to_numpy()
+ greedy_size, greedy_std = df_eval["val/greedy_size"].to_numpy(), df_eval["val/greedy_std"].to_numpy()
+
+ fig, (ax_t, ax_gsize, ax_mis) = plt.subplots(3, sharex=False, figsize=(20, 15))
+
+ color_map = {
+ "rollout": "cornflowerblue",
+ "eval": "orange",
+ "uniform": "green",
+ "greedy": "firebrick",
+ }
+
+ ax_t.plot(n_steps, label="n_steps", color=color_map["rollout"])
+ ax_t.plot(eval_ep, eval_n_steps, label="val/n_steps", color=color_map["eval"])
+
+ ax_gsize.plot(eval_ep, graph_size, label="val/avg_graph_size", color=color_map["eval"])
+
+ ax_mis.plot(mis_size, label="avg_mis_size", color=color_map["rollout"])
+ self._plot_mean_std(ax_mis, eval_ep, eval_mis_size, eval_size_std, color_map["eval"], "val/mis_size")
+ self._plot_mean_std(ax_mis, eval_ep, uniform_size, uniform_std, color_map["uniform"], "val/uniform_size")
+ self._plot_mean_std(ax_mis, eval_ep, greedy_size, greedy_std, color_map["greedy"], "val/greedy_size")
+
+ for ax in fig.get_axes():
+ ax.legend()
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(self._log_dir, "metrics.png"))
+ plt.close()
+ plt.cla()
+ plt.clf()
+
+ def on_validation_end(self, ep: int) -> None:
+ self._plot_trainer()
+ self._plot_metrics()
diff --git a/examples/mis/lwd/figures/fig1_mdp.png b/examples/mis/lwd/figures/fig1_mdp.png
new file mode 100644
index 000000000..2e6a3e5c2
Binary files /dev/null and b/examples/mis/lwd/figures/fig1_mdp.png differ
diff --git a/examples/mis/lwd/figures/fig2_transition.png b/examples/mis/lwd/figures/fig2_transition.png
new file mode 100644
index 000000000..6e1963536
Binary files /dev/null and b/examples/mis/lwd/figures/fig2_transition.png differ
diff --git a/examples/mis/lwd/figures/fig3_diversity_reward.png b/examples/mis/lwd/figures/fig3_diversity_reward.png
new file mode 100644
index 000000000..ac071b9c1
Binary files /dev/null and b/examples/mis/lwd/figures/fig3_diversity_reward.png differ
diff --git a/examples/mis/lwd/ppo/__init__.py b/examples/mis/lwd/ppo/__init__.py
new file mode 100644
index 000000000..607979975
--- /dev/null
+++ b/examples/mis/lwd/ppo/__init__.py
@@ -0,0 +1,65 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from maro.rl.training.algorithms.ppo import PPOParams
+from maro.rl.utils.common import get_env
+
+from examples.mis.lwd.ppo.model import GraphBasedPolicyNet, GraphBasedVNet
+from examples.mis.lwd.ppo.ppo import GraphBasedPPOPolicy, GraphBasedPPOTrainer
+
+
+def get_ppo_policy(
+ name: str,
+ state_dim: int,
+ action_num: int,
+ hidden_dim: int,
+ num_layers: int,
+ init_lr: float,
+) -> GraphBasedPPOPolicy:
+ return GraphBasedPPOPolicy(
+ name=name,
+ policy_net=GraphBasedPolicyNet(
+ state_dim=state_dim,
+ action_num=action_num,
+ hidden_dim=hidden_dim,
+ num_layers=num_layers,
+ init_lr=init_lr,
+ ),
+ )
+
+
+def get_ppo_trainer(
+ name: str,
+ state_dim: int,
+ hidden_dim: int,
+ num_layers: int,
+ init_lr: float,
+ clip_ratio: float,
+ max_tick: int,
+ batch_size: int,
+ reward_discount: float,
+ graph_batch_size: int,
+ graph_num_samples: int,
+ num_train_epochs: int,
+ norm_base: float,
+) -> GraphBasedPPOTrainer:
+ return GraphBasedPPOTrainer(
+ name=name,
+ params=PPOParams(
+ get_v_critic_net_func=lambda: GraphBasedVNet(state_dim, hidden_dim, num_layers, init_lr, norm_base),
+ grad_iters=1,
+ lam=None, # GAE not used here.
+ clip_ratio=clip_ratio,
+ ),
+ replay_memory_capacity=max_tick,
+ batch_size=batch_size,
+ reward_discount=reward_discount,
+ graph_batch_size=graph_batch_size,
+ graph_num_samples=graph_num_samples,
+ input_feature_size=state_dim,
+ num_train_epochs=num_train_epochs,
+ log_dir=get_env("LOG_PATH", required=False, default=None),
+ )
+
+
+__all__ = ["get_ppo_policy", "get_ppo_trainer"]
diff --git a/examples/mis/lwd/ppo/graph_net.py b/examples/mis/lwd/ppo/graph_net.py
new file mode 100644
index 000000000..e4c34d1a4
--- /dev/null
+++ b/examples/mis/lwd/ppo/graph_net.py
@@ -0,0 +1,124 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import dgl
+import dgl.function as fn
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GraphConvLayer(nn.Module):
+ def __init__(
+ self, in_feats: int, out_feats: int, norm: bool=True, jump: bool=True, bias: bool=True, activation=None,
+ ) -> None:
+ """The GraphConvLayer that can forward based on the given graph and graph node features.
+
+ Args:
+ in_feats (int): The dimension of input feature.
+ out_feats (int): The dimension of the output tensor.
+ norm (bool): Add feature normalization operation or not. Defaults to True.
+ jump (bool): Add skip connections of the input feature to the aggregation or not. Defaults to True.
+ bias (bool): Add a learnable bias layer or not. Defaults to True.
+ activation (torch.nn.functional): The output activation function to use. Defaults to None.
+ """
+ super(GraphConvLayer, self).__init__()
+ self._in_feats = in_feats
+ self._out_feats = out_feats
+ self._norm = norm
+ self._jump = jump
+ self._activation = activation
+
+ if jump:
+ self.weight = nn.Parameter(torch.Tensor(2 * in_feats, out_feats))
+ else:
+ self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
+
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_feats))
+ else:
+ self.register_parameter("bias", None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ torch.nn.init.xavier_uniform_(self.weight)
+ if self.bias is not None:
+ torch.nn.init.zeros_(self.bias)
+
+ def forward(self, feat: torch.Tensor, graph: dgl.DGLGraph, mask=None) -> torch.Tensor:
+ if self._jump:
+ _feat = feat
+
+ if self._norm:
+ if mask is None:
+ norm = torch.pow(graph.in_degrees().float(), -0.5)
+ norm.masked_fill_(graph.in_degrees() == 0, 1.0)
+ shp = norm.shape + (1,) * (feat.dim() - 1)
+ norm = torch.reshape(norm, shp).to(feat.device)
+ feat = feat * norm.unsqueeze(1)
+ else:
+ graph.ndata["h"] = mask.float()
+ graph.update_all(fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h"))
+ masked_deg = graph.ndata.pop("h")
+ norm = torch.pow(masked_deg, -0.5)
+ norm.masked_fill_(masked_deg == 0, 1.0)
+ feat = feat * norm.unsqueeze(-1)
+
+ if mask is not None:
+ feat = mask.float().unsqueeze(-1) * feat
+
+ graph.ndata["h"] = feat
+ graph.update_all(fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h"))
+ rst = graph.ndata.pop("h")
+
+ if self._norm:
+ rst = rst * norm.unsqueeze(-1)
+
+ if self._jump:
+ rst = torch.cat([rst, _feat], dim=-1)
+
+ rst = torch.matmul(rst, self.weight)
+
+ if self.bias is not None:
+ rst = rst + self.bias
+
+ if self._activation is not None:
+ rst = self._activation(rst)
+
+ return rst
+
+
+class GraphConvNet(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ hidden_dim: int,
+ num_layers: int,
+ activation=F.relu,
+ out_activation=None,
+ ) -> None:
+ """The GraphConvNet constructed with multiple GraphConvLayers.
+
+ Args:
+ input_dim (int): The dimension of the input feature.
+ output_dim (int): The dimension of the output tensor.
+ hidden_dim (int): The dimension of the hidden layers.
+ num_layers (int): How many layers in this GraphConvNet in total, including the input and output layer. >= 2.
+ activation (torch.nn.functional): The activation function used in input layer and hidden layers. Defaults to
+ torch.nn.functional.relu.
+ out_activation (torch.nn.functional): The output activation function to use. Defaults to None.
+ """
+ super(GraphConvNet, self).__init__()
+
+ self.layers = nn.ModuleList(
+ [GraphConvLayer(input_dim, hidden_dim, activation=activation)]
+ + [GraphConvLayer(hidden_dim, hidden_dim, activation=activation) for _ in range(num_layers - 2)]
+ + [GraphConvLayer(hidden_dim, output_dim, activation=out_activation)]
+ )
+
+ def forward(self, h: torch.Tensor, graph: dgl.DGLGraph, mask=None) -> torch.Tensor:
+ for layer in self.layers:
+ h = layer(h, graph, mask=mask)
+ return h
diff --git a/examples/mis/lwd/ppo/model.py b/examples/mis/lwd/ppo/model.py
new file mode 100644
index 000000000..f6fab8b99
--- /dev/null
+++ b/examples/mis/lwd/ppo/model.py
@@ -0,0 +1,209 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from typing import Tuple
+
+import dgl
+import torch
+from torch.distributions import Categorical
+from torch.optim import Adam
+
+from maro.rl.model import VNet, DiscretePolicyNet
+
+from examples.mis.lwd.ppo.graph_net import GraphConvNet
+from examples.mis.lwd.simulator import VertexState
+
+
+VertexStateIndex = 0
+
+
+def get_masks_idxs_subgraph_h(obs: torch.Tensor, graph: dgl.DGLGraph) -> Tuple[
+ torch.Tensor, torch.Tensor, torch.Tensor, dgl.DGLGraph, torch.Tensor,
+]:
+ """Extract masks and input feature information for the deferred vertexes.
+
+ Args:
+ obs (torch.Tensor): The observation tensor with shape (num_nodes, num_samples, feature_size).
+ graph (dgl.DGLGraph): The input graph.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ - undecided_node_mask, with shape (num_nodes, num_samples)
+ - subgraph_mask, with shape (num_nodes)
+ - subgraph_node_mask, with shape (num_nodes, num_samples)
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ - flatten_node_idxs, with shape (num_nodes * num_samples)
+ - flatten_subgraph_idxs, with shape (num_nodes)
+ - flatten_subgraph_node_idxs, with shape (num_nodes * num_samples)
+ dgl.DGLGraph:
+ torch.Tensor: input tensor with shape (num_nodes, num_samples, 2)
+ """
+ # Mask tensor with shape (num_nodes, num_samples)
+ undecided_node_mask = obs.select(2, VertexStateIndex).long() == VertexState.Deferred
+ # Flatten index tensor with shape (num_nodes * num_samples)
+ flatten_node_idxs = undecided_node_mask.reshape(-1).nonzero().squeeze(1)
+
+ # Mask tensor with shape (num_nodes)
+ subgraph_mask = undecided_node_mask.any(dim=1)
+ # Flatten index tensor with shape (num_nodes)
+ flatten_subgraph_idxs = subgraph_mask.nonzero().squeeze(1)
+
+ # Mask tensor with shape (num_nodes, num_samples)
+ subgraph_node_mask = undecided_node_mask.index_select(0, flatten_subgraph_idxs)
+ # Flatten index tensor with shape (num_nodes * num_samples)
+ flatten_subgraph_node_idxs = subgraph_node_mask.view(-1).nonzero().squeeze(1)
+
+ # Extract a subgraph with only node in flatten_subgraph_idxs, batch_size -> 1
+ subgraph = graph.subgraph(flatten_subgraph_idxs)
+
+ # The observation of the deferred vertexes.
+ h = obs.index_select(0, flatten_subgraph_idxs)
+
+ num_nodes, num_samples = obs.size(0), obs.size(1)
+ return subgraph_node_mask, flatten_node_idxs, flatten_subgraph_node_idxs, subgraph, h, num_nodes, num_samples
+
+
+class GraphBasedPolicyNet(DiscretePolicyNet):
+ def __init__(
+ self,
+ state_dim: int,
+ action_num: int,
+ hidden_dim: int,
+ num_layers: int,
+ init_lr: float,
+ ) -> None:
+ """A discrete policy net implemented with a graph as input.
+
+ Args:
+ state_dim (int): The dimension of the input state for this policy net.
+ action_num (int): The number of pre-defined discrete actions, i.e., the size of the discrete action space.
+ hidden_dim (int): The dimension of the hidden layers used in the GraphConvNet of the actor.
+ num_layers (int): The number of layers of the GraphConvNet of the actor.
+ init_lr (float): The initial learning rate of the optimizer.
+ """
+ super(GraphBasedPolicyNet, self).__init__(state_dim, action_num)
+
+ self._actor = GraphConvNet(
+ input_dim=state_dim,
+ output_dim=action_num,
+ hidden_dim=hidden_dim,
+ num_layers=num_layers,
+ )
+ with torch.no_grad():
+ self._actor.layers[-1].bias[2].add_(3.0)
+
+ self._optim = Adam(self._actor.parameters(), lr=init_lr)
+
+ def _get_action_probs_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
+ raise NotImplementedError
+
+ def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
+ action, _ = self._get_actions_with_logps_impl(states, exploring, **kwargs)
+ return action
+
+ def _get_actions_with_probs_impl(
+ self, states: torch.Tensor, exploring: bool, **kwargs
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ raise NotImplementedError
+
+ def _get_actions_with_logps_impl(
+ self, states: torch.Tensor, exploring: bool, **kwargs
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert "graph" in kwargs, f"graph is required to given in kwargs"
+ graph = kwargs["graph"]
+ subg_mask, node_idxs, subg_node_idxs, subg, h, num_nodes, num_samples = get_masks_idxs_subgraph_h(states, graph)
+
+ # Compute logits to get action, logits: shape (num_nodes * num_samples, 3)
+ logits = self._actor(h, subg, mask=subg_mask).view(-1, self.action_num).index_select(0, subg_node_idxs)
+
+ action = torch.zeros(num_nodes * num_samples, dtype=torch.long, device=self._device)
+ action_log_probs = torch.zeros(num_nodes * num_samples, device=self._device)
+
+ # NOTE: here we do not distinguish exploration mode and exploitation mode.
+ # The main reason here for doing so is that the LwD modeling is learnt to better exploration,
+ # the final result is chosen from the sampled trajectories.
+ m = Categorical(logits=logits)
+ action[node_idxs] = m.sample()
+ action_log_probs[node_idxs] = m.log_prob(action.index_select(0, node_idxs))
+
+ action = action.view(-1, num_samples)
+ action_log_probs = action_log_probs.view(-1, num_samples)
+
+ return action, action_log_probs
+
+ def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
+ raise NotImplementedError
+
+ def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
+ assert "graph" in kwargs, f"graph is required to given in kwargs"
+ graph = kwargs["graph"]
+ subg_mask, node_idxs, subg_node_idxs, subg, h, num_nodes, num_samples = get_masks_idxs_subgraph_h(states, graph)
+
+ # compute logits to get action
+ logits = self._actor(h, subg, mask=subg_mask).view(-1, self.action_num).index_select(0, subg_node_idxs)
+
+ try:
+ m = Categorical(logits=logits)
+ except Exception:
+ print(f"[GraphBasedPolicyNet] flatten_subgraph_node_idxs with shape {subg_node_idxs.shape}")
+ print(f"[GraphBasedPolicyNet] logits with shape {logits.shape}")
+ return None
+
+ # compute log probability of actions per node
+ actions = actions.reshape(-1)
+ action_log_probs = torch.zeros(num_nodes * num_samples, device=self._device)
+ action_log_probs[node_idxs] = m.log_prob(actions.index_select(0, node_idxs))
+ action_log_probs = action_log_probs.view(-1, num_samples)
+
+ return action_log_probs
+
+ def get_actions(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor:
+ actions = self._get_actions_impl(states, exploring, **kwargs)
+ return actions
+
+ def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
+ logps = self._get_states_actions_logps_impl(states, actions, **kwargs)
+ return logps
+
+
+class GraphBasedVNet(VNet):
+ def __init__(self, state_dim: int, hidden_dim: int, num_layers: int, init_lr: float, norm_base: float) -> None:
+ """A value net implemented with a graph as input.
+
+ Args:
+ state_dim (int): The dimension of the input state for this value network.
+ hidden_dim (int): The dimension of the hidden layers used in the GraphConvNet of the critic.
+ num_layers (int): The number of layers of the GraphConvNet of the critic.
+ init_lr (float): The initial learning rate of the optimizer.
+ norm_base (float): The normalization base for the predicted value. The critic network will predict the value
+ of each node, and the returned v value is defined as `Sum(predicted_node_values) / normalization_base`.
+ """
+ super(GraphBasedVNet, self).__init__(state_dim)
+ self._critic = GraphConvNet(
+ input_dim=state_dim,
+ output_dim=1,
+ hidden_dim=hidden_dim,
+ num_layers=num_layers,
+ )
+ self._optim = Adam(self._critic.parameters(), lr=init_lr)
+ self._normalization_base = norm_base
+
+ def _get_v_values(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
+ assert "graph" in kwargs, f"graph is required to given in kwargs"
+ graph = kwargs["graph"]
+ subg_mask, node_idxs, subg_node_idxs, subg, h, num_nodes, num_samples = get_masks_idxs_subgraph_h(states, graph)
+
+ values = self._critic(h, subg, mask=subg_mask).view(-1).index_select(0, subg_node_idxs)
+ # Init node value prediction, shape (num_nodes * num_samples)
+ node_value_preds = torch.zeros(num_nodes * num_samples, device=self._device)
+ node_value_preds[node_idxs] = values
+
+ graph.ndata["h"] = node_value_preds.view(-1, num_samples)
+ value_pred = dgl.sum_nodes(graph, "h") / self._normalization_base
+ graph.ndata.pop("h")
+
+ return value_pred
+
+ def v_values(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
+ v = self._get_v_values(states, **kwargs)
+ return v
diff --git a/examples/mis/lwd/ppo/ppo.py b/examples/mis/lwd/ppo/ppo.py
new file mode 100644
index 000000000..673fcf79a
--- /dev/null
+++ b/examples/mis/lwd/ppo/ppo.py
@@ -0,0 +1,186 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import math
+from typing import List, Tuple, cast
+
+import dgl
+import torch
+
+from maro.rl.model import DiscretePolicyNet
+from maro.rl.policy import DiscretePolicyGradient, RLPolicy
+from maro.rl.training.algorithms.base import ACBasedOps, ACBasedParams, ACBasedTrainer
+from maro.rl.utils import ndarray_to_tensor
+from maro.utils import LogFormat, Logger
+
+from examples.mis.lwd.ppo.replay_memory import BatchGraphReplayMemory, GraphBasedExpElement, GraphBasedTransitionBatch
+
+
+class GraphBasedPPOTrainOps(ACBasedOps):
+ def __init__(
+ self,
+ name: str,
+ policy: RLPolicy,
+ params: ACBasedParams,
+ reward_discount: float,
+ parallelism: int = 1,
+ ) -> None:
+ super().__init__(name, policy, params, reward_discount, parallelism)
+
+ self._clip_lower_bound = math.log(1.0 - self._clip_ratio)
+ self._clip_upper_bound = math.log(1.0 + self._clip_ratio)
+
+ def _get_critic_loss(self, batch: GraphBasedTransitionBatch) -> torch.Tensor:
+ states = ndarray_to_tensor(batch.states, device=self._device).permute(1, 0, 2)
+ returns = ndarray_to_tensor(batch.returns, device=self._device)
+
+ self._v_critic_net.train()
+ kwargs = {"graph": batch.graph}
+ value_preds = self._v_critic_net.v_values(states, **kwargs).permute(1, 0)
+ critic_loss = 0.5 * (value_preds - returns).pow(2).mean()
+ return critic_loss
+
+ def _get_actor_loss(self, batch: GraphBasedTransitionBatch) -> Tuple[torch.Tensor, bool]:
+ graph = batch.graph
+ kwargs = {"graph": batch.graph}
+
+ states = ndarray_to_tensor(batch.states, device=self._device).permute(1, 0, 2)
+ actions = ndarray_to_tensor(batch.actions, device=self._device).permute(1, 0)
+ advantages = ndarray_to_tensor(batch.advantages, device=self._device)
+ logps_old = ndarray_to_tensor(batch.old_logps, device=self._device)
+ if self._is_discrete_action:
+ actions = actions.long()
+
+ self._policy.train()
+ logps = self._policy.get_states_actions_logps(states, actions, **kwargs).permute(1, 0)
+ diff = logps - logps_old
+ clamped_diff = torch.clamp(diff, self._clip_lower_bound, self._clip_upper_bound)
+ stacked_diff = torch.stack([diff, clamped_diff], dim=2)
+
+ graph.ndata["h"] = stacked_diff.permute(1, 0, 2)
+ h = dgl.sum_nodes(graph, "h").permute(1, 0, 2)
+ graph.ndata.pop("h")
+ ratio = torch.exp(h.select(2, 0))
+ clamped_ratio = torch.exp(h.select(2, 1))
+
+ actor_loss = -torch.min(ratio * advantages, clamped_ratio * advantages).mean()
+ return actor_loss, False # TODO: add early-stop logic if needed
+
+
+class GraphBasedPPOTrainer(ACBasedTrainer):
+ def __init__(
+ self,
+ name: str,
+ params: ACBasedParams,
+ replay_memory_capacity: int = 32,
+ batch_size: int = 16,
+ data_parallelism: int = 1,
+ reward_discount: float = 0.9,
+ graph_batch_size: int = 4,
+ graph_num_samples: int = 2,
+ input_feature_size: int = 2,
+ num_train_epochs: int = 4,
+ log_dir: str = None,
+ ) -> None:
+ super().__init__(name, params, replay_memory_capacity, batch_size, data_parallelism, reward_discount)
+
+ self._graph_batch_size = graph_batch_size
+ self._graph_num_samples = graph_num_samples
+ self._input_feature_size = input_feature_size
+ self._num_train_epochs = num_train_epochs
+
+ self._trainer_logger = None
+ if log_dir is not None:
+ self._trainer_logger = Logger(
+ tag=self.name,
+ format_=LogFormat.none,
+ dump_folder=log_dir,
+ dump_mode="w",
+ extension_name="csv",
+ auto_timestamp=False,
+ stdout_level="INFO",
+ )
+ self._trainer_logger.debug(
+ "Steps,Mean Reward,Mean Return,Mean Advantage,0-Action,1-Action,2-Action,Critic Loss,Actor Loss"
+ )
+
+ def build(self) -> None:
+ self._ops = cast(GraphBasedPPOTrainOps, self.get_ops())
+ self._replay_memory = BatchGraphReplayMemory(
+ max_t=self._replay_memory_capacity,
+ graph_batch_size=self._graph_batch_size,
+ num_samples=self._graph_num_samples,
+ feature_size=self._input_feature_size,
+ )
+
+ def record_multiple(self, env_idx: int, exp_elements: List[GraphBasedExpElement]) -> None:
+ self._replay_memory.reset()
+
+ self._ops._v_critic_net.eval()
+ self._ops._policy.eval()
+
+ for exp in exp_elements:
+ state = ndarray_to_tensor(exp.state, self._ops._device)
+ action = ndarray_to_tensor(exp.action, self._ops._device)
+ value_pred = self._ops._v_critic_net.v_values(state, graph=exp.graph).cpu().detach().numpy()
+ logps = self._ops._policy.get_states_actions_logps(state, action, graph=exp.graph).cpu().detach().numpy()
+ self._replay_memory.add_transition(exp, value_pred, logps)
+
+ self._ops._v_critic_net.train()
+ self._ops._policy.train()
+
+ def get_local_ops(self) -> GraphBasedPPOTrainOps:
+ return GraphBasedPPOTrainOps(
+ name=self._policy.name,
+ policy=self._policy,
+ parallelism=self._data_parallelism,
+ reward_discount=self._reward_discount,
+ params=self._params,
+ )
+
+ def train_step(self) -> None:
+ assert isinstance(self._ops, GraphBasedPPOTrainOps)
+
+ data_loader = self._replay_memory.build_update_sampler(
+ batch_size=self._batch_size,
+ num_train_epochs=self._num_train_epochs,
+ gamma=self._reward_discount,
+ )
+
+ statistics = self._replay_memory.get_statistics()
+
+ avg_critic_loss, avg_actor_loss = 0, 0
+ for batch in data_loader:
+ for _ in range(self._params.grad_iters):
+ critic_loss = self._ops.update_critic(batch)
+ actor_loss, _ = self._ops.update_actor(batch)
+ avg_critic_loss += critic_loss
+ avg_actor_loss += actor_loss
+ avg_critic_loss /= self._params.grad_iters
+ avg_actor_loss /= self._params.grad_iters
+
+ if self._trainer_logger is not None:
+ self._trainer_logger.debug(
+ f"{statistics['step_t']},"
+ f"{statistics['reward']},"
+ f"{statistics['return']},"
+ f"{statistics['advantage']},"
+ f"{statistics['action_0']},"
+ f"{statistics['action_1']},"
+ f"{statistics['action_2']},"
+ f"{avg_critic_loss},"
+ f"{avg_actor_loss}"
+ )
+
+
+class GraphBasedPPOPolicy(DiscretePolicyGradient):
+ def __init__(self, name: str, policy_net: DiscretePolicyNet, trainable: bool = True, warmup: int = 0) -> None:
+ super(GraphBasedPPOPolicy, self).__init__(name, policy_net, trainable, warmup)
+
+ def get_actions_tensor(self, states: torch.Tensor, **kwargs) -> torch.Tensor:
+ actions = self._get_actions_impl(states, **kwargs)
+ return actions
+
+ def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor:
+ logps = self._get_states_actions_logps_impl(states, actions, **kwargs)
+ return logps
diff --git a/examples/mis/lwd/ppo/replay_memory.py b/examples/mis/lwd/ppo/replay_memory.py
new file mode 100644
index 000000000..fc6f2a185
--- /dev/null
+++ b/examples/mis/lwd/ppo/replay_memory.py
@@ -0,0 +1,150 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Dict
+
+import dgl
+import numpy as np
+from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
+
+
+@dataclass
+class GraphBasedExpElement:
+ state: np.ndarray
+ action: np.ndarray
+ reward: np.ndarray
+ is_done: np.ndarray
+ graph: dgl.DGLGraph
+
+ def split_contents_by_trainer(self, agent2trainer: Dict[Any, str]) -> Dict[str, GraphBasedExpElement]:
+ """Split the ExpElement's contents by trainer.
+
+ Args:
+ agent2trainer (Dict[Any, str]): Mapping of agent name and trainer name.
+
+ Returns:
+ Contents (Dict[str, ExpElement]): A dict that contains the ExpElements of all trainers. The key of this
+ dict is the trainer name.
+ """
+ return {trainer_name: self for trainer_name in agent2trainer.values()}
+
+
+@dataclass
+class GraphBasedTransitionBatch:
+ states: np.ndarray
+ actions: np.ndarray
+ returns: np.ndarray
+ graph: dgl.DGLGraph
+ advantages: np.ndarray
+ old_logps: np.ndarray
+
+
+class BatchGraphReplayMemory:
+ def __init__(
+ self,
+ max_t: int,
+ graph_batch_size: int,
+ num_samples: int,
+ feature_size: int,
+ ) -> None:
+ self.max_t = max_t
+ self.graph_batch_size: int = graph_batch_size
+ self.num_nodes: int = None
+ self.num_samples: int = num_samples
+ self.feature_size: int = feature_size
+ self._t = 0
+
+ self.graph: dgl.DGLGraph = None
+ self.states: np.ndarray = None # shape [max_t + 1, num_nodes, num_samples, feature_size]
+ self.actions: np.ndarray = None # shape [max_t, num_nodes, num_samples]
+ self.action_logps: np.ndarray = None
+
+ array_size = (self.max_t + 1, graph_batch_size, num_samples)
+ self.rewards: np.ndarray = np.zeros(array_size, dtype=np.float32)
+ self.is_done: np.ndarray = np.ones(array_size, dtype=np.int8)
+ self.returns: np.ndarray = np.zeros(array_size, dtype=np.float32)
+ self.advantages: np.ndarray = np.zeros(array_size, dtype=np.float32)
+ self.value_preds: np.ndarray = np.zeros(array_size, dtype=np.float32)
+
+ def _init_storage(self, graph: dgl.DGLGraph) -> None:
+ self.graph = graph
+ self.num_nodes = graph.num_nodes()
+
+ array_size = (self.max_t + 1, self.num_nodes, self.num_samples)
+ self.states = np.zeros((*array_size, self.feature_size), dtype=np.float32)
+ self.actions = np.zeros(array_size, dtype=np.float32)
+ self.action_logps = np.zeros(array_size, dtype=np.float32)
+
+ def add_transition(self, exp_element: GraphBasedExpElement, value_pred: np.ndarray, logps: np.ndarray) -> None:
+ if self._t == 0:
+ assert exp_element.graph is not None
+ self._init_storage(exp_element.graph)
+
+ self.states[self._t] = exp_element.state
+ self.actions[self._t] = exp_element.action
+ self.rewards[self._t] = exp_element.reward
+ self.value_preds[self._t] = value_pred
+ self.action_logps[self._t] = logps
+ self.is_done[self._t + 1] = exp_element.is_done
+
+ self._t += 1
+
+ def build_update_sampler(self, batch_size: int, num_train_epochs: int, gamma: float) -> GraphBasedTransitionBatch:
+ for t in reversed(range(self.max_t)):
+ self.returns[t] = self.rewards[t] + gamma * (1 - self.is_done[t + 1]) * self.returns[t + 1]
+
+ advantages = self.returns[:-1] - self.value_preds[:-1]
+ self.advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
+
+ flat_states = np.transpose(self.states[: self._t], (0, 2, 1, 3)).reshape(-1, self.num_nodes, self.feature_size)
+ flat_actions = np.transpose(self.actions[: self._t], (0, 2, 1)).reshape(-1, self.num_nodes)
+ flat_returns = np.transpose(self.returns[: self._t], (0, 2, 1)).reshape(-1, self.graph_batch_size)
+ flat_advantages = np.transpose(self.advantages[: self._t], (0, 2, 1)).reshape(-1, self.graph_batch_size)
+ flat_action_logps = np.transpose(self.action_logps[: self._t], (0, 2, 1)).reshape(-1, self.num_nodes)
+
+ flat_dim = flat_states.shape[0]
+
+ sampler = BatchSampler(
+ sampler=SubsetRandomSampler(range(flat_dim)),
+ batch_size=min(flat_dim, batch_size),
+ drop_last=False,
+ )
+
+ sampler_t = 0
+ while sampler_t < num_train_epochs:
+ for idx in sampler:
+ yield GraphBasedTransitionBatch(
+ states=flat_states[idx],
+ actions=flat_actions[idx],
+ returns=flat_returns[idx],
+ graph=self.graph,
+ advantages=flat_advantages[idx],
+ old_logps=flat_action_logps[idx],
+ )
+
+ sampler_t += 1
+ if sampler_t == num_train_epochs:
+ break
+
+ def reset(self) -> None:
+ self.is_done[0] = 0
+ self._t = 0
+
+ def get_statistics(self) -> Dict[str, float]:
+ action_0_count = np.count_nonzero(self.actions[: self._t] == 0)
+ action_1_count = np.count_nonzero(self.actions[: self._t] == 1)
+ action_2_count = np.count_nonzero(self.actions[: self._t] == 2)
+ action_count = action_0_count + action_1_count + action_2_count
+ assert action_count == self.actions[: self._t].size
+ return {
+ "step_t": self._t,
+ "action_0": action_0_count / action_count,
+ "action_1": action_1_count / action_count,
+ "action_2": action_2_count / action_count,
+ "reward": self.rewards[: self._t].mean(),
+ "return": self.returns[: self._t].mean(),
+ "advantage": self.advantages[: self._t].mean(),
+ }
diff --git a/examples/mis/lwd/simulator/__init__.py b/examples/mis/lwd/simulator/__init__.py
new file mode 100644
index 000000000..ac949bd50
--- /dev/null
+++ b/examples/mis/lwd/simulator/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from examples.mis.lwd.simulator.common import Action, MISDecisionPayload, MISEnvMetrics, VertexState
+from examples.mis.lwd.simulator.mis_business_engine import MISBusinessEngine
+
+
+__all__ = ["Action", "MISBusinessEngine", "MISDecisionPayload", "MISEnvMetrics", "VertexState"]
diff --git a/examples/mis/lwd/simulator/common.py b/examples/mis/lwd/simulator/common.py
new file mode 100644
index 000000000..175a4f63b
--- /dev/null
+++ b/examples/mis/lwd/simulator/common.py
@@ -0,0 +1,32 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+from enum import Enum, IntEnum
+
+import dgl
+import torch
+
+from maro.common import BaseAction, BaseDecisionEvent
+
+
+class VertexState(IntEnum):
+ Excluded = 0
+ Included = 1
+ Deferred = 2
+
+
+class MISEnvMetrics(Enum):
+ IncludedNodeCount = "Included Node Count"
+ HammingDistanceAmongSamples = "Hamming Distance Among Two Samples"
+ IsDoneMasks = "Is Done Masks"
+
+
+class Action(BaseAction):
+ def __init__(self, vertex_states: torch.Tensor) -> None:
+ self.vertex_states = vertex_states
+
+
+class MISDecisionPayload(BaseDecisionEvent):
+ def __init__(self, graph: dgl.DGLGraph, vertex_states: torch.Tensor) -> None:
+ self.graph = graph
+ self.vertex_states = vertex_states
diff --git a/examples/mis/lwd/simulator/mis_business_engine.py b/examples/mis/lwd/simulator/mis_business_engine.py
new file mode 100644
index 000000000..0a20d1b0b
--- /dev/null
+++ b/examples/mis/lwd/simulator/mis_business_engine.py
@@ -0,0 +1,274 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import math
+import random
+from typing import Dict, List, Optional, Tuple
+
+import dgl
+import dgl.function as fn
+import torch
+
+from maro.backends.frame import FrameBase, SnapshotList
+from maro.event_buffer import CascadeEvent, EventBuffer, MaroEvents
+from maro.simulator.scenarios import AbsBusinessEngine
+
+from examples.mis.lwd.simulator.common import Action, MISDecisionPayload, MISEnvMetrics, VertexState
+
+
+class MISBusinessEngine(AbsBusinessEngine):
+ def __init__(
+ self,
+ event_buffer: EventBuffer,
+ topology: Optional[str],
+ start_tick: int,
+ max_tick: int,
+ snapshot_resolution: int,
+ max_snapshots: Optional[int],
+ additional_options: dict = None,
+ ) -> None:
+ super(MISBusinessEngine, self).__init__(
+ scenario_name="MaximumIndependentSet",
+ event_buffer=event_buffer,
+ topology=topology,
+ start_tick=start_tick,
+ max_tick=max_tick,
+ snapshot_resolution=snapshot_resolution,
+ max_snapshots=max_snapshots,
+ additional_options=additional_options,
+ )
+ self._config = self._parse_additional_options(additional_options)
+
+ # NOTE: not used now
+ self._frame: FrameBase = FrameBase()
+ self._snapshots = self._frame.snapshots
+
+ self._register_events()
+
+ self._batch_adjs: List[Dict[int, List[int]]] = []
+ self._batch_graphs: dgl.DGLGraph = None
+ self._vertex_states: torch.Tensor = None
+ self._is_done_mask: torch.Tensor = None
+
+ self._num_included_record: Dict[int, torch.Tensor] = None
+ self._hamming_distance_among_samples_record: Dict[int, torch.Tensor] = None
+
+ self.reset()
+
+ def _parse_additional_options(self, additional_options: dict) -> dict:
+ required_keys = [
+ "graph_batch_size",
+ "num_samples",
+ "device",
+ "num_node_lower_bound",
+ "num_node_upper_bound",
+ "node_sample_probability",
+ ]
+ for key in required_keys:
+ assert key in additional_options, f"Parameter {key} is required in additional_options!"
+
+ self._graph_batch_size = additional_options["graph_batch_size"]
+ self._num_samples = additional_options["num_samples"]
+ self._device = additional_options["device"]
+ self._num_node_lower_bound = additional_options["num_node_lower_bound"]
+ self._num_node_upper_bound = additional_options["num_node_upper_bound"]
+ self._node_sample_probability = additional_options["node_sample_probability"]
+
+ return {key: additional_options[key] for key in required_keys}
+
+ @property
+ def configs(self) -> dict:
+ return self._config
+
+ @property
+ def frame(self) -> FrameBase:
+ return self._frame
+
+ @property
+ def snapshots(self) -> SnapshotList:
+ return self._snapshots
+
+ def _register_events(self) -> None:
+ self._event_buffer.register_event_handler(MaroEvents.TAKE_ACTION, self._on_action_received)
+
+ def get_agent_idx_list(self) -> List[int]:
+ return [0]
+
+ def set_seed(self, seed: int) -> None:
+ pass
+
+ def _calculate_and_record_metrics(self, tick: int, undecided_before_mask: torch.Tensor) -> None:
+ # Calculate and record the number of included vertexes for cardinality reward.
+ included_mask = self._vertex_states == VertexState.Included
+ self._batch_graphs.ndata["h"] = included_mask.float()
+ node_count = dgl.sum_nodes(self._batch_graphs, "h")
+ self._batch_graphs.ndata.pop("h")
+ self._num_included_record[tick] = node_count
+
+ # Calculate and record the diversity for diversity reward.
+ if self._num_samples == 2:
+ undecided_before_mask_left, undecided_before_mask_right = undecided_before_mask.split(1, dim=1)
+
+ states_left, states_right = self._vertex_states.split(1, dim=1)
+ undecided_mask_left = states_left == VertexState.Deferred
+ undecided_mask_right = states_right == VertexState.Deferred
+
+ hamming_distance = torch.abs(states_left.float() - states_right.float())
+ hamming_distance[undecided_mask_left | undecided_mask_right] = 0.0
+ hamming_distance[~undecided_before_mask_left & ~undecided_before_mask_right] = 0.0
+
+ self._batch_graphs.ndata["h"] = hamming_distance
+ distance = dgl.sum_nodes(self._batch_graphs, "h").expand_as(node_count)
+ self._batch_graphs.ndata.pop("h")
+ self._hamming_distance_among_samples_record[tick] = distance
+
+ return
+
+ def _on_action_received(self, event: CascadeEvent) -> None:
+ actions = event.payload
+ assert isinstance(actions, List)
+
+ undecided_before_mask = self._vertex_states == VertexState.Deferred
+
+ # Update Phase
+ for action in actions:
+ assert isinstance(action, Action)
+ undecided_mask = self._vertex_states == VertexState.Deferred
+ self._vertex_states[undecided_mask] = action.vertex_states[undecided_mask]
+
+ # Clean-Up Phase: Set clashed node pairs to Deferred
+ included_mask = self._vertex_states == VertexState.Included
+ self._batch_graphs.ndata["h"] = included_mask.float()
+ self._batch_graphs.update_all(fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h"))
+ neighbor_included_mask = self._batch_graphs.ndata.pop("h").bool()
+
+ # Clashed: if the node and its neighbor are both set to included
+ clashed_mask = included_mask & neighbor_included_mask
+ self._vertex_states[clashed_mask] = VertexState.Deferred
+ neighbor_included_mask[clashed_mask] = False
+
+ # Clean-Up Phase: exclude the deferred vertex neighboring to an included one.
+ undecided_mask = self._vertex_states == VertexState.Deferred
+ self._vertex_states[undecided_mask & neighbor_included_mask] = VertexState.Excluded
+
+ # Timeout handling
+ if event.tick + 1 == self._max_tick:
+ undecided_mask = self._vertex_states == VertexState.Deferred
+ self._vertex_states[undecided_mask] = VertexState.Excluded
+
+ self._calculate_and_record_metrics(event.tick, undecided_before_mask)
+ self._update_is_done_mask()
+
+ return
+
+ def step(self, tick: int) -> None:
+ decision_payload = MISDecisionPayload(
+ graph=self._batch_graphs,
+ vertex_states=self._vertex_states.clone(),
+ )
+ decision_event = self._event_buffer.gen_decision_event(tick, decision_payload)
+ self._event_buffer.insert_event(decision_event)
+
+ def _generate_er_graph(self) -> Tuple[dgl.DGLGraph, Dict[int, List[int]]]:
+ num_nodes = random.randint(self._num_node_lower_bound, self._num_node_upper_bound)
+ adj = {node: [] for node in range(num_nodes)}
+
+ w = -1
+ lp = math.log(1.0 - self._node_sample_probability)
+
+ # Nodes in graph are from 0, num_nodes - 1 (start with v as the first node index).
+ v = 1
+ u_list, v_list = [], []
+ while v < num_nodes:
+ lr = math.log(1.0 - random.random())
+ w = w + 1 + int(lr / lp)
+ while w >= v and v < num_nodes:
+ w = w - v
+ v = v + 1
+ if v < num_nodes:
+ u_list.extend([v, w])
+ v_list.extend([w, v])
+ adj[v].append(w)
+ adj[w].append(v)
+
+ graph = dgl.graph((u_list, v_list), num_nodes=num_nodes)
+
+ return graph, adj
+
+ def _reset_batch_graph(self) -> None:
+ graph_list = []
+ self._batch_adjs = []
+ for _ in range(self._graph_batch_size):
+ graph, adj = self._generate_er_graph()
+ graph_list.append(graph)
+ self._batch_adjs.append(adj)
+
+ self._batch_graphs = dgl.batch(graph_list)
+ self._batch_graphs.set_n_initializer(dgl.init.zero_initializer)
+ self._batch_graphs = self._batch_graphs.to(self._device)
+ return
+
+ def reset(self, keep_seed: bool = False) -> None:
+ self._reset_batch_graph()
+
+ tensor_size = (self._batch_graphs.num_nodes(), self._num_samples)
+ self._vertex_states = torch.full(size=tensor_size, fill_value=VertexState.Deferred, device=self._device)
+ self._update_is_done_mask()
+
+ self._num_included_record = {}
+ self._hamming_distance_among_samples_record = {}
+ return
+
+ def _update_is_done_mask(self) -> None:
+ undecided_mask = self._vertex_states == VertexState.Deferred
+ self._batch_graphs.ndata["h"] = undecided_mask.float()
+ num_undecided = dgl.sum_nodes(self._batch_graphs, "h")
+ self._batch_graphs.ndata.pop("h")
+ self._is_done_mask = (num_undecided == 0)
+ return
+
+ def post_step(self, tick: int) -> bool:
+ if tick + 1 == self._max_tick:
+ return True
+
+ return torch.all(self._is_done_mask).item()
+
+ def get_metrics(self) -> dict:
+ return {
+ MISEnvMetrics.IncludedNodeCount: self._num_included_record,
+ MISEnvMetrics.HammingDistanceAmongSamples: self._hamming_distance_among_samples_record,
+ MISEnvMetrics.IsDoneMasks: self._is_done_mask.cpu().detach().numpy(),
+ }
+
+
+if __name__ == "__main__":
+ from maro.simulator import Env
+ device = torch.device("cuda:0")
+
+ env = Env(
+ business_engine_cls=MISBusinessEngine,
+ options={
+ "graph_batch_size": 4,
+ "num_samples": 2,
+ "device": device,
+ "num_node_lower_bound": 15,
+ "num_node_upper_bound": 20,
+ "node_sample_probability": 0.15,
+ },
+ )
+
+ env.reset()
+ metrics, decision_event, done = env.step(None)
+ while not done:
+ assert isinstance(decision_event, MISDecisionPayload)
+ vertex_state = decision_event.vertex_states
+ undecided_mask = vertex_state == VertexState.Deferred
+ random_mask = torch.rand(vertex_state.size(), device=device) < 0.8
+ vertex_state[undecided_mask & random_mask] = VertexState.Included
+ action = Action(vertex_state)
+ metrics, decision_event, done = env.step(action)
+
+ for key in [MISEnvMetrics.IncludedNodeCount]:
+ print(f"[{env.tick - 1:02d}] {key:28s} {metrics[key][env.tick - 1].reshape(-1)}")
+ for key in [MISEnvMetrics.IsDoneMasks]:
+ print(f"[{env.tick - 1:02d}] {key:28s} {metrics[key].reshape(-1)}")