🎉 SocialJax has been accepted at ICLR 2026!
A suite of sequential social dilemma environments for multi-agent reinforcement learning in JAX
Common Rewards : a scenario where all agents share a single, unified reward signal. This approach ensures that all agents are aligned towards achieving the same objective, promoting collaboration and coordination among them.
Individual Rewards: each agent is assigned its own reward, inherently encouraging selfish behavior.
SocialJax leverages JAX's high-performance GPU capabilities to accelerate multi-agent reinforcement learning in sequential social dilemmas. We are committed to providing a more efficient and diverse suite of environments for studying social dilemmas. We provide JAX implementations of the following environments: Coins, Commons Harvest: Open, Commons Harvest: Closed, Clean Up, Territory, and Coop Mining, which are derived from Melting Pot 2.0 and feature commonly studied mixed incentives.
Our blog presents more details and analysis on agents' policy and performance.
First: Clone the repository
git clone https://github.com/cooperativex/SocialJax.git
cd SocialJaxSecond: Environment Setup.
Option one: Using poetry, make sure you have python 3.10
-
Install Poetry
curl -sSL https://install.python-poetry.org | python3 - export PATH="$HOME/.local/bin:$PATH"
-
Install requirements
poetry install --no-root
export PYTHONPATH=$PWD:$PYTHONPATH
-
Run code
poetry run python algorithms/train.py --algo IPPO --env coins
Option two: conda with requirements.txt
-
Conda
conda create -n SocialJax python=3.10 conda activate SocialJax
-
Install requirements
pip install -r requirements.txt
export PYTHONPATH=$PWD:$PYTHONPATH
-
Run code
python algorithms/train.py --algo IPPO --env coins
Option three: conda with environment.yml
-
Install requirements
conda env create -f environment.yml
export PYTHONPATH=$PWD:$PYTHONPATH
-
Run code
python algorithms/train.py --algo IPPO --env coins
All training is launched through a single entry point that dispatches by --algo and --env:
python algorithms/train.py --algo <ALGO> --env <ENV> [HYDRA_OVERRIDES...]--algo selects the algorithm family; --env selects the per-env config (resolving to
algorithms/<ALGO>/config/<algo>_cnn_<env>.yaml). Anything after these two flags is
forwarded verbatim to Hydra as key=value overrides.
--algo |
Description |
|---|---|
IPPO |
Independent PPO |
SVO |
Social Value Orientation (PPO with SVO reward shaping) |
MAPPO |
Multi-Agent PPO (centralised critic) |
TRANSFER |
Self-interest reward exchange |
VDN |
Value Decomposition Networks (Q-learning) |
Env names are unified across all algorithms (the per-env yamls live in
algorithms/<ALGO>/config/<algo>_cnn_<env>.yaml):
| Environment | --env |
|---|---|
| Coins | coins |
| Clean Up | cleanup |
| Coop Mining | coop_mining |
| Gift | gift |
| Mushrooms | mushrooms |
| Harvest: Open | harvest_open |
| Harvest: Closed | harvest_closed |
| Harvest: Partnership | harvest_partnership |
| PD Arena | pd_arena |
IPPO supports two reward modes: common (all agents share one summed reward) and individual (each agent gets its own reward — selfish baseline). Pick via the reward Hydra group; checkpoint and wandb name automatically get a _reward_<mode> suffix so both variants coexist.
python algorithms/train.py --algo IPPO --env coins reward=common
python algorithms/train.py --algo IPPO --env coins reward=individualSVO trains on individual rewards but shapes them toward a target orientation. The
strength (svo_w) and ideal angle (svo_ideal_angle_degrees) live under ENV_KWARGS:
python algorithms/train.py --algo SVO --env coins
python algorithms/train.py --algo SVO --env coins ENV_KWARGS.svo_w=0.5 ENV_KWARGS.svo_ideal_angle_degrees=45TRANSFER mixes individual rewards by a self-interest weight s_interest (fixed per env
in transfer_cnn_<env>.yaml, optionally scheduled over training). Override it inline via
ENV_KWARGS:
python algorithms/train.py --algo TRANSFER --env coins
python algorithms/train.py --algo TRANSFER --env pd_arena ENV_KWARGS.s_interest=0.4# Override hyperparameters
python algorithms/train.py --algo IPPO --env coins SEED=42 LR=1e-4 NUM_ENVS=128
# Multi-seed grid (Hydra multirun)
python algorithms/train.py --algo MAPPO --env cleanup -m SEED=42,52,62
# Override nested ENV_KWARGS
python algorithms/train.py --algo SVO --env coins ENV_KWARGS.svo_w=0.8
# VDN's hyperparameters live under an `alg.*` namespace
python algorithms/train.py --algo VDN --env coins alg.NUM_ENVS=32 alg.LR=1e-4
# Turn off wandb (useful for local smoke testing)
python algorithms/train.py --algo IPPO --env coins WANDB_MODE=disabledWe introduce the environments and use Schelling diagrams to demonstrate whether the environments are social dilemmas.
| Environment | Description | Schelling Diagrams Proof |
|---|---|---|
| Coins | Link | ✓ |
| Commons Harvest: Open | Link | ✓ |
| Commons Harvest: Closed | Link | ✓ |
| Commons Harvest: partnership | Link | ✓ |
| Clean Up | Link | ✓ |
| Territory | Link | ✗ |
| Coop Mining | Link | ✓ |
| Mushrooms | Link | ✓ |
| Gift Refinement | Link | ✓ |
| Prisoners Dilemma: Arena | Link | ✓ |
-
Due to algorithmic limitations, agents may not always learn the optimal actions. As a result, Schelling diagrams can prove that the environment is social dilemmas, but they cannot definitively prove that the environment is not social dilemmas.
-
Territory might not be Social diagram, but as long as the agents' behaviors are interesting, Territory holds intrinsic value.
SocialJax interfaces follow JaxMARL which takes inspiration from the PettingZoo and Gymnax.
You can create an environment using the make function:
import jax
import socialjax
env = make('clean_up')Find more fixed policy examples.
import jax
import socialjax
from socialjax import make
num_agents = 7
env = make('clean_up', num_agents=num_agents)
rng = jax.random.PRNGKey(259)
rng, _rng = jax.random.split(rng)
for t in range(100):
rng, *rngs = jax.random.split(rng, num_agents+1)
actions = [jax.random.choice(
rngs[a],
a=env.action_space(0).n,
p=jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
) for a in range(num_agents)]
obs, state, reward, done, info = env.step_env(
rng, old_state, [a for a in actions]
)You can test the speed of our environments by running speed_test_random.py or using the colab.
If you use SocialJax in your research, please cite:
@inproceedings{guo2025socialjax,
title={{SocialJax}: An Evaluation Suite for Multi-agent Reinforcement Learning in Sequential Social Dilemmas},
author={Guo, Zihao and Shi, Shuqing and Willis, Richard and Tomilin, Tristan and Leibo, Joel Z. and Du, Yali},
booktitle={International Conference on Learning Representations (ICLR)},
year={2026},
}JaxMARL: accelerated MARL environments with baselines in JAX.
PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.









