diff --git a/examples/mis/lwd/README.md b/examples/mis/lwd/README.md new file mode 100644 index 000000000..1d6070137 --- /dev/null +++ b/examples/mis/lwd/README.md @@ -0,0 +1,150 @@ +# LWD: Learning What to Defer for Maximum Independent Sets + +It is the application of the [LwD](http://proceedings.mlr.press/v119/ahn20a.html) method (published in ICML 2020). +Some code are referred from the original GitHub [repository](https://github.com/sungsoo-ahn/learning_what_to_defer) +held by the authors. + +## Algorithm Introduction + +### Deferred Markov Decision Process + +
+ + ![MDP Illustration](./figures/fig1_mdp.png) + +
+ +#### State + +Each state of the MDP is represented as a *vertex-state* vector: + +
+ +$s = [s_i: i \in V] \in \{0, 1, *\}^V$ + +
+ +where *0, 1, \** indicates vertex *i* is *excluded*, *included*, +and *the determination is deferred and expected to be made in later iterations* respectively. +The MDP is *initialized* with the deferred vertex-states, i.e., $s_i = *, \forall i \in V$, +while *terminated* when (a) there is no deferred vertex-state left or (b) time limit is reached. + +#### Action + +Actions correspond to new assignments for the next state of vertices, defined only on the deferred vertices here: + +
+ +$a_* = [a_i: i \in V_*] \in \{0, 1, *\}^{V_*}$ + +
+ +where $V_* = \{i: i \in V, x_i = *\}$. + +#### Transition + +The transition $P_{a_*}(s, s')$ consists of two deterministic phases: + +- *update phase*: takes the action $a_*$ to get an intermediate vertex-state $\hat{s}$, +i.e., $\hat{s_i} = a_i$ if $i \in V_*$ and $\hat{s_i} = s_i$ otherwise. +- *clean-up phase*: modifies $\hat{s}$ to yield a valid vertex-state $s'$. + + - Whenever there exists a pair of included vertices adjacent to each other, + they are both mapped back to the deferred vertex-state. + - Excludes any deferred vertex neighboring with an included vertex. + +Here is an illustration of the transition fucntion: + +
+ +![Transition](./figures/fig2_transition.png) + +
+ +#### Reward + +A *cardinality reward* is defined here: + +
+ +$R(s, s') = \sum_{i \in V_* \setminus V_*'}{s_i'}$ + +
+ +where $V_*$ and $V_*'$ are the set of vertices with deferred vertex-state with respect to $s$ and $s'$ respectively. +By doing so, the overall reward of the MDP corresonds to the cardinality of the independent set returned. + +### Diversification Reward + +Couple two copies of MDPs defined on an indentical graph $G$ into a new MDP. +Then the new MDP is associated with a pair of distinct vertex-state vectors $(s, \bar{s})$, +and let the resulting solutions be $(x, \bar{x})$. +We directly reward the deviation between the coupled solutions in terms of $l_1$-norm, i.e., $||x-\bar{x}||_1$. +To be specific, the deviation is decomposed into rewards in each iteration of the MDP defined by: + +
+ +$R_{div}(s, s', \bar{s}, \bar{s}') = \sum_{i \in \hat{V}}|s_i'-\bar{s}_i'|$, where $\hat{V}=(V_* \setminus V_*')\cup(\bar{V}_* \setminus \bar{V}_*')$ + +
+ +Here is an example of the diversity reward: + +
+ +![Diversity Reward](./figures/fig3_diversity_reward.png) + +
+ +*The Entropy Regularization plays a similar role to the diversity reward introduced above. +But note that, the entropy regularition only attempts to generate diverse trajectories of the same MDP, +which does not necessarily lead to diverse solutions at last, +since there existing many trajectories resulting in the same solution.* + +### Design of the Neural Network + +The policy network $\pi(a|s)$ and the value network $V(s)$ is designed to follow the +[GraphSAGE](https://proceedings.neurips.cc/paper/2017/hash/5dd9db5e033da9c6fb5ba83c7a7ebea9-Abstract.html) architecture, +which is a general inductive framework that leverages node feature information +to efficiently generate node embeddings by sampling and aggregating features from a node's local neighborhood. +Each network consists of multiple layers $h^{(n)}$ with $n = 1, ..., N$ +where the $n$-layer with weights $W_1^{(n)}$ and $W_2^{(n)}$ performs the following transformation on input $H$: + +
+ +$h^{(n)} = ReLU(HW_1^{(n)}+D^{-\frac{1}{2}}BD^{-\frac{1}{2}}HW_2^{(N)})$. + +
+ +Here $B$ and $D$ corresponds to adjacency and degree matrix of the graph $G$, respectively. At the final layer, +the policy and value networks apply softmax function and graph readout function with sum pooling instead of ReLU +to generate actions and value estimates, respectively. + +### Input of the Neural Network + +- The subgraph that is induced on the deferred vertices $V_*$ as the input of the networks +since the determined part of the graph no longer affects the future rewards of the MDP. +- Input features: + + - Vertex degrees; + - The current iteration-index of the MDP, normalized by the maximum number of iterations. + +### Training Algorithm + +The Proximal Policy Optimization (PPO) is used in this solution. + +## Quick Start + +Please make sure the environment is correctly set up, refer to +[MARO](https://github.com/microsoft/maro#install-maro-from-source) for more installation guidance. +To try the example code, you can simply run: + +```sh +python examples/rl/run.py examples/mis/lwd/config.yml +``` + +The default log path is set to *examples/mis/lwd/log/test*, the recorded metrics and training curves can be found here. + +To adjust the configurations of the training workflow, go to file: *examples/mis/lwd/config.yml*, +To adjust the problem formulation, network setting and some other detailed configurations, +go to file *examples/mis/lwd/config.py*. diff --git a/examples/mis/lwd/__init__.py b/examples/mis/lwd/__init__.py new file mode 100644 index 000000000..a1517b73c --- /dev/null +++ b/examples/mis/lwd/__init__.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch + +from maro.rl.rl_component.rl_component_bundle import RLComponentBundle +from maro.rl.utils.common import get_env +from maro.simulator import Env + +from examples.mis.lwd.config import Config +from examples.mis.lwd.env_sampler.mis_env_sampler import MISEnvSampler, MISPlottingCallback +from examples.mis.lwd.simulator.mis_business_engine import MISBusinessEngine +from examples.mis.lwd.ppo import get_ppo_policy, get_ppo_trainer + + +config = Config() + +# Environments +learn_env = Env( + business_engine_cls=MISBusinessEngine, + durations=config.max_tick, + options={ + "graph_batch_size": config.train_graph_batch_size, + "num_samples": config.train_num_samples, + "device": torch.device(config.device), + "num_node_lower_bound": config.num_node_lower_bound, + "num_node_upper_bound": config.num_node_upper_bound, + "node_sample_probability": config.node_sample_probability, + }, +) + +test_env = Env( + business_engine_cls=MISBusinessEngine, + durations=config.max_tick, + options={ + "graph_batch_size": config.eval_graph_batch_size, + "num_samples": config.eval_num_samples, + "device": torch.device(config.device), + "num_node_lower_bound": config.num_node_lower_bound, + "num_node_upper_bound": config.num_node_upper_bound, + "node_sample_probability": config.node_sample_probability, + }, +) + +# Agent, policy, and trainers +agent2policy = {agent: f"ppo_{agent}.policy" for agent in learn_env.agent_idx_list} + +policies = [ + get_ppo_policy( + name=f"ppo_{agent}.policy", + state_dim=config.input_dim, + action_num=config.output_dim, + hidden_dim=config.hidden_dim, + num_layers=config.num_layers, + init_lr=config.init_lr, + ) + for agent in learn_env.agent_idx_list +] + +trainers = [ + get_ppo_trainer( + name=f"ppo_{agent}", + state_dim=config.input_dim, + hidden_dim=config.hidden_dim, + num_layers=config.num_layers, + init_lr=config.init_lr, + clip_ratio=config.clip_ratio, + max_tick=config.max_tick, + batch_size=config.batch_size, + reward_discount=config.reward_discount, + graph_batch_size=config.train_graph_batch_size, + graph_num_samples=config.train_num_samples, + num_train_epochs=config.num_train_epochs, + norm_base=config.reward_normalization_base, + ) + for agent in learn_env.agent_idx_list +] + +device_mapping = {f"ppo_{agent}.policy": config.device for agent in learn_env.agent_idx_list} + +# Build RLComponentBundle +rl_component_bundle = RLComponentBundle( + env_sampler=MISEnvSampler( + learn_env=learn_env, + test_env=test_env, + policies=policies, + agent2policy=agent2policy, + diversity_reward_coef=config.diversity_reward_coef, + reward_normalization_base=config.reward_normalization_base, + ), + agent2policy=agent2policy, + policies=policies, + trainers=trainers, + device_mapping=device_mapping, + customized_callbacks=[MISPlottingCallback(log_dir=get_env("LOG_PATH", required=False, default="./"))], +) + + +__all__ = ["rl_component_bundle"] diff --git a/examples/mis/lwd/config.py b/examples/mis/lwd/config.py new file mode 100644 index 000000000..eb59ae8ef --- /dev/null +++ b/examples/mis/lwd/config.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +class Config(object): + device: str = "cuda:0" + + # Configuration for graph batch size + train_graph_batch_size = 32 + eval_graph_batch_size = 32 + + # Configuration for num_samples + train_num_samples = 2 + eval_num_samples = 10 + + # Configuration for the MISEnv + max_tick = 32 # Once the max_tick reached, the timeout processing will set all deferred nodes to excluded + num_node_lower_bound: int = 40 + num_node_upper_bound: int = 50 + node_sample_probability: float = 0.15 + + # Configuration for the reward definition + diversity_reward_coef = 0.1 # reward = cardinality reward + coef * diversity Reward + reward_normalization_base = 20 + + # Configuration for the GraphBasedActorCritic + input_dim = 2 + output_dim = 3 + hidden_dim = 128 + num_layers = 5 + + # Configuration for PPO update + init_lr = 1e-4 + clip_ratio = 0.2 + reward_discount = 1.0 + + # Configuration for main loop + batch_size = 16 + num_train_epochs = 4 diff --git a/examples/mis/lwd/config.yml b/examples/mis/lwd/config.yml new file mode 100644 index 000000000..b057b2c8b --- /dev/null +++ b/examples/mis/lwd/config.yml @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +job: mis_lwd +scenario_path: "examples/mis/lwd" +# The log dir where you want to save the training loggings and model checkpoints. +log_path: "examples/mis/lwd/log/test_40_50" +main: + # Number of episodes to run. Each episode is one cycle of roll-out and training. + num_episodes: 1000 + # This can be an integer or a list of integers. An integer indicates the interval at which policies are evaluated. + # A list indicates the episodes at the end of which policies are to be evaluated. Note that episode indexes are + # 1-based. + eval_schedule: [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 600, 700, 800, 900, 1000] + # Number of Episodes to run in evaluation. + num_eval_episodes: 5 + min_n_sample: 1 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG +training: + mode: simple + load_path: null + load_episode: null + checkpointing: + path: null + # Interval at which trained policies / models are persisted to disk. + interval: 200 + logging: + stdout: INFO + file: DEBUG diff --git a/examples/mis/lwd/env_sampler/baseline.py b/examples/mis/lwd/env_sampler/baseline.py new file mode 100644 index 000000000..3282d1a88 --- /dev/null +++ b/examples/mis/lwd/env_sampler/baseline.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import random +from typing import Dict, List + + +def _choose_by_weight(graph: Dict[int, List[int]], node2weight: Dict[int, float]) -> List[int]: + """Choose node in the order of descending weight if not blocked.. + + Args: + graph (Dict[int, List[int]]): The adjacent matrix of the target graph. The key is the node id of each node. The + value is a list of each node's neighbor nodes. + node2weight (Dict[int, float]): The node to weight dictionary with node id as key and node weight as value. + + Returns: + List[int]: A list of chosen node id. + """ + node_weight_list = [(node, weight) for node, weight in node2weight.items()] + # Shuffle the candidates to get random result in the case there are nodes sharing the same weight. + random.shuffle(node_weight_list) + # Sort node candidates with descending weight. + sorted_nodes = sorted(node_weight_list, key=lambda x: x[1], reverse=True) + + chosen_node_id_set: set = set() + blocked_node_id_set: set = set() + # Choose node in the order of descending weight if it is not blocked yet by the chosen nodes. + for node, _ in sorted_nodes: + if node in blocked_node_id_set: + continue + chosen_node_id_set.add(node) + for neighbor_node in graph[node]: + blocked_node_id_set.add(neighbor_node) + + chosen_node_ids = [node for node in chosen_node_id_set] + return chosen_node_ids + +def uniform_mis_solver(graph: Dict[int, List[int]]) -> List[int]: + node2weight: Dict[int, float] = {node: 1 for node in graph.keys()} + chosen_node_list = _choose_by_weight(graph, node2weight) + return chosen_node_list + +def greedy_mis_solver(graph: Dict[int, List[int]]) -> List[int]: + node2weight: Dict[int, float] = {node: 1 / (1 + len(neighbor_list)) for node, neighbor_list in graph.items()} + chosen_node_list = _choose_by_weight(graph, node2weight) + return chosen_node_list diff --git a/examples/mis/lwd/env_sampler/mis_env_sampler.py b/examples/mis/lwd/env_sampler/mis_env_sampler.py new file mode 100644 index 000000000..c0aaabe60 --- /dev/null +++ b/examples/mis/lwd/env_sampler/mis_env_sampler.py @@ -0,0 +1,343 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch + +from maro.rl.policy.abs_policy import AbsPolicy +from maro.rl.rollout import SimpleAgentWrapper, AbsEnvSampler, CacheElement +from maro.rl.utils import ndarray_to_tensor +from maro.rl.workflows.callback import Callback +from maro.simulator.core import Env + +from examples.mis.lwd.env_sampler.baseline import greedy_mis_solver, uniform_mis_solver +from examples.mis.lwd.ppo.ppo import GraphBasedPPOPolicy +from examples.mis.lwd.ppo.replay_memory import GraphBasedExpElement +from examples.mis.lwd.simulator import Action, MISDecisionPayload, MISEnvMetrics, MISBusinessEngine + + +class MISAgentWrapper(SimpleAgentWrapper): + def __init__(self, policy_dict: Dict[str, AbsPolicy], agent2policy: Dict[Any, str]) -> None: + super().__init__(policy_dict, agent2policy) + + def _choose_actions_impl(self, state_by_agent: Dict[Any, torch.Tensor]) -> Dict[Any, np.ndarray]: + assert len(state_by_agent) == 1 + for agent_name, state in state_by_agent.items(): + break + + policy_name = self._agent2policy[agent_name] + policy = self._policy_dict[policy_name] + assert isinstance(policy, GraphBasedPPOPolicy) + + assert isinstance(state, Tuple) + assert len(state) == 2 + action = policy.get_actions(state[0], graph=state[1]) + return {agent_name: action} + + +class MISEnvSampler(AbsEnvSampler): + def __init__( + self, + learn_env: Env, + test_env: Env, + policies: List[AbsPolicy], + agent2policy: Dict[Any, str], + trainable_policies: List[str] = None, + reward_eval_delay: int = None, + max_episode_length: int = None, + diversity_reward_coef: float = 0.1, + reward_normalization_base: float = None + ) -> None: + super(MISEnvSampler, self).__init__( + learn_env=learn_env, + test_env=test_env, + policies=policies, + agent2policy=agent2policy, + trainable_policies=trainable_policies, + agent_wrapper_cls=MISAgentWrapper, + reward_eval_delay=reward_eval_delay, + max_episode_length=max_episode_length, + ) + be = learn_env.business_engine + assert isinstance(be, MISBusinessEngine) + self._device = be._device + self._diversity_reward_coef = diversity_reward_coef + self._reward_normalization_base = reward_normalization_base + + self._sample_metrics: List[tuple] = [] + self._eval_metrics: List[tuple] = [] + + def _get_global_and_agent_state( + self, + event: Any, + tick: int = None, + ) -> Tuple[Optional[Any], Dict[Any, Union[np.ndarray, list]]]: + return self._get_global_and_agent_state_impl(event, tick) + + def _get_global_and_agent_state_impl( + self, + event: MISDecisionPayload, + tick: int = None, + ) -> Tuple[Union[None, np.ndarray, list], Dict[Any, Union[np.ndarray, list]]]: + vertex_states = event.vertex_states.unsqueeze(2).float().cpu() + normalized_tick = torch.full(vertex_states.size(), tick / self.env._business_engine._max_tick) + state = torch.cat([vertex_states, normalized_tick], dim=2).cpu().detach().numpy() + return None, {0: (state, event.graph)} + + def _translate_to_env_action(self, action_dict: dict, event: Any) -> dict: + return {k: Action(vertex_states=ndarray_to_tensor(v, self._device)) for k, v in action_dict.items()} + + def _get_reward(self, env_action_dict: dict, event: Any, tick: int) -> Dict[Any, float]: + be = self.env.business_engine + assert isinstance(be, MISBusinessEngine) + + cardinality_record_dict = self.env.metrics[MISEnvMetrics.IncludedNodeCount] + hamming_distance_record_dict = self.env.metrics[MISEnvMetrics.HammingDistanceAmongSamples] + + assert (tick - 1) in cardinality_record_dict + cardinality_reward = cardinality_record_dict[tick - 1] - cardinality_record_dict.get(tick - 2, 0) + + assert (tick - 1) in hamming_distance_record_dict + diversity_reward = hamming_distance_record_dict[tick - 1] + + reward = cardinality_reward + self._diversity_reward_coef * diversity_reward + reward /= self._reward_normalization_base + + return {0: reward.cpu().detach().numpy()} + + def _eval_baseline(self) -> Dict[str, float]: + assert isinstance(self.env.business_engine, MISBusinessEngine) + graph_list: List[Dict[int, List[int]]] = self.env.business_engine._batch_adjs + num_samples: int = self.env.business_engine._num_samples + + def best_among_samples( + solver: Callable[[Dict[int, List[int]]], List[int]], + num_samples: int, + graph: Dict[int, List[int]], + ) -> int: + res = 0 + for _ in range(num_samples): + res = max(res, len(solver(graph))) + return res + + graph_size_list = [len(graph) for graph in graph_list] + uniform_size_list = [best_among_samples(uniform_mis_solver, num_samples, graph) for graph in graph_list] + greedy_size_list = [best_among_samples(greedy_mis_solver, num_samples, graph) for graph in graph_list] + + return { + "graph_size": np.mean(graph_size_list), + "uniform_size": np.mean(uniform_size_list), + "greedy_size": np.mean(greedy_size_list), + } + + def sample(self, policy_state: Optional[Dict[str, Dict[str, Any]]] = None, num_steps: Optional[int] = None) -> dict: + if policy_state is not None: # Update policy state if necessary + self.set_policy_state(policy_state) + self._switch_env(self._learn_env) # Init the env + self._agent_wrapper.explore() # Collect experience + + # One complete episode in one sample call here. + self._reset() + experiences: List[GraphBasedExpElement] = [] + + while not self._end_of_episode: + state = self._agent_state_dict[0][0] + graph = self._agent_state_dict[0][1] + # Get agent actions and translate them to env actions + action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict) + env_action_dict = self._translate_to_env_action(action_dict, self._event) + # Update env and get new states (global & agent) + self._step(list(env_action_dict.values())) + assert self._reward_eval_delay is None + reward_dict = self._get_reward(env_action_dict, self._event, self.env.tick) + + experiences.append( + GraphBasedExpElement( + state=state, + action=action_dict[0], + reward=reward_dict[0], + is_done=self.env.metrics[MISEnvMetrics.IsDoneMasks], + graph=graph, + ) + ) + + self._total_number_interactions += 1 + self._current_episode_length += 1 + self._post_step(None) + + return { + "experiences": [experiences], + "info": [], + } + + def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict: + self._switch_env(self._test_env) + info_list = [] + + for _ in range(num_episodes): + self._reset() + + baseline_info = self._eval_baseline() + info_list.append(baseline_info) + + if policy_state is not None: + self.set_policy_state(policy_state) + + self._agent_wrapper.exploit() + while not self._end_of_episode: + action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict) + env_action_dict = self._translate_to_env_action(action_dict, self._event) + # Update env and get new states (global & agent) + self._step(list(env_action_dict.values())) + + self._post_eval_step(None) + + return {"info": info_list} + + def _post_step(self, cache_element: CacheElement) -> None: + if not (self._end_of_episode or self.truncated): + return + + node_count_record = self.env.metrics[MISEnvMetrics.IncludedNodeCount] + num_steps = max(node_count_record.keys()) + 1 + # Report the mean among samples as the average MIS size in rollout. + num_nodes = torch.mean(node_count_record[num_steps - 1]).item() + + self._sample_metrics.append((num_steps, num_nodes)) + + def _post_eval_step(self, cache_element: CacheElement) -> None: + if not (self._end_of_episode or self.truncated): + return + + node_count_record = self.env.metrics[MISEnvMetrics.IncludedNodeCount] + num_steps = max(node_count_record.keys()) + 1 + # Report the maximum among samples as the MIS size in evaluation. + num_nodes = torch.mean(torch.max(node_count_record[num_steps - 1], dim=1).values).item() + + self._eval_metrics.append((num_steps, num_nodes)) + + def post_collect(self, info_list: list, ep: int) -> None: + assert len(self._sample_metrics) == 1, f"One Episode for One Rollout/Collection in Current Workflow Design." + num_steps, mis_size = self._sample_metrics[0] + + cur = { + "n_steps": num_steps, + "avg_mis_size": mis_size, + "n_interactions": self._total_number_interactions, + } + self.metrics.update(cur) + + # clear validation metrics + self.metrics = {k: v for k, v in self.metrics.items() if not k.startswith("val/")} + self._sample_metrics.clear() + + def post_evaluate(self, info_list: list, ep: int) -> None: + num_eval = len(self._eval_metrics) + assert num_eval > 0, f"Num evaluation rounds much be positive!" + + cur = { + "val/num_eval": num_eval, + "val/avg_n_steps": np.mean([n for n, _ in self._eval_metrics]), + "val/avg_mis_size": np.mean([s for _, s in self._eval_metrics]), + "val/std_mis_size": np.std([s for _, s in self._eval_metrics]), + "val/avg_graph_size": np.mean([info["graph_size"] for info in info_list]), + "val/uniform_size": np.mean([info["uniform_size"] for info in info_list]), + "val/uniform_std": np.std([info["uniform_size"] for info in info_list]), + "val/greedy_size": np.mean([info["greedy_size"] for info in info_list]), + "val/greedy_std": np.std([info["greedy_size"] for info in info_list]), + } + self.metrics.update(cur) + self._eval_metrics.clear() + + +class MISPlottingCallback(Callback): + def __init__(self, log_dir: str) -> None: + super().__init__() + self._log_dir = log_dir + + def _plot_trainer(self) -> None: + for trainer_name in self.training_manager._trainer_dict.keys(): + df = pd.read_csv(os.path.join(self._log_dir, f"{trainer_name}.csv")) + columns = [ + "Steps", + "Mean Reward", + "Mean Return", + "Mean Advantage", + "0-Action", + "1-Action", + "2-Action", + "Critic Loss", + "Actor Loss", + ] + + _, axes = plt.subplots(len(columns), sharex=False, figsize=(20, 21)) + + for col, ax in zip(columns, axes): + data = df[col].dropna().to_numpy()[:-1] + ax.plot(data, label=col) + ax.legend() + + plt.tight_layout() + plt.savefig(os.path.join(self._log_dir, f"{trainer_name}.png")) + plt.close() + plt.cla() + plt.clf() + + @staticmethod + def _plot_mean_std( + ax: plt.Axes, x: np.ndarray, y_mean: np.ndarray, y_std: np.ndarray, color: str, label: str, + ) -> None: + ax.plot(x, y_mean, label=label, color=color) + ax.fill_between(x, y_mean - y_std, y_mean + y_std, color=color, alpha=0.2) + + def _plot_metrics(self) -> None: + df_rollout = pd.read_csv(os.path.join(self._log_dir, "metrics_full.csv")) + df_eval = pd.read_csv(os.path.join(self._log_dir, "metrics_valid.csv")) + + n_steps = df_rollout["n_steps"].to_numpy() + eval_n_steps = df_eval["val/avg_n_steps"].to_numpy() + + eval_ep = df_eval["ep"].to_numpy() + graph_size = df_eval["val/avg_graph_size"].to_numpy() + + mis_size = df_rollout["avg_mis_size"].to_numpy() + eval_mis_size, eval_size_std = df_eval["val/avg_mis_size"].to_numpy(), df_eval["val/std_mis_size"].to_numpy() + uniform_size, uniform_std = df_eval["val/uniform_size"].to_numpy(), df_eval["val/uniform_std"].to_numpy() + greedy_size, greedy_std = df_eval["val/greedy_size"].to_numpy(), df_eval["val/greedy_std"].to_numpy() + + fig, (ax_t, ax_gsize, ax_mis) = plt.subplots(3, sharex=False, figsize=(20, 15)) + + color_map = { + "rollout": "cornflowerblue", + "eval": "orange", + "uniform": "green", + "greedy": "firebrick", + } + + ax_t.plot(n_steps, label="n_steps", color=color_map["rollout"]) + ax_t.plot(eval_ep, eval_n_steps, label="val/n_steps", color=color_map["eval"]) + + ax_gsize.plot(eval_ep, graph_size, label="val/avg_graph_size", color=color_map["eval"]) + + ax_mis.plot(mis_size, label="avg_mis_size", color=color_map["rollout"]) + self._plot_mean_std(ax_mis, eval_ep, eval_mis_size, eval_size_std, color_map["eval"], "val/mis_size") + self._plot_mean_std(ax_mis, eval_ep, uniform_size, uniform_std, color_map["uniform"], "val/uniform_size") + self._plot_mean_std(ax_mis, eval_ep, greedy_size, greedy_std, color_map["greedy"], "val/greedy_size") + + for ax in fig.get_axes(): + ax.legend() + + plt.tight_layout() + plt.savefig(os.path.join(self._log_dir, "metrics.png")) + plt.close() + plt.cla() + plt.clf() + + def on_validation_end(self, ep: int) -> None: + self._plot_trainer() + self._plot_metrics() diff --git a/examples/mis/lwd/figures/fig1_mdp.png b/examples/mis/lwd/figures/fig1_mdp.png new file mode 100644 index 000000000..2e6a3e5c2 Binary files /dev/null and b/examples/mis/lwd/figures/fig1_mdp.png differ diff --git a/examples/mis/lwd/figures/fig2_transition.png b/examples/mis/lwd/figures/fig2_transition.png new file mode 100644 index 000000000..6e1963536 Binary files /dev/null and b/examples/mis/lwd/figures/fig2_transition.png differ diff --git a/examples/mis/lwd/figures/fig3_diversity_reward.png b/examples/mis/lwd/figures/fig3_diversity_reward.png new file mode 100644 index 000000000..ac071b9c1 Binary files /dev/null and b/examples/mis/lwd/figures/fig3_diversity_reward.png differ diff --git a/examples/mis/lwd/ppo/__init__.py b/examples/mis/lwd/ppo/__init__.py new file mode 100644 index 000000000..607979975 --- /dev/null +++ b/examples/mis/lwd/ppo/__init__.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from maro.rl.training.algorithms.ppo import PPOParams +from maro.rl.utils.common import get_env + +from examples.mis.lwd.ppo.model import GraphBasedPolicyNet, GraphBasedVNet +from examples.mis.lwd.ppo.ppo import GraphBasedPPOPolicy, GraphBasedPPOTrainer + + +def get_ppo_policy( + name: str, + state_dim: int, + action_num: int, + hidden_dim: int, + num_layers: int, + init_lr: float, +) -> GraphBasedPPOPolicy: + return GraphBasedPPOPolicy( + name=name, + policy_net=GraphBasedPolicyNet( + state_dim=state_dim, + action_num=action_num, + hidden_dim=hidden_dim, + num_layers=num_layers, + init_lr=init_lr, + ), + ) + + +def get_ppo_trainer( + name: str, + state_dim: int, + hidden_dim: int, + num_layers: int, + init_lr: float, + clip_ratio: float, + max_tick: int, + batch_size: int, + reward_discount: float, + graph_batch_size: int, + graph_num_samples: int, + num_train_epochs: int, + norm_base: float, +) -> GraphBasedPPOTrainer: + return GraphBasedPPOTrainer( + name=name, + params=PPOParams( + get_v_critic_net_func=lambda: GraphBasedVNet(state_dim, hidden_dim, num_layers, init_lr, norm_base), + grad_iters=1, + lam=None, # GAE not used here. + clip_ratio=clip_ratio, + ), + replay_memory_capacity=max_tick, + batch_size=batch_size, + reward_discount=reward_discount, + graph_batch_size=graph_batch_size, + graph_num_samples=graph_num_samples, + input_feature_size=state_dim, + num_train_epochs=num_train_epochs, + log_dir=get_env("LOG_PATH", required=False, default=None), + ) + + +__all__ = ["get_ppo_policy", "get_ppo_trainer"] diff --git a/examples/mis/lwd/ppo/graph_net.py b/examples/mis/lwd/ppo/graph_net.py new file mode 100644 index 000000000..e4c34d1a4 --- /dev/null +++ b/examples/mis/lwd/ppo/graph_net.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import dgl +import dgl.function as fn +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GraphConvLayer(nn.Module): + def __init__( + self, in_feats: int, out_feats: int, norm: bool=True, jump: bool=True, bias: bool=True, activation=None, + ) -> None: + """The GraphConvLayer that can forward based on the given graph and graph node features. + + Args: + in_feats (int): The dimension of input feature. + out_feats (int): The dimension of the output tensor. + norm (bool): Add feature normalization operation or not. Defaults to True. + jump (bool): Add skip connections of the input feature to the aggregation or not. Defaults to True. + bias (bool): Add a learnable bias layer or not. Defaults to True. + activation (torch.nn.functional): The output activation function to use. Defaults to None. + """ + super(GraphConvLayer, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._norm = norm + self._jump = jump + self._activation = activation + + if jump: + self.weight = nn.Parameter(torch.Tensor(2 * in_feats, out_feats)) + else: + self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats)) + + if bias: + self.bias = nn.Parameter(torch.Tensor(out_feats)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + torch.nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + torch.nn.init.zeros_(self.bias) + + def forward(self, feat: torch.Tensor, graph: dgl.DGLGraph, mask=None) -> torch.Tensor: + if self._jump: + _feat = feat + + if self._norm: + if mask is None: + norm = torch.pow(graph.in_degrees().float(), -0.5) + norm.masked_fill_(graph.in_degrees() == 0, 1.0) + shp = norm.shape + (1,) * (feat.dim() - 1) + norm = torch.reshape(norm, shp).to(feat.device) + feat = feat * norm.unsqueeze(1) + else: + graph.ndata["h"] = mask.float() + graph.update_all(fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")) + masked_deg = graph.ndata.pop("h") + norm = torch.pow(masked_deg, -0.5) + norm.masked_fill_(masked_deg == 0, 1.0) + feat = feat * norm.unsqueeze(-1) + + if mask is not None: + feat = mask.float().unsqueeze(-1) * feat + + graph.ndata["h"] = feat + graph.update_all(fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")) + rst = graph.ndata.pop("h") + + if self._norm: + rst = rst * norm.unsqueeze(-1) + + if self._jump: + rst = torch.cat([rst, _feat], dim=-1) + + rst = torch.matmul(rst, self.weight) + + if self.bias is not None: + rst = rst + self.bias + + if self._activation is not None: + rst = self._activation(rst) + + return rst + + +class GraphConvNet(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dim: int, + num_layers: int, + activation=F.relu, + out_activation=None, + ) -> None: + """The GraphConvNet constructed with multiple GraphConvLayers. + + Args: + input_dim (int): The dimension of the input feature. + output_dim (int): The dimension of the output tensor. + hidden_dim (int): The dimension of the hidden layers. + num_layers (int): How many layers in this GraphConvNet in total, including the input and output layer. >= 2. + activation (torch.nn.functional): The activation function used in input layer and hidden layers. Defaults to + torch.nn.functional.relu. + out_activation (torch.nn.functional): The output activation function to use. Defaults to None. + """ + super(GraphConvNet, self).__init__() + + self.layers = nn.ModuleList( + [GraphConvLayer(input_dim, hidden_dim, activation=activation)] + + [GraphConvLayer(hidden_dim, hidden_dim, activation=activation) for _ in range(num_layers - 2)] + + [GraphConvLayer(hidden_dim, output_dim, activation=out_activation)] + ) + + def forward(self, h: torch.Tensor, graph: dgl.DGLGraph, mask=None) -> torch.Tensor: + for layer in self.layers: + h = layer(h, graph, mask=mask) + return h diff --git a/examples/mis/lwd/ppo/model.py b/examples/mis/lwd/ppo/model.py new file mode 100644 index 000000000..f6fab8b99 --- /dev/null +++ b/examples/mis/lwd/ppo/model.py @@ -0,0 +1,209 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Tuple + +import dgl +import torch +from torch.distributions import Categorical +from torch.optim import Adam + +from maro.rl.model import VNet, DiscretePolicyNet + +from examples.mis.lwd.ppo.graph_net import GraphConvNet +from examples.mis.lwd.simulator import VertexState + + +VertexStateIndex = 0 + + +def get_masks_idxs_subgraph_h(obs: torch.Tensor, graph: dgl.DGLGraph) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, dgl.DGLGraph, torch.Tensor, +]: + """Extract masks and input feature information for the deferred vertexes. + + Args: + obs (torch.Tensor): The observation tensor with shape (num_nodes, num_samples, feature_size). + graph (dgl.DGLGraph): The input graph. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - undecided_node_mask, with shape (num_nodes, num_samples) + - subgraph_mask, with shape (num_nodes) + - subgraph_node_mask, with shape (num_nodes, num_samples) + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - flatten_node_idxs, with shape (num_nodes * num_samples) + - flatten_subgraph_idxs, with shape (num_nodes) + - flatten_subgraph_node_idxs, with shape (num_nodes * num_samples) + dgl.DGLGraph: + torch.Tensor: input tensor with shape (num_nodes, num_samples, 2) + """ + # Mask tensor with shape (num_nodes, num_samples) + undecided_node_mask = obs.select(2, VertexStateIndex).long() == VertexState.Deferred + # Flatten index tensor with shape (num_nodes * num_samples) + flatten_node_idxs = undecided_node_mask.reshape(-1).nonzero().squeeze(1) + + # Mask tensor with shape (num_nodes) + subgraph_mask = undecided_node_mask.any(dim=1) + # Flatten index tensor with shape (num_nodes) + flatten_subgraph_idxs = subgraph_mask.nonzero().squeeze(1) + + # Mask tensor with shape (num_nodes, num_samples) + subgraph_node_mask = undecided_node_mask.index_select(0, flatten_subgraph_idxs) + # Flatten index tensor with shape (num_nodes * num_samples) + flatten_subgraph_node_idxs = subgraph_node_mask.view(-1).nonzero().squeeze(1) + + # Extract a subgraph with only node in flatten_subgraph_idxs, batch_size -> 1 + subgraph = graph.subgraph(flatten_subgraph_idxs) + + # The observation of the deferred vertexes. + h = obs.index_select(0, flatten_subgraph_idxs) + + num_nodes, num_samples = obs.size(0), obs.size(1) + return subgraph_node_mask, flatten_node_idxs, flatten_subgraph_node_idxs, subgraph, h, num_nodes, num_samples + + +class GraphBasedPolicyNet(DiscretePolicyNet): + def __init__( + self, + state_dim: int, + action_num: int, + hidden_dim: int, + num_layers: int, + init_lr: float, + ) -> None: + """A discrete policy net implemented with a graph as input. + + Args: + state_dim (int): The dimension of the input state for this policy net. + action_num (int): The number of pre-defined discrete actions, i.e., the size of the discrete action space. + hidden_dim (int): The dimension of the hidden layers used in the GraphConvNet of the actor. + num_layers (int): The number of layers of the GraphConvNet of the actor. + init_lr (float): The initial learning rate of the optimizer. + """ + super(GraphBasedPolicyNet, self).__init__(state_dim, action_num) + + self._actor = GraphConvNet( + input_dim=state_dim, + output_dim=action_num, + hidden_dim=hidden_dim, + num_layers=num_layers, + ) + with torch.no_grad(): + self._actor.layers[-1].bias[2].add_(3.0) + + self._optim = Adam(self._actor.parameters(), lr=init_lr) + + def _get_action_probs_impl(self, states: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def _get_actions_impl(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor: + action, _ = self._get_actions_with_logps_impl(states, exploring, **kwargs) + return action + + def _get_actions_with_probs_impl( + self, states: torch.Tensor, exploring: bool, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def _get_actions_with_logps_impl( + self, states: torch.Tensor, exploring: bool, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert "graph" in kwargs, f"graph is required to given in kwargs" + graph = kwargs["graph"] + subg_mask, node_idxs, subg_node_idxs, subg, h, num_nodes, num_samples = get_masks_idxs_subgraph_h(states, graph) + + # Compute logits to get action, logits: shape (num_nodes * num_samples, 3) + logits = self._actor(h, subg, mask=subg_mask).view(-1, self.action_num).index_select(0, subg_node_idxs) + + action = torch.zeros(num_nodes * num_samples, dtype=torch.long, device=self._device) + action_log_probs = torch.zeros(num_nodes * num_samples, device=self._device) + + # NOTE: here we do not distinguish exploration mode and exploitation mode. + # The main reason here for doing so is that the LwD modeling is learnt to better exploration, + # the final result is chosen from the sampled trajectories. + m = Categorical(logits=logits) + action[node_idxs] = m.sample() + action_log_probs[node_idxs] = m.log_prob(action.index_select(0, node_idxs)) + + action = action.view(-1, num_samples) + action_log_probs = action_log_probs.view(-1, num_samples) + + return action, action_log_probs + + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor: + assert "graph" in kwargs, f"graph is required to given in kwargs" + graph = kwargs["graph"] + subg_mask, node_idxs, subg_node_idxs, subg, h, num_nodes, num_samples = get_masks_idxs_subgraph_h(states, graph) + + # compute logits to get action + logits = self._actor(h, subg, mask=subg_mask).view(-1, self.action_num).index_select(0, subg_node_idxs) + + try: + m = Categorical(logits=logits) + except Exception: + print(f"[GraphBasedPolicyNet] flatten_subgraph_node_idxs with shape {subg_node_idxs.shape}") + print(f"[GraphBasedPolicyNet] logits with shape {logits.shape}") + return None + + # compute log probability of actions per node + actions = actions.reshape(-1) + action_log_probs = torch.zeros(num_nodes * num_samples, device=self._device) + action_log_probs[node_idxs] = m.log_prob(actions.index_select(0, node_idxs)) + action_log_probs = action_log_probs.view(-1, num_samples) + + return action_log_probs + + def get_actions(self, states: torch.Tensor, exploring: bool, **kwargs) -> torch.Tensor: + actions = self._get_actions_impl(states, exploring, **kwargs) + return actions + + def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor: + logps = self._get_states_actions_logps_impl(states, actions, **kwargs) + return logps + + +class GraphBasedVNet(VNet): + def __init__(self, state_dim: int, hidden_dim: int, num_layers: int, init_lr: float, norm_base: float) -> None: + """A value net implemented with a graph as input. + + Args: + state_dim (int): The dimension of the input state for this value network. + hidden_dim (int): The dimension of the hidden layers used in the GraphConvNet of the critic. + num_layers (int): The number of layers of the GraphConvNet of the critic. + init_lr (float): The initial learning rate of the optimizer. + norm_base (float): The normalization base for the predicted value. The critic network will predict the value + of each node, and the returned v value is defined as `Sum(predicted_node_values) / normalization_base`. + """ + super(GraphBasedVNet, self).__init__(state_dim) + self._critic = GraphConvNet( + input_dim=state_dim, + output_dim=1, + hidden_dim=hidden_dim, + num_layers=num_layers, + ) + self._optim = Adam(self._critic.parameters(), lr=init_lr) + self._normalization_base = norm_base + + def _get_v_values(self, states: torch.Tensor, **kwargs) -> torch.Tensor: + assert "graph" in kwargs, f"graph is required to given in kwargs" + graph = kwargs["graph"] + subg_mask, node_idxs, subg_node_idxs, subg, h, num_nodes, num_samples = get_masks_idxs_subgraph_h(states, graph) + + values = self._critic(h, subg, mask=subg_mask).view(-1).index_select(0, subg_node_idxs) + # Init node value prediction, shape (num_nodes * num_samples) + node_value_preds = torch.zeros(num_nodes * num_samples, device=self._device) + node_value_preds[node_idxs] = values + + graph.ndata["h"] = node_value_preds.view(-1, num_samples) + value_pred = dgl.sum_nodes(graph, "h") / self._normalization_base + graph.ndata.pop("h") + + return value_pred + + def v_values(self, states: torch.Tensor, **kwargs) -> torch.Tensor: + v = self._get_v_values(states, **kwargs) + return v diff --git a/examples/mis/lwd/ppo/ppo.py b/examples/mis/lwd/ppo/ppo.py new file mode 100644 index 000000000..673fcf79a --- /dev/null +++ b/examples/mis/lwd/ppo/ppo.py @@ -0,0 +1,186 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import math +from typing import List, Tuple, cast + +import dgl +import torch + +from maro.rl.model import DiscretePolicyNet +from maro.rl.policy import DiscretePolicyGradient, RLPolicy +from maro.rl.training.algorithms.base import ACBasedOps, ACBasedParams, ACBasedTrainer +from maro.rl.utils import ndarray_to_tensor +from maro.utils import LogFormat, Logger + +from examples.mis.lwd.ppo.replay_memory import BatchGraphReplayMemory, GraphBasedExpElement, GraphBasedTransitionBatch + + +class GraphBasedPPOTrainOps(ACBasedOps): + def __init__( + self, + name: str, + policy: RLPolicy, + params: ACBasedParams, + reward_discount: float, + parallelism: int = 1, + ) -> None: + super().__init__(name, policy, params, reward_discount, parallelism) + + self._clip_lower_bound = math.log(1.0 - self._clip_ratio) + self._clip_upper_bound = math.log(1.0 + self._clip_ratio) + + def _get_critic_loss(self, batch: GraphBasedTransitionBatch) -> torch.Tensor: + states = ndarray_to_tensor(batch.states, device=self._device).permute(1, 0, 2) + returns = ndarray_to_tensor(batch.returns, device=self._device) + + self._v_critic_net.train() + kwargs = {"graph": batch.graph} + value_preds = self._v_critic_net.v_values(states, **kwargs).permute(1, 0) + critic_loss = 0.5 * (value_preds - returns).pow(2).mean() + return critic_loss + + def _get_actor_loss(self, batch: GraphBasedTransitionBatch) -> Tuple[torch.Tensor, bool]: + graph = batch.graph + kwargs = {"graph": batch.graph} + + states = ndarray_to_tensor(batch.states, device=self._device).permute(1, 0, 2) + actions = ndarray_to_tensor(batch.actions, device=self._device).permute(1, 0) + advantages = ndarray_to_tensor(batch.advantages, device=self._device) + logps_old = ndarray_to_tensor(batch.old_logps, device=self._device) + if self._is_discrete_action: + actions = actions.long() + + self._policy.train() + logps = self._policy.get_states_actions_logps(states, actions, **kwargs).permute(1, 0) + diff = logps - logps_old + clamped_diff = torch.clamp(diff, self._clip_lower_bound, self._clip_upper_bound) + stacked_diff = torch.stack([diff, clamped_diff], dim=2) + + graph.ndata["h"] = stacked_diff.permute(1, 0, 2) + h = dgl.sum_nodes(graph, "h").permute(1, 0, 2) + graph.ndata.pop("h") + ratio = torch.exp(h.select(2, 0)) + clamped_ratio = torch.exp(h.select(2, 1)) + + actor_loss = -torch.min(ratio * advantages, clamped_ratio * advantages).mean() + return actor_loss, False # TODO: add early-stop logic if needed + + +class GraphBasedPPOTrainer(ACBasedTrainer): + def __init__( + self, + name: str, + params: ACBasedParams, + replay_memory_capacity: int = 32, + batch_size: int = 16, + data_parallelism: int = 1, + reward_discount: float = 0.9, + graph_batch_size: int = 4, + graph_num_samples: int = 2, + input_feature_size: int = 2, + num_train_epochs: int = 4, + log_dir: str = None, + ) -> None: + super().__init__(name, params, replay_memory_capacity, batch_size, data_parallelism, reward_discount) + + self._graph_batch_size = graph_batch_size + self._graph_num_samples = graph_num_samples + self._input_feature_size = input_feature_size + self._num_train_epochs = num_train_epochs + + self._trainer_logger = None + if log_dir is not None: + self._trainer_logger = Logger( + tag=self.name, + format_=LogFormat.none, + dump_folder=log_dir, + dump_mode="w", + extension_name="csv", + auto_timestamp=False, + stdout_level="INFO", + ) + self._trainer_logger.debug( + "Steps,Mean Reward,Mean Return,Mean Advantage,0-Action,1-Action,2-Action,Critic Loss,Actor Loss" + ) + + def build(self) -> None: + self._ops = cast(GraphBasedPPOTrainOps, self.get_ops()) + self._replay_memory = BatchGraphReplayMemory( + max_t=self._replay_memory_capacity, + graph_batch_size=self._graph_batch_size, + num_samples=self._graph_num_samples, + feature_size=self._input_feature_size, + ) + + def record_multiple(self, env_idx: int, exp_elements: List[GraphBasedExpElement]) -> None: + self._replay_memory.reset() + + self._ops._v_critic_net.eval() + self._ops._policy.eval() + + for exp in exp_elements: + state = ndarray_to_tensor(exp.state, self._ops._device) + action = ndarray_to_tensor(exp.action, self._ops._device) + value_pred = self._ops._v_critic_net.v_values(state, graph=exp.graph).cpu().detach().numpy() + logps = self._ops._policy.get_states_actions_logps(state, action, graph=exp.graph).cpu().detach().numpy() + self._replay_memory.add_transition(exp, value_pred, logps) + + self._ops._v_critic_net.train() + self._ops._policy.train() + + def get_local_ops(self) -> GraphBasedPPOTrainOps: + return GraphBasedPPOTrainOps( + name=self._policy.name, + policy=self._policy, + parallelism=self._data_parallelism, + reward_discount=self._reward_discount, + params=self._params, + ) + + def train_step(self) -> None: + assert isinstance(self._ops, GraphBasedPPOTrainOps) + + data_loader = self._replay_memory.build_update_sampler( + batch_size=self._batch_size, + num_train_epochs=self._num_train_epochs, + gamma=self._reward_discount, + ) + + statistics = self._replay_memory.get_statistics() + + avg_critic_loss, avg_actor_loss = 0, 0 + for batch in data_loader: + for _ in range(self._params.grad_iters): + critic_loss = self._ops.update_critic(batch) + actor_loss, _ = self._ops.update_actor(batch) + avg_critic_loss += critic_loss + avg_actor_loss += actor_loss + avg_critic_loss /= self._params.grad_iters + avg_actor_loss /= self._params.grad_iters + + if self._trainer_logger is not None: + self._trainer_logger.debug( + f"{statistics['step_t']}," + f"{statistics['reward']}," + f"{statistics['return']}," + f"{statistics['advantage']}," + f"{statistics['action_0']}," + f"{statistics['action_1']}," + f"{statistics['action_2']}," + f"{avg_critic_loss}," + f"{avg_actor_loss}" + ) + + +class GraphBasedPPOPolicy(DiscretePolicyGradient): + def __init__(self, name: str, policy_net: DiscretePolicyNet, trainable: bool = True, warmup: int = 0) -> None: + super(GraphBasedPPOPolicy, self).__init__(name, policy_net, trainable, warmup) + + def get_actions_tensor(self, states: torch.Tensor, **kwargs) -> torch.Tensor: + actions = self._get_actions_impl(states, **kwargs) + return actions + + def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor, **kwargs) -> torch.Tensor: + logps = self._get_states_actions_logps_impl(states, actions, **kwargs) + return logps diff --git a/examples/mis/lwd/ppo/replay_memory.py b/examples/mis/lwd/ppo/replay_memory.py new file mode 100644 index 000000000..fc6f2a185 --- /dev/null +++ b/examples/mis/lwd/ppo/replay_memory.py @@ -0,0 +1,150 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict + +import dgl +import numpy as np +from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler + + +@dataclass +class GraphBasedExpElement: + state: np.ndarray + action: np.ndarray + reward: np.ndarray + is_done: np.ndarray + graph: dgl.DGLGraph + + def split_contents_by_trainer(self, agent2trainer: Dict[Any, str]) -> Dict[str, GraphBasedExpElement]: + """Split the ExpElement's contents by trainer. + + Args: + agent2trainer (Dict[Any, str]): Mapping of agent name and trainer name. + + Returns: + Contents (Dict[str, ExpElement]): A dict that contains the ExpElements of all trainers. The key of this + dict is the trainer name. + """ + return {trainer_name: self for trainer_name in agent2trainer.values()} + + +@dataclass +class GraphBasedTransitionBatch: + states: np.ndarray + actions: np.ndarray + returns: np.ndarray + graph: dgl.DGLGraph + advantages: np.ndarray + old_logps: np.ndarray + + +class BatchGraphReplayMemory: + def __init__( + self, + max_t: int, + graph_batch_size: int, + num_samples: int, + feature_size: int, + ) -> None: + self.max_t = max_t + self.graph_batch_size: int = graph_batch_size + self.num_nodes: int = None + self.num_samples: int = num_samples + self.feature_size: int = feature_size + self._t = 0 + + self.graph: dgl.DGLGraph = None + self.states: np.ndarray = None # shape [max_t + 1, num_nodes, num_samples, feature_size] + self.actions: np.ndarray = None # shape [max_t, num_nodes, num_samples] + self.action_logps: np.ndarray = None + + array_size = (self.max_t + 1, graph_batch_size, num_samples) + self.rewards: np.ndarray = np.zeros(array_size, dtype=np.float32) + self.is_done: np.ndarray = np.ones(array_size, dtype=np.int8) + self.returns: np.ndarray = np.zeros(array_size, dtype=np.float32) + self.advantages: np.ndarray = np.zeros(array_size, dtype=np.float32) + self.value_preds: np.ndarray = np.zeros(array_size, dtype=np.float32) + + def _init_storage(self, graph: dgl.DGLGraph) -> None: + self.graph = graph + self.num_nodes = graph.num_nodes() + + array_size = (self.max_t + 1, self.num_nodes, self.num_samples) + self.states = np.zeros((*array_size, self.feature_size), dtype=np.float32) + self.actions = np.zeros(array_size, dtype=np.float32) + self.action_logps = np.zeros(array_size, dtype=np.float32) + + def add_transition(self, exp_element: GraphBasedExpElement, value_pred: np.ndarray, logps: np.ndarray) -> None: + if self._t == 0: + assert exp_element.graph is not None + self._init_storage(exp_element.graph) + + self.states[self._t] = exp_element.state + self.actions[self._t] = exp_element.action + self.rewards[self._t] = exp_element.reward + self.value_preds[self._t] = value_pred + self.action_logps[self._t] = logps + self.is_done[self._t + 1] = exp_element.is_done + + self._t += 1 + + def build_update_sampler(self, batch_size: int, num_train_epochs: int, gamma: float) -> GraphBasedTransitionBatch: + for t in reversed(range(self.max_t)): + self.returns[t] = self.rewards[t] + gamma * (1 - self.is_done[t + 1]) * self.returns[t + 1] + + advantages = self.returns[:-1] - self.value_preds[:-1] + self.advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5) + + flat_states = np.transpose(self.states[: self._t], (0, 2, 1, 3)).reshape(-1, self.num_nodes, self.feature_size) + flat_actions = np.transpose(self.actions[: self._t], (0, 2, 1)).reshape(-1, self.num_nodes) + flat_returns = np.transpose(self.returns[: self._t], (0, 2, 1)).reshape(-1, self.graph_batch_size) + flat_advantages = np.transpose(self.advantages[: self._t], (0, 2, 1)).reshape(-1, self.graph_batch_size) + flat_action_logps = np.transpose(self.action_logps[: self._t], (0, 2, 1)).reshape(-1, self.num_nodes) + + flat_dim = flat_states.shape[0] + + sampler = BatchSampler( + sampler=SubsetRandomSampler(range(flat_dim)), + batch_size=min(flat_dim, batch_size), + drop_last=False, + ) + + sampler_t = 0 + while sampler_t < num_train_epochs: + for idx in sampler: + yield GraphBasedTransitionBatch( + states=flat_states[idx], + actions=flat_actions[idx], + returns=flat_returns[idx], + graph=self.graph, + advantages=flat_advantages[idx], + old_logps=flat_action_logps[idx], + ) + + sampler_t += 1 + if sampler_t == num_train_epochs: + break + + def reset(self) -> None: + self.is_done[0] = 0 + self._t = 0 + + def get_statistics(self) -> Dict[str, float]: + action_0_count = np.count_nonzero(self.actions[: self._t] == 0) + action_1_count = np.count_nonzero(self.actions[: self._t] == 1) + action_2_count = np.count_nonzero(self.actions[: self._t] == 2) + action_count = action_0_count + action_1_count + action_2_count + assert action_count == self.actions[: self._t].size + return { + "step_t": self._t, + "action_0": action_0_count / action_count, + "action_1": action_1_count / action_count, + "action_2": action_2_count / action_count, + "reward": self.rewards[: self._t].mean(), + "return": self.returns[: self._t].mean(), + "advantage": self.advantages[: self._t].mean(), + } diff --git a/examples/mis/lwd/simulator/__init__.py b/examples/mis/lwd/simulator/__init__.py new file mode 100644 index 000000000..ac949bd50 --- /dev/null +++ b/examples/mis/lwd/simulator/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from examples.mis.lwd.simulator.common import Action, MISDecisionPayload, MISEnvMetrics, VertexState +from examples.mis.lwd.simulator.mis_business_engine import MISBusinessEngine + + +__all__ = ["Action", "MISBusinessEngine", "MISDecisionPayload", "MISEnvMetrics", "VertexState"] diff --git a/examples/mis/lwd/simulator/common.py b/examples/mis/lwd/simulator/common.py new file mode 100644 index 000000000..175a4f63b --- /dev/null +++ b/examples/mis/lwd/simulator/common.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from enum import Enum, IntEnum + +import dgl +import torch + +from maro.common import BaseAction, BaseDecisionEvent + + +class VertexState(IntEnum): + Excluded = 0 + Included = 1 + Deferred = 2 + + +class MISEnvMetrics(Enum): + IncludedNodeCount = "Included Node Count" + HammingDistanceAmongSamples = "Hamming Distance Among Two Samples" + IsDoneMasks = "Is Done Masks" + + +class Action(BaseAction): + def __init__(self, vertex_states: torch.Tensor) -> None: + self.vertex_states = vertex_states + + +class MISDecisionPayload(BaseDecisionEvent): + def __init__(self, graph: dgl.DGLGraph, vertex_states: torch.Tensor) -> None: + self.graph = graph + self.vertex_states = vertex_states diff --git a/examples/mis/lwd/simulator/mis_business_engine.py b/examples/mis/lwd/simulator/mis_business_engine.py new file mode 100644 index 000000000..0a20d1b0b --- /dev/null +++ b/examples/mis/lwd/simulator/mis_business_engine.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import math +import random +from typing import Dict, List, Optional, Tuple + +import dgl +import dgl.function as fn +import torch + +from maro.backends.frame import FrameBase, SnapshotList +from maro.event_buffer import CascadeEvent, EventBuffer, MaroEvents +from maro.simulator.scenarios import AbsBusinessEngine + +from examples.mis.lwd.simulator.common import Action, MISDecisionPayload, MISEnvMetrics, VertexState + + +class MISBusinessEngine(AbsBusinessEngine): + def __init__( + self, + event_buffer: EventBuffer, + topology: Optional[str], + start_tick: int, + max_tick: int, + snapshot_resolution: int, + max_snapshots: Optional[int], + additional_options: dict = None, + ) -> None: + super(MISBusinessEngine, self).__init__( + scenario_name="MaximumIndependentSet", + event_buffer=event_buffer, + topology=topology, + start_tick=start_tick, + max_tick=max_tick, + snapshot_resolution=snapshot_resolution, + max_snapshots=max_snapshots, + additional_options=additional_options, + ) + self._config = self._parse_additional_options(additional_options) + + # NOTE: not used now + self._frame: FrameBase = FrameBase() + self._snapshots = self._frame.snapshots + + self._register_events() + + self._batch_adjs: List[Dict[int, List[int]]] = [] + self._batch_graphs: dgl.DGLGraph = None + self._vertex_states: torch.Tensor = None + self._is_done_mask: torch.Tensor = None + + self._num_included_record: Dict[int, torch.Tensor] = None + self._hamming_distance_among_samples_record: Dict[int, torch.Tensor] = None + + self.reset() + + def _parse_additional_options(self, additional_options: dict) -> dict: + required_keys = [ + "graph_batch_size", + "num_samples", + "device", + "num_node_lower_bound", + "num_node_upper_bound", + "node_sample_probability", + ] + for key in required_keys: + assert key in additional_options, f"Parameter {key} is required in additional_options!" + + self._graph_batch_size = additional_options["graph_batch_size"] + self._num_samples = additional_options["num_samples"] + self._device = additional_options["device"] + self._num_node_lower_bound = additional_options["num_node_lower_bound"] + self._num_node_upper_bound = additional_options["num_node_upper_bound"] + self._node_sample_probability = additional_options["node_sample_probability"] + + return {key: additional_options[key] for key in required_keys} + + @property + def configs(self) -> dict: + return self._config + + @property + def frame(self) -> FrameBase: + return self._frame + + @property + def snapshots(self) -> SnapshotList: + return self._snapshots + + def _register_events(self) -> None: + self._event_buffer.register_event_handler(MaroEvents.TAKE_ACTION, self._on_action_received) + + def get_agent_idx_list(self) -> List[int]: + return [0] + + def set_seed(self, seed: int) -> None: + pass + + def _calculate_and_record_metrics(self, tick: int, undecided_before_mask: torch.Tensor) -> None: + # Calculate and record the number of included vertexes for cardinality reward. + included_mask = self._vertex_states == VertexState.Included + self._batch_graphs.ndata["h"] = included_mask.float() + node_count = dgl.sum_nodes(self._batch_graphs, "h") + self._batch_graphs.ndata.pop("h") + self._num_included_record[tick] = node_count + + # Calculate and record the diversity for diversity reward. + if self._num_samples == 2: + undecided_before_mask_left, undecided_before_mask_right = undecided_before_mask.split(1, dim=1) + + states_left, states_right = self._vertex_states.split(1, dim=1) + undecided_mask_left = states_left == VertexState.Deferred + undecided_mask_right = states_right == VertexState.Deferred + + hamming_distance = torch.abs(states_left.float() - states_right.float()) + hamming_distance[undecided_mask_left | undecided_mask_right] = 0.0 + hamming_distance[~undecided_before_mask_left & ~undecided_before_mask_right] = 0.0 + + self._batch_graphs.ndata["h"] = hamming_distance + distance = dgl.sum_nodes(self._batch_graphs, "h").expand_as(node_count) + self._batch_graphs.ndata.pop("h") + self._hamming_distance_among_samples_record[tick] = distance + + return + + def _on_action_received(self, event: CascadeEvent) -> None: + actions = event.payload + assert isinstance(actions, List) + + undecided_before_mask = self._vertex_states == VertexState.Deferred + + # Update Phase + for action in actions: + assert isinstance(action, Action) + undecided_mask = self._vertex_states == VertexState.Deferred + self._vertex_states[undecided_mask] = action.vertex_states[undecided_mask] + + # Clean-Up Phase: Set clashed node pairs to Deferred + included_mask = self._vertex_states == VertexState.Included + self._batch_graphs.ndata["h"] = included_mask.float() + self._batch_graphs.update_all(fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")) + neighbor_included_mask = self._batch_graphs.ndata.pop("h").bool() + + # Clashed: if the node and its neighbor are both set to included + clashed_mask = included_mask & neighbor_included_mask + self._vertex_states[clashed_mask] = VertexState.Deferred + neighbor_included_mask[clashed_mask] = False + + # Clean-Up Phase: exclude the deferred vertex neighboring to an included one. + undecided_mask = self._vertex_states == VertexState.Deferred + self._vertex_states[undecided_mask & neighbor_included_mask] = VertexState.Excluded + + # Timeout handling + if event.tick + 1 == self._max_tick: + undecided_mask = self._vertex_states == VertexState.Deferred + self._vertex_states[undecided_mask] = VertexState.Excluded + + self._calculate_and_record_metrics(event.tick, undecided_before_mask) + self._update_is_done_mask() + + return + + def step(self, tick: int) -> None: + decision_payload = MISDecisionPayload( + graph=self._batch_graphs, + vertex_states=self._vertex_states.clone(), + ) + decision_event = self._event_buffer.gen_decision_event(tick, decision_payload) + self._event_buffer.insert_event(decision_event) + + def _generate_er_graph(self) -> Tuple[dgl.DGLGraph, Dict[int, List[int]]]: + num_nodes = random.randint(self._num_node_lower_bound, self._num_node_upper_bound) + adj = {node: [] for node in range(num_nodes)} + + w = -1 + lp = math.log(1.0 - self._node_sample_probability) + + # Nodes in graph are from 0, num_nodes - 1 (start with v as the first node index). + v = 1 + u_list, v_list = [], [] + while v < num_nodes: + lr = math.log(1.0 - random.random()) + w = w + 1 + int(lr / lp) + while w >= v and v < num_nodes: + w = w - v + v = v + 1 + if v < num_nodes: + u_list.extend([v, w]) + v_list.extend([w, v]) + adj[v].append(w) + adj[w].append(v) + + graph = dgl.graph((u_list, v_list), num_nodes=num_nodes) + + return graph, adj + + def _reset_batch_graph(self) -> None: + graph_list = [] + self._batch_adjs = [] + for _ in range(self._graph_batch_size): + graph, adj = self._generate_er_graph() + graph_list.append(graph) + self._batch_adjs.append(adj) + + self._batch_graphs = dgl.batch(graph_list) + self._batch_graphs.set_n_initializer(dgl.init.zero_initializer) + self._batch_graphs = self._batch_graphs.to(self._device) + return + + def reset(self, keep_seed: bool = False) -> None: + self._reset_batch_graph() + + tensor_size = (self._batch_graphs.num_nodes(), self._num_samples) + self._vertex_states = torch.full(size=tensor_size, fill_value=VertexState.Deferred, device=self._device) + self._update_is_done_mask() + + self._num_included_record = {} + self._hamming_distance_among_samples_record = {} + return + + def _update_is_done_mask(self) -> None: + undecided_mask = self._vertex_states == VertexState.Deferred + self._batch_graphs.ndata["h"] = undecided_mask.float() + num_undecided = dgl.sum_nodes(self._batch_graphs, "h") + self._batch_graphs.ndata.pop("h") + self._is_done_mask = (num_undecided == 0) + return + + def post_step(self, tick: int) -> bool: + if tick + 1 == self._max_tick: + return True + + return torch.all(self._is_done_mask).item() + + def get_metrics(self) -> dict: + return { + MISEnvMetrics.IncludedNodeCount: self._num_included_record, + MISEnvMetrics.HammingDistanceAmongSamples: self._hamming_distance_among_samples_record, + MISEnvMetrics.IsDoneMasks: self._is_done_mask.cpu().detach().numpy(), + } + + +if __name__ == "__main__": + from maro.simulator import Env + device = torch.device("cuda:0") + + env = Env( + business_engine_cls=MISBusinessEngine, + options={ + "graph_batch_size": 4, + "num_samples": 2, + "device": device, + "num_node_lower_bound": 15, + "num_node_upper_bound": 20, + "node_sample_probability": 0.15, + }, + ) + + env.reset() + metrics, decision_event, done = env.step(None) + while not done: + assert isinstance(decision_event, MISDecisionPayload) + vertex_state = decision_event.vertex_states + undecided_mask = vertex_state == VertexState.Deferred + random_mask = torch.rand(vertex_state.size(), device=device) < 0.8 + vertex_state[undecided_mask & random_mask] = VertexState.Included + action = Action(vertex_state) + metrics, decision_event, done = env.step(action) + + for key in [MISEnvMetrics.IncludedNodeCount]: + print(f"[{env.tick - 1:02d}] {key:28s} {metrics[key][env.tick - 1].reshape(-1)}") + for key in [MISEnvMetrics.IsDoneMasks]: + print(f"[{env.tick - 1:02d}] {key:28s} {metrics[key].reshape(-1)}")