Main scripts | Typical train and test flow | Citation
algorithm/marl_ppo.pyfor training Multi agent PPO on target MPE environment.- Note run this script as python module with
python -m algorithm/marl_ppo.pyfor imports to work properly.
- Note run this script as python module with
envs/target_mpe_env.py. This is the main class that defines the target MPE environment.- Also look at
envs/wrapper.pyfor env wrappers.
- Also look at
config/mappo_config.py. This is the one and only file for changing config values to run experiments. Used python classes instead of yaml file to get auto complete and type checking and easier refactor when accessing and changing the structure of config.visualize_actor.pyfor visualizing the trained actor in a local environment.model/actor_critic_rnn.pyhas all the flax linen networks used in the PPO.
- Run the
train_with_gpu.ipynbnotebook in a colab with gpu.- Remember to set up the config in
WandbConfiginconfig/mappo_config.pyand change modeonlineto get wandb logging. - The artifacts are saved under the name "PPO_RNN_Runner_State"
- Remember to set up the config in
- Visualize the actor with
visualize_actor.pyafter changing theartifact_versionvariable in the block.if __name__ == "__main__"
It is recommended to first install either requirements_jax_cpu.txt or requirements_jax_cuda.txt before
requirements.txt since the packages in requirements will install a jax version for you.
If you use JaxInforMARL in your work, please cite as follows:
@software{JaxInforMARL,
title={JaxInforMARL: Multi-Agent Target MPE RL Environments with GNNs in JAX},
author={Joseph Selvaraaj},
year = {2025},
url = {https://github.com/jselvaraaj/JaxInforMARL},
version = {1.0.0}
}

