From 2d5ed96c9770b96ddc2df49f406e981ea75bdd6a Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Mon, 27 Jun 2022 16:02:32 +0100 Subject: [PATCH 01/29] begin adding centralized learning --- pax/independent_learners.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pax/independent_learners.py b/pax/independent_learners.py index 840d7cd3..3b584fb2 100644 --- a/pax/independent_learners.py +++ b/pax/independent_learners.py @@ -23,7 +23,8 @@ def update( actions: List[jnp.ndarray], timesteps: List[TimeStep], ) -> None: - # might have to add some centralised training to this + # TODO: Add centralized training for LOLA-esque algorithms + # (They require the parameters of the other agent at timestep i) for agent, t, action, t_1 in zip( self.agents, old_timesteps, actions, timesteps ): From ec3ce6abbf5b49a468085887bdf785134479ad02 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Tue, 5 Jul 2022 15:27:22 +0100 Subject: [PATCH 02/29] first commit. begin adding centralized training for LOLA --- pax/dqn/agent.py | 1 + pax/experiment.py | 1 + pax/independent_learners.py | 7 ++++--- pax/ppo/ppo.py | 8 +++++++- pax/ppo/ppo_gru.py | 8 +++++++- pax/runner.py | 2 -- 6 files changed, 20 insertions(+), 7 deletions(-) diff --git a/pax/dqn/agent.py b/pax/dqn/agent.py index a11e27e7..6fef9d63 100644 --- a/pax/dqn/agent.py +++ b/pax/dqn/agent.py @@ -183,6 +183,7 @@ def update( timestep: dm_env.TimeStep, action: jnp.array, new_timestep: dm_env.TimeStep, + other_agents=None, ): self._replay.add_batch( diff --git a/pax/experiment.py b/pax/experiment.py index 67262080..ce74cf60 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -9,6 +9,7 @@ BattleOfTheSexes, Chicken, ) +from pax.centralized_learners import CentralizedLearners from pax.independent_learners import IndependentLearners from pax.ppo.ppo import make_agent from pax.ppo.ppo_gru import make_gru_agent diff --git a/pax/independent_learners.py b/pax/independent_learners.py index 3b584fb2..91b24431 100644 --- a/pax/independent_learners.py +++ b/pax/independent_learners.py @@ -5,7 +5,9 @@ class IndependentLearners: - "Interface for a set of batched agents to work with environment" + """Interface for a set of batched agents to work with environment + Performs independent learning + """ def __init__(self, agents: list): self.num_agents: int = len(agents) @@ -23,8 +25,7 @@ def update( actions: List[jnp.ndarray], timesteps: List[TimeStep], ) -> None: - # TODO: Add centralized training for LOLA-esque algorithms - # (They require the parameters of the other agent at timestep i) + for agent, t, action, t_1 in zip( self.agents, old_timesteps, actions, timesteps ): diff --git a/pax/ppo/ppo.py b/pax/ppo/ppo.py index 7b39db07..d68efed6 100644 --- a/pax/ppo/ppo.py +++ b/pax/ppo/ppo.py @@ -413,7 +413,13 @@ def select_action(self, t: TimeStep): ) return utils.to_numpy(actions) - def update(self, t: TimeStep, actions: np.array, t_prime: TimeStep): + def update( + self, + t: TimeStep, + actions: np.array, + t_prime: TimeStep, + other_agents=None, + ): # Adds agent and environment info to buffer self._rollouts( buffer=self._trajectory_buffer, diff --git a/pax/ppo/ppo_gru.py b/pax/ppo/ppo_gru.py index 1d7eeada..6c655378 100644 --- a/pax/ppo/ppo_gru.py +++ b/pax/ppo/ppo_gru.py @@ -445,7 +445,13 @@ def select_action(self, t: TimeStep): ) return utils.to_numpy(actions) - def update(self, t: TimeStep, actions: np.array, t_prime: TimeStep): + def update( + self, + t: TimeStep, + actions: np.array, + t_prime: TimeStep, + other_agents=None, + ): # Adds agent and environment info to buffer self._rollouts( buffer=self._trajectory_buffer, diff --git a/pax/runner.py b/pax/runner.py index 0de025b0..5890974c 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -33,8 +33,6 @@ def train_loop(self, env, agents, num_episodes, watchers): rewards_0.append(r_0) rewards_1.append(r_1) - # train model - # agents.update(t, actions, infos, t_prime) agents.update(t, actions, t_prime) self.train_steps += 1 From 2664eead9901847f8cc537d52379754e3542c1f6 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Tue, 5 Jul 2022 16:31:09 +0100 Subject: [PATCH 03/29] add base lola --- pax/lola/__init__.py | 0 pax/lola/lola.py | 102 +++++++++++++++++++++++++++++++++++++++++++ pax/lola/network.py | 0 3 files changed, 102 insertions(+) create mode 100644 pax/lola/__init__.py create mode 100644 pax/lola/lola.py create mode 100644 pax/lola/network.py diff --git a/pax/lola/__init__.py b/pax/lola/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pax/lola/lola.py b/pax/lola/lola.py new file mode 100644 index 00000000..bfa0b450 --- /dev/null +++ b/pax/lola/lola.py @@ -0,0 +1,102 @@ +# Learning with Opponent-Learning Awareness (LOLA) implementation in JAX +# https://arxiv.org/pdf/1709.04326.pdf + +from typing import NamedTuple, Any + +from dm_env import TimeStep +import haiku as hk +import jax +import jax.numpy as jnp + + +class TrainingState(NamedTuple): + params: hk.Params + random_key: jnp.ndarray + + +class LOLA: + """Implements LOLA with exact value functions""" + + def __init__(self, network: hk.Params, random_key: jnp.ndarray): + def policy( + params: hk.Params, observation: jnp.ndarray, state: TrainingState + ): + """Determines how to choose an action""" + key, subkey = jax.random.split(state.random_key) + logits = network.apply(params, observation) + print("Logits from policy", logits) + print("Softmax of logits from policy", jax.nn.softmax(logits)) + actions = jax.random.categorical(subkey, logits) + state = state._replace(random_key=key) + return int(actions), state + + def loss(): + """Loss function""" + pass + + def sgd(): + """Stochastic gradient descent""" + pass + + def make_initial_state(key: jnp.ndarray) -> TrainingState: + """Make initial training state for LOLA""" + key, subkey = jax.random.split(key) + dummy_obs = jnp.zeros(shape=(1, 5)) + params = network.init(subkey, dummy_obs) + return TrainingState(params=params, random_key=key) + + self.state = make_initial_state(random_key) + self._policy = policy + + def select_action(self, t: TimeStep): + """Select action based on observation""" + # Unpack + params = self.state.params + state = self.state + action, self.state = self._policy(params, t.observation, state) + return action + + def update( + self, + t: TimeStep, + actions: jnp.ndarray, + t_prime: TimeStep, + other_agents=None, + ): + """Update agent""" + # an sgd step requires the parameters of the other agent. + # currently, the runner file doesn't have access to the other agent's gradients + # we could put the parameters of the agent inside the timestep + pass + + +def make_lola(seed: int) -> LOLA: + """ "Instantiate LOLA""" + random_key = jax.random.PRNGKey(seed) + + def forward(inputs): + """Forward pass for LOLA exact""" + values = hk.Linear(1, with_bias=False) + return values(inputs) + + network = hk.without_apply_rng(hk.transform(forward)) + + return LOLA(network=network, random_key=random_key) + + +if __name__ == "__main__": + lola = make_lola(seed=0) + print(f"LOLA state: {lola.state}") + timestep = TimeStep( + step_type=0, + reward=1, + discount=1, + observation=jnp.array([[1, 0, 0, 0, 0]]), + ) + action = lola.select_action(timestep) + print("Action", action) + timestep = TimeStep( + step_type=0, reward=1, discount=1, observation=jnp.zeros(shape=(1, 5)) + ) + action = lola.select_action(timestep) + print("Action", action) diff --git a/pax/lola/network.py b/pax/lola/network.py new file mode 100644 index 00000000..e69de29b From bcb4833f6a2d4a592e2a1b4d19bda5b534f29596 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Tue, 5 Jul 2022 16:39:58 +0100 Subject: [PATCH 04/29] add centralized learner --- pax/centralized_learners.py | 43 +++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 pax/centralized_learners.py diff --git a/pax/centralized_learners.py b/pax/centralized_learners.py new file mode 100644 index 00000000..67f5ed30 --- /dev/null +++ b/pax/centralized_learners.py @@ -0,0 +1,43 @@ +from typing import Callable, List + +from dm_env import TimeStep +import jax.numpy as jnp + + +class CentralizedLearners: + """Interface for a set of batched agents to work with environment + Performs centralized training""" + + def __init__(self, agents: list): + self.num_agents: int = len(agents) + self.agents: list = agents + + def select_action(self, timesteps: List[TimeStep]) -> List[jnp.ndarray]: + assert len(timesteps) == self.num_agents + return [ + agent.select_action(t) for agent, t in zip(self.agents, timesteps) + ] + + def update( + self, + old_timesteps: List[TimeStep], + actions: List[jnp.ndarray], + timesteps: List[TimeStep], + ) -> None: + counter = 0 + for agent, t, action, t_1 in zip( + self.agents, old_timesteps, actions, timesteps + ): + # All other agents in a list + # i.e. if i am agent2, then other_agents=[agent1, agent3, agent4 ...] + other_agents = self.agents[:counter] + self.agents[counter + 1 :] + agent.update(t, action, t_1, other_agents) + counter += 1 + + def log(self, metrics: List[Callable]) -> None: + for metric, agent in zip(metrics, self.agents): + metric(agent) + + def eval(self, set_flag: bool) -> None: + for agent in self.agents: + agent.eval = set_flag From 6c17d02e222e0e2f5f9608e042d47744fcf83649 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Tue, 5 Jul 2022 18:32:43 +0100 Subject: [PATCH 05/29] add lola machinery to experiments.py --- pax/experiment.py | 10 ++++++++++ pax/lola/lola.py | 4 +--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pax/experiment.py b/pax/experiment.py index 518aa514..34a7637d 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -13,6 +13,7 @@ from pax.independent_learners import IndependentLearners from pax.ppo.ppo import make_agent from pax.ppo.ppo_gru import make_gru_agent +from pax.lola.lola import make_lola from pax.runner import Runner from pax.sac.agent import SAC from pax.strategies import ( @@ -147,6 +148,10 @@ def get_PPO_agent(seed, player_id): ) return ppo_agent + def get_LOLA_agent(seed, player_id): + lola_agent = make_lola(seed) + return lola_agent + strategies = { "TitForTat": TitForTat, "Defect": Defect, @@ -158,6 +163,7 @@ def get_PPO_agent(seed, player_id): "SAC": get_SAC_agent, "DQN": get_DQN_agent, "PPO": get_PPO_agent, + "LOLA": get_LOLA_agent, } assert args.agent1 in strategies @@ -178,6 +184,9 @@ def get_PPO_agent(seed, player_id): logger.info(f"Agent Pair: {args.agent1} | {args.agent2}") logger.info(f"Agent seeds: {seeds[0]} | {seeds[1]}") + if args.centralized: + return CentralizedLearners([agent_0, agent_1]) + return IndependentLearners([agent_0, agent_1]) @@ -226,6 +235,7 @@ def dumb_log(agent, *args): "SAC": sac_log, "DQN": dqn_log, "PPO": ppo_log, + "LOLA": dumb_log, } assert args.agent1 in strategies diff --git a/pax/lola/lola.py b/pax/lola/lola.py index bfa0b450..3fb2ddef 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -24,11 +24,9 @@ def policy( """Determines how to choose an action""" key, subkey = jax.random.split(state.random_key) logits = network.apply(params, observation) - print("Logits from policy", logits) - print("Softmax of logits from policy", jax.nn.softmax(logits)) actions = jax.random.categorical(subkey, logits) state = state._replace(random_key=key) - return int(actions), state + return actions, state def loss(): """Loss function""" From 8fab1679a6b1b089221ff4f47fefacb8c79b9e4e Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Wed, 6 Jul 2022 15:51:27 +0100 Subject: [PATCH 06/29] fix entropy annealing --- pax/conf/config.yaml | 11 ++++++----- pax/experiment.py | 1 - pax/lola/lola.py | 13 ++++++------- pax/ppo/buffer.py | 2 ++ pax/ppo/networks.py | 6 ++++-- pax/ppo/ppo.py | 10 ++++++---- pax/ppo/ppo_gru.py | 3 +++ pax/watchers.py | 18 ++++++++++-------- 8 files changed, 37 insertions(+), 27 deletions(-) diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index 64822826..4f4a330b 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -13,13 +13,14 @@ seed: 0 save_dir: "./exp/${wandb.group}/${wandb.name}" # Agents -agent1: 'PPO' -agent2: 'TitForTat' +agent1: 'LOLA' +agent2: 'PPO' # Environment env_id: ipd game: ipd payoff: +centralized: True # Training hyperparameters num_envs: 100 @@ -54,8 +55,8 @@ ppo: clip_value: True max_gradient_norm: 0.5 anneal_entropy: True - entropy_coeff_start: 0.1 - entropy_coeff_horizon: 200_000_000 + entropy_coeff_start: 0.2 + entropy_coeff_horizon: 500_000 entropy_coeff_end: 0.01 lr_scheduling: True learning_rate: 2.5e-2 @@ -70,6 +71,6 @@ ppo: wandb: entity: "ucl-dark" project: ipd - group: '${agent1}-vs-${agent2}-${game}-with-memory=${ppo.with_memory}-final' + group: '${agent1}-vs-${agent2}-${game}-with-memory=${ppo.with_memory}-v3' name: run-seed-${seed} log: True diff --git a/pax/experiment.py b/pax/experiment.py index 34a7637d..9550d354 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -23,7 +23,6 @@ Random, Human, GrimTrigger, - # ZDExtortion, ) from pax.utils import Section from pax.watchers import ( diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 3fb2ddef..3731cf5d 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -59,22 +59,21 @@ def update( t: TimeStep, actions: jnp.ndarray, t_prime: TimeStep, - other_agents=None, + other_agents: list = None, ): """Update agent""" - # an sgd step requires the parameters of the other agent. - # currently, the runner file doesn't have access to the other agent's gradients - # we could put the parameters of the agent inside the timestep + # for agent in other_agents: + # other_agent_obs = agent._trajectory_buffer.observations pass def make_lola(seed: int) -> LOLA: - """ "Instantiate LOLA""" + """Instantiate LOLA""" random_key = jax.random.PRNGKey(seed) def forward(inputs): - """Forward pass for LOLA exact""" - values = hk.Linear(1, with_bias=False) + """Forward pass for LOLA""" + values = hk.Linear(2, with_bias=False) return values(inputs) network = hk.without_apply_rng(hk.transform(forward)) diff --git a/pax/ppo/buffer.py b/pax/ppo/buffer.py index 1452dfda..f161a348 100644 --- a/pax/ppo/buffer.py +++ b/pax/ppo/buffer.py @@ -153,6 +153,8 @@ def reset(self): (self._num_envs, self._num_steps, self.gru_dim) ) + self.parameters = jnp.zeros((self._num_envs, self._num_steps)) + if __name__ == "__main__": pass diff --git a/pax/ppo/networks.py b/pax/ppo/networks.py index de13f2f5..0a43a13a 100644 --- a/pax/ppo/networks.py +++ b/pax/ppo/networks.py @@ -19,12 +19,14 @@ def __init__( super().__init__(name=name) self._logit_layer = hk.Linear( num_values, - w_init=hk.initializers.Orthogonal(0.01), # baseline + # w_init=hk.initializers.Orthogonal(0.01), # baseline + w_init=hk.initializers.Constant(0.5), with_bias=False, ) self._value_layer = hk.Linear( 1, - w_init=hk.initializers.Orthogonal(1.0), # baseline + # w_init=hk.initializers.Orthogonal(1.0), # baseline + w_init=hk.initializers.Constant(0.5), with_bias=False, ) diff --git a/pax/ppo/ppo.py b/pax/ppo/ppo.py index 193c878c..452440b6 100644 --- a/pax/ppo/ppo.py +++ b/pax/ppo/ppo.py @@ -184,9 +184,10 @@ def loss( fraction * entropy_coeff_start + (1 - fraction) * entropy_coeff_end ) + # Constant Entropy term - else: - entropy_cost = entropy_coeff_start + # else: + # entropy_cost = entropy_coeff_start entropy_loss = -jnp.mean(entropy) # Total loss: Minimize policy and value loss; maximize entropy @@ -201,6 +202,7 @@ def loss( "loss_policy": policy_loss, "loss_value": value_loss, "loss_entropy": entropy_loss, + "entropy_cost": entropy_cost, } @jax.jit @@ -371,8 +373,6 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: dummy_obs = utils.add_batch_dim(dummy_obs) initial_params = network.init(subkey, dummy_obs) initial_opt_state = optimizer.init(initial_params) - # for dict_key in initial_params.keys(): - # print(initial_params[dict_key]) return TrainingState( params=initial_params, opt_state=initial_opt_state, @@ -401,6 +401,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: "loss_policy": 0, "loss_value": 0, "loss_entropy": 0, + "entropy_cost": entropy_coeff_start, } # Initialize functions @@ -480,6 +481,7 @@ def update( self._logger.metrics["loss_policy"] = results["loss_policy"] self._logger.metrics["loss_value"] = results["loss_value"] self._logger.metrics["loss_entropy"] = results["loss_entropy"] + self._logger.metrics["entropy_cost"] = results["entropy_cost"] # TODO: seed, and player_id not used in CartPole diff --git a/pax/ppo/ppo_gru.py b/pax/ppo/ppo_gru.py index 476a6ef8..218540fc 100644 --- a/pax/ppo/ppo_gru.py +++ b/pax/ppo/ppo_gru.py @@ -209,6 +209,7 @@ def loss( "loss_policy": policy_loss, "loss_value": value_loss, "loss_entropy": entropy_loss, + "entropy_cost": entropy_cost, } # }, new_rnn_unroll_state @@ -429,6 +430,7 @@ def make_initial_state( "loss_policy": 0, "loss_value": 0, "loss_entropy": 0, + "entropy_cost": entropy_coeff_start, } # Initialize functions @@ -503,6 +505,7 @@ def update( self._logger.metrics["loss_policy"] = results["loss_policy"] self._logger.metrics["loss_value"] = results["loss_value"] self._logger.metrics["loss_entropy"] = results["loss_entropy"] + self._logger.metrics["entropy_cost"] = results["entropy_cost"] # TODO: seed, and player_id not used in CartPole diff --git a/pax/watchers.py b/pax/watchers.py index 53e2afe1..f0740f50 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -6,11 +6,11 @@ # five possible states START = jnp.array([[0, 0, 0, 0, 1]]) CC = jnp.array([[1, 0, 0, 0, 0]]) -CD = jnp.array([[0, 1, 0, 0, 0]]) -DC = jnp.array([[0, 0, 1, 0, 0]]) +DC = jnp.array([[0, 1, 0, 0, 0]]) +CD = jnp.array([[0, 0, 1, 0, 0]]) DD = jnp.array([[0, 0, 0, 1, 0]]) -STATE_NAMES = ["START", "CC", "CD", "DC", "DD"] -ALL_STATES = [START, CC, CD, DC, DD] +STATE_NAMES = ["START", "CC", "DC", "CD", "DD"] +ALL_STATES = [START, CC, DC, CD, DD] def policy_logger(agent) -> None: @@ -119,11 +119,13 @@ def ppo_losses(agent) -> None: loss_policy = agent._logger.metrics["loss_policy"] loss_value = agent._logger.metrics["loss_value"] loss_entropy = agent._logger.metrics["loss_entropy"] + entropy_coefficient = agent._logger.metrics["entropy_cost"] losses = { "sgd_steps": sgd_steps, - "losses/total": loss_total, - "losses/policy": loss_policy, - "losses/value": loss_value, - "losses/entropy": loss_entropy, + "train/total": loss_total, + "train/policy": loss_policy, + "train/value": loss_value, + "train/entropy": loss_entropy, + "train/entropy_coefficient": entropy_coefficient, } return losses From 21dc4fdae3538545058ebd1e6023c75e33dc8509 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Thu, 7 Jul 2022 02:07:34 +0100 Subject: [PATCH 07/29] fix done conditiion in additional rollout step in PPO --- pax/experiment.py | 3 +-- pax/lola/lola.py | 20 +++++++++++++------- pax/ppo/ppo.py | 2 +- pax/ppo/ppo_gru.py | 2 +- pax/watchers.py | 12 ++++++++++++ 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/pax/experiment.py b/pax/experiment.py index 9550d354..7a12a97e 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -148,7 +148,7 @@ def get_PPO_agent(seed, player_id): return ppo_agent def get_LOLA_agent(seed, player_id): - lola_agent = make_lola(seed) + lola_agent = make_lola(seed, player_id) return lola_agent strategies = { @@ -158,7 +158,6 @@ def get_LOLA_agent(seed, player_id): "Human": Human, "Random": Random, "Grim": GrimTrigger, - # "ZDExtortion": ZDExtortion, "SAC": get_SAC_agent, "DQN": get_DQN_agent, "PPO": get_PPO_agent, diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 3731cf5d..da02efd7 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -17,7 +17,9 @@ class TrainingState(NamedTuple): class LOLA: """Implements LOLA with exact value functions""" - def __init__(self, network: hk.Params, random_key: jnp.ndarray): + def __init__( + self, network: hk.Params, random_key: jnp.ndarray, player_id: int + ): def policy( params: hk.Params, observation: jnp.ndarray, state: TrainingState ): @@ -45,6 +47,8 @@ def make_initial_state(key: jnp.ndarray) -> TrainingState: self.state = make_initial_state(random_key) self._policy = policy + self.player_id = player_id # for logging + # self. def select_action(self, t: TimeStep): """Select action based on observation""" @@ -62,12 +66,15 @@ def update( other_agents: list = None, ): """Update agent""" - # for agent in other_agents: - # other_agent_obs = agent._trajectory_buffer.observations - pass + for agent in other_agents: + other_agent_parameters = agent._trajectory_buffer.parameters + # print(f"other_agent_parameters: {other_agent_parameters}") + print( + f"other_agent_parameters shape: {other_agent_parameters.shape}" + ) -def make_lola(seed: int) -> LOLA: +def make_lola(seed: int, player_id: int) -> LOLA: """Instantiate LOLA""" random_key = jax.random.PRNGKey(seed) @@ -77,8 +84,7 @@ def forward(inputs): return values(inputs) network = hk.without_apply_rng(hk.transform(forward)) - - return LOLA(network=network, random_key=random_key) + return LOLA(network=network, random_key=random_key, player_id=player_id) if __name__ == "__main__": diff --git a/pax/ppo/ppo.py b/pax/ppo/ppo.py index 452440b6..a82bf3b6 100644 --- a/pax/ppo/ppo.py +++ b/pax/ppo/ppo.py @@ -465,7 +465,7 @@ def update( ) self._trajectory_buffer.add( - timestep=t, + timestep=t_prime, # this should be t_prime action=0, log_prob=0, value=self._state.extras["values"], diff --git a/pax/ppo/ppo_gru.py b/pax/ppo/ppo_gru.py index 218540fc..27cbc4c4 100644 --- a/pax/ppo/ppo_gru.py +++ b/pax/ppo/ppo_gru.py @@ -488,7 +488,7 @@ def update( ) self._trajectory_buffer.add( - timestep=t, + timestep=t_prime, action=0, log_prob=0, value=self._state.extras["values"], diff --git a/pax/watchers.py b/pax/watchers.py index f0740f50..fd99fdbe 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -129,3 +129,15 @@ def ppo_losses(agent) -> None: "train/entropy_coefficient": entropy_coefficient, } return losses + + +def policy_logger_lola(agent) -> None: + weights = agent._state.params["categorical_value_head/~/linear"]["w"] + pi = nn.softmax(weights) + sgd_steps = agent._total_steps / agent._num_steps + probs = { + f"policy/{str(s)}/{agent.player_id}.cooperate": p[0] + for (s, p) in zip(State, pi) + } + probs.update({"policy/total_steps": sgd_steps}) + return probs From 6fe1d02ddefd0782876e7a27b9d8850002b81b61 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Fri, 8 Jul 2022 17:38:56 +0100 Subject: [PATCH 08/29] minor changes to lola --- pax/conf/config.yaml | 9 +++++---- pax/lola/lola.py | 6 ++++-- pax/ppo/buffer.py | 2 +- pax/ppo/ppo.py | 4 +++- pax/ppo/ppo_gru.py | 4 +++- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index 4f4a330b..e9e159e5 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -14,7 +14,7 @@ save_dir: "./exp/${wandb.group}/${wandb.name}" # Agents agent1: 'LOLA' -agent2: 'PPO' +agent2: 'LOLA' # Environment env_id: ipd @@ -24,7 +24,7 @@ centralized: True # Training hyperparameters num_envs: 100 -num_steps: 25 # number of steps per episode +num_steps: 100 # number of steps per episode total_timesteps: 100_000_000 eval_every: 500 # eval every n episodes, not timesteps @@ -48,7 +48,7 @@ dqn: ppo: num_minibatches: 10 num_epochs: 4 - gamma: 0.99 + gamma: 0.75 gae_lambda: 0.95 ppo_clipping_epsilon: 0.2 value_coeff: 0.5 @@ -57,11 +57,12 @@ ppo: anneal_entropy: True entropy_coeff_start: 0.2 entropy_coeff_horizon: 500_000 + # for halfway, the horizon should (1/2) * (total_timesteps / num_envs) entropy_coeff_end: 0.01 lr_scheduling: True learning_rate: 2.5e-2 adam_epsilon: 1e-5 - with_memory: True + with_memory: False # LOLA agent parameters # lola: diff --git a/pax/lola/lola.py b/pax/lola/lola.py index da02efd7..10efb5f2 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -2,7 +2,7 @@ # https://arxiv.org/pdf/1709.04326.pdf from typing import NamedTuple, Any - +from pax.lola.buffer import TrajectoryBuffer from dm_env import TimeStep import haiku as hk import jax @@ -48,7 +48,9 @@ def make_initial_state(key: jnp.ndarray) -> TrainingState: self.state = make_initial_state(random_key) self._policy = policy self.player_id = player_id # for logging - # self. + # self._trajectory_buffer = TrajectoryBuffer( + # num_envs, num_steps, obs_spec + # ) def select_action(self, t: TimeStep): """Select action based on observation""" diff --git a/pax/ppo/buffer.py b/pax/ppo/buffer.py index f161a348..10b6bd0d 100644 --- a/pax/ppo/buffer.py +++ b/pax/ppo/buffer.py @@ -77,7 +77,7 @@ def add( self.behavior_values = jax.lax.stop_gradient( self.behavior_values.at[:, self.ptr].set(value.flatten()) ) - self.dones = self.dones.at[:, self.ptr].set(timestep.step_type) + self.dones = self.dones.at[:, self.ptr].set(new_timestep.step_type) self.rewards = self.rewards.at[:, self.ptr].set(new_timestep.reward) if hidden is not None: diff --git a/pax/ppo/ppo.py b/pax/ppo/ppo.py index a82bf3b6..e62be8e5 100644 --- a/pax/ppo/ppo.py +++ b/pax/ppo/ppo.py @@ -468,7 +468,9 @@ def update( timestep=t_prime, # this should be t_prime action=0, log_prob=0, - value=self._state.extras["values"], + value=self._state.extras["values"] + if not t_prime.last() + else jnp.zeros_like(self._state.extras["values"]), new_timestep=t_prime, ) diff --git a/pax/ppo/ppo_gru.py b/pax/ppo/ppo_gru.py index 27cbc4c4..e9d51ad5 100644 --- a/pax/ppo/ppo_gru.py +++ b/pax/ppo/ppo_gru.py @@ -491,7 +491,9 @@ def update( timestep=t_prime, action=0, log_prob=0, - value=self._state.extras["values"], + value=self._state.extras["values"] + if not t_prime.last() + else jnp.zeros_like(self._state.extras["values"]), new_timestep=t_prime, ) From b409692d8de9fea6caa49a3661007e4420558f19 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Tue, 12 Jul 2022 14:43:06 +0100 Subject: [PATCH 09/29] minor bug fix --- pax/ppo/ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pax/ppo/ppo.py b/pax/ppo/ppo.py index 9af4ffb0..305ac1a1 100644 --- a/pax/ppo/ppo.py +++ b/pax/ppo/ppo.py @@ -182,8 +182,8 @@ def loss( ) # Constant Entropy term - # else: - # entropy_cost = entropy_coeff_start + else: + entropy_cost = entropy_coeff_start entropy_loss = -jnp.mean(entropy) # Total loss: Minimize policy and value loss; maximize entropy From 8f421707949588bef8228e737fa3d31354254457 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Wed, 13 Jul 2022 13:10:44 +0100 Subject: [PATCH 10/29] add changes to buffer --- pax/conf/config.yaml | 4 +- pax/experiment.py | 11 ++- pax/lola/lola.py | 167 +++++++++++++++++++++++++++++++++++-------- pax/lola/network.py | 50 +++++++++++++ pax/runner.py | 2 +- 5 files changed, 201 insertions(+), 33 deletions(-) diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index 7d204cce..f531779f 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -19,7 +19,7 @@ agent2: 'LOLA' # Environment env_id: ipd game: ipd -env_type: infinite +env_type: finite env_discount: 0.99 payoff: centralized: True @@ -76,4 +76,4 @@ wandb: project: ipd group: '${agent1}-vs-${agent2}-${game}-with-memory=${ppo.with_memory}-v3' name: run-seed-${seed} - log: True + log: False diff --git a/pax/experiment.py b/pax/experiment.py index 77d9aa4a..4acf5872 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -188,7 +188,16 @@ def get_PPO_agent(seed, player_id): return ppo_agent def get_LOLA_agent(seed, player_id): - lola_agent = make_lola(seed, player_id) + dummy_env = SequentialMatrixGame( + args.num_envs, args.payoff, args.num_steps + ) + lola_agent = make_lola( + args, + obs_spec=(dummy_env.observation_spec().num_values,), + action_spec=dummy_env.action_spec().num_values, + seed=seed, + player_id=player_id, + ) return lola_agent def get_hyper_agent(seed, player_id): diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 10efb5f2..f622ca9b 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -1,37 +1,82 @@ # Learning with Opponent-Learning Awareness (LOLA) implementation in JAX # https://arxiv.org/pdf/1709.04326.pdf -from typing import NamedTuple, Any -from pax.lola.buffer import TrajectoryBuffer +from typing import NamedTuple, Any, Mapping +from pax.lola.buffer import Replay +from pax.lola.network import make_network from dm_env import TimeStep import haiku as hk import jax import jax.numpy as jnp +import numpy as np class TrainingState(NamedTuple): params: hk.Params random_key: jnp.ndarray + extras: Mapping[str, jnp.ndarray] class LOLA: - """Implements LOLA with exact value functions""" + """Implements LOLA with policy gradient""" + + # TODO + # 1. implement policy gradient + # 2. Implement loss function + # 3. Implement storing parameters in the buffer so that both agents can access + # 4. Second order LOLA term def __init__( - self, network: hk.Params, random_key: jnp.ndarray, player_id: int + self, + network: hk.Params, + random_key: jnp.ndarray, + player_id: int, + replay_capacity: int = 100000, + min_replay_size: int = 1000, + sgd_period: int = 1, + batch_size: int = 256, + obs_spec: tuple = (5,), + num_envs: int = 4, + num_steps: int = 500, ): def policy( params: hk.Params, observation: jnp.ndarray, state: TrainingState ): """Determines how to choose an action""" + # key, subkey = jax.random.split(state.random_key) + # logits = network.apply(params, observation) + # actions = jax.random.categorical(subkey, logits) + # state = state._replace(random_key=key) + # return actions, state key, subkey = jax.random.split(state.random_key) - logits = network.apply(params, observation) - actions = jax.random.categorical(subkey, logits) - state = state._replace(random_key=key) + dist, values = network.apply(params, observation) + actions = dist.sample(seed=subkey) + state.extras["values"] = values + state.extras["log_probs"] = dist.log_prob(actions) + state = TrainingState( + params=params, + random_key=key, + extras=state.extras, + ) return actions, state + def rollouts( + buffer: Replay, + t: TimeStep, + actions: np.array, + t_prime: TimeStep, + state: TrainingState, + ) -> None: + """Stores rollout in buffer""" + log_probs, values = ( + state.extras["log_probs"], + state.extras["values"], + ) + buffer.add(t, actions, log_probs, values, t_prime) + def loss(): """Loss function""" + # see pg. 4 of the LOLA paper pass def sgd(): @@ -43,55 +88,119 @@ def make_initial_state(key: jnp.ndarray) -> TrainingState: key, subkey = jax.random.split(key) dummy_obs = jnp.zeros(shape=(1, 5)) params = network.init(subkey, dummy_obs) - return TrainingState(params=params, random_key=key) + return TrainingState( + params=params, + random_key=key, + extras={"values": None, "log_probs": None}, + ) + + # init buffer + self._buffer = Replay(num_envs, num_steps, replay_capacity, obs_spec) - self.state = make_initial_state(random_key) + # init functions + self._state = make_initial_state(random_key) self._policy = policy - self.player_id = player_id # for logging + self._rollouts = rollouts + self._sgd_step = sgd + + # init constants + self.player_id = player_id + self._min_replay_size = min_replay_size + self._sgd_period = sgd_period + self._batch_size = batch_size + + # init variables + self._total_steps = 0 + # self._trajectory_buffer = TrajectoryBuffer( # num_envs, num_steps, obs_spec # ) + # TODO: get the correct parameters for this. + # self._trajectory_buffer = Replay( + # num_envs, num_steps, obs_spec + # ) def select_action(self, t: TimeStep): """Select action based on observation""" # Unpack - params = self.state.params - state = self.state - action, self.state = self._policy(params, t.observation, state) + params = self._state.params + state = self._state + action, self._state = self._policy(params, t.observation, state) return action def update( self, t: TimeStep, - actions: jnp.ndarray, + action: jnp.ndarray, t_prime: TimeStep, other_agents: list = None, ): + + self._rollouts( + buffer=self._buffer, + t=t, + actions=action, + t_prime=t_prime, + state=self._state, + ) + + self._total_steps += 1 + if self._total_steps % self._sgd_period != 0: + return + + print(self._buffer.size()) + if self._buffer.size() < self._min_replay_size: + return + + # Do a batch of SGD. + sample, key = self._buffer.sample( + self._batch_size, self._state.random_key + ) + print(sample) + self._state = self._state._replace(random_key=key) + # self._state = self._sgd_step(self._state, transitions) + """Update agent""" - for agent in other_agents: - other_agent_parameters = agent._trajectory_buffer.parameters - # print(f"other_agent_parameters: {other_agent_parameters}") - print( - f"other_agent_parameters shape: {other_agent_parameters.shape}" - ) + # for agent in other_agents: + # pass + # other_agent_parameters = agent._trajectory_buffer.parameters + # # print(f"other_agent_parameters: {other_agent_parameters}") + # print( + # f"other_agent_parameters shape: {other_agent_parameters.shape}" + # ) -def make_lola(seed: int, player_id: int) -> LOLA: +# TODO: Add args argument +def make_lola(args, obs_spec, action_spec, seed: int, player_id: int) -> LOLA: """Instantiate LOLA""" random_key = jax.random.PRNGKey(seed) - def forward(inputs): - """Forward pass for LOLA""" - values = hk.Linear(2, with_bias=False) - return values(inputs) - - network = hk.without_apply_rng(hk.transform(forward)) - return LOLA(network=network, random_key=random_key, player_id=player_id) + # def forward(inputs): + # """Forward pass for LOLA""" + # values = hk.Linear(2, with_bias=False) + # return values(inputs) + + # network = hk.without_apply_rng(hk.transform(forward)) + + network = make_network(action_spec) + + return LOLA( + network=network, + random_key=random_key, + player_id=player_id, + replay_capacity=100000, # args.lola.replay_capacity, + min_replay_size=50, # args.lola.min_replay_size, + sgd_period=1, # args.dqn.sgd_period, + batch_size=2, # args.dqn.batch_size + obs_spec=(5,), # obs_spec + num_envs=2, # num_envs=args.num_envs, + num_steps=4, # num_steps=args.num_steps, + ) if __name__ == "__main__": lola = make_lola(seed=0) - print(f"LOLA state: {lola.state}") + print(f"LOLA state: {lola._state}") timestep = TimeStep( step_type=0, reward=1, diff --git a/pax/lola/network.py b/pax/lola/network.py index e69de29b..04abacd1 100644 --- a/pax/lola/network.py +++ b/pax/lola/network.py @@ -0,0 +1,50 @@ +from typing import Optional + +from pax import utils + +import distrax +import haiku as hk +import jax.numpy as jnp + + +class CategoricalValueHead(hk.Module): + """Network head that produces a categorical distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._logit_layer = hk.Linear( + num_values, + w_init=hk.initializers.Constant(0.5), + with_bias=False, + ) + self._value_layer = hk.Linear( + 1, + w_init=hk.initializers.Constant(0.5), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + logits = self._logit_layer(inputs) + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + return (distrax.Categorical(logits=logits), value) + + +def make_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + CategoricalValueHead(num_values=num_actions), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network diff --git a/pax/runner.py b/pax/runner.py index fcab1184..7bbf9880 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -64,7 +64,7 @@ def train_loop(self, env, agents, num_episodes, watchers): ) # end of episode stats - self.train_episodes += 1 + self.train_episodes += env.num_envs rewards_0 = jnp.array(rewards_0) rewards_1 = jnp.array(rewards_1) From fce83a48e7a965b6b6f60c9cc0b0e2ad728759dc Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Fri, 15 Jul 2022 15:41:17 +0100 Subject: [PATCH 11/29] update confs --- pax/conf/experiment/lola.yaml | 38 +++++++++++++++++++++++++++++ pax/conf/experiment/ppo_memory.yaml | 10 ++++---- 2 files changed, 43 insertions(+), 5 deletions(-) create mode 100644 pax/conf/experiment/lola.yaml diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml new file mode 100644 index 00000000..9ab3cb61 --- /dev/null +++ b/pax/conf/experiment/lola.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +# Agents +agent1: 'LOLA' +agent2: 'LOLA' + +# Environment +env_id: ipd +game: ipd +env_type: finite +env_discount: 0.99 +payoff: + +# Training hyperparameters +num_envs: 2 +num_steps: 10 # number of steps per episode +total_timesteps: 10000 +eval_every: 50 # timesteps + +# Useful information +# num_episodes = total_timesteps / num_steps +# num_updates = num_episodes / eval_every +# batch_size = num_envs * num_steps + +# LOLA agent parameters +lola: + replay_capacity: 100000 #args.lola.replay_capacity, + min_replay_size: 50 #args.lola.min_replay_size, + sgd_period: 1 #args.dqn.sgd_period, + batch_size: 2 #args.dqn.batch_size + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipd + group: 'LOLA-vs-${agent2}-${game}' + name: run-seed-${seed} + log: False diff --git a/pax/conf/experiment/ppo_memory.yaml b/pax/conf/experiment/ppo_memory.yaml index 82bfe765..9686e5e0 100644 --- a/pax/conf/experiment/ppo_memory.yaml +++ b/pax/conf/experiment/ppo_memory.yaml @@ -13,8 +13,8 @@ payoff: # Training hyperparameters num_envs: 100 -num_steps: 100 # number of steps per episode -total_timesteps: 1_000_000 +num_steps: 25 # number of steps per episode +total_timesteps: 2_000_000 eval_every: 50_000 # timesteps # Useful information @@ -34,10 +34,10 @@ ppo: max_gradient_norm: 0.5 anneal_entropy: True entropy_coeff_start: 0.2 - entropy_coeff_horizon: 500_000 + entropy_coeff_horizon: 1_000_000 entropy_coeff_end: 0.001 lr_scheduling: True - learning_rate: 2.5e-3 + learning_rate: 2.5e-2 adam_epsilon: 1e-5 with_memory: True @@ -45,6 +45,6 @@ ppo: wandb: entity: "ucl-dark" project: ipd - group: 'PPO_memory-vs-${agent2}-${game}' + group: 'PPO_memory-vs-${agent2}-${game}-${entropy_coeff_horizon}' name: run-seed-${seed} log: True From d1be0c5d5a11e39127c3e7135a9133813257c05d Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Tue, 19 Jul 2022 14:50:55 +0100 Subject: [PATCH 12/29] add naive learner --- pax/experiment.py | 27 +++ pax/naive/__init__.py | 0 pax/naive/buffer.py | 160 ++++++++++++++ pax/naive/naive.py | 483 ++++++++++++++++++++++++++++++++++++++++++ pax/naive/network.py | 50 +++++ pax/watchers.py | 16 +- 6 files changed, 735 insertions(+), 1 deletion(-) create mode 100644 pax/naive/__init__.py create mode 100644 pax/naive/buffer.py create mode 100644 pax/naive/naive.py create mode 100644 pax/naive/network.py diff --git a/pax/experiment.py b/pax/experiment.py index e236c172..521e54ee 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -14,6 +14,7 @@ from pax.ppo.ppo import make_agent from pax.ppo.ppo_gru import make_gru_agent from pax.lola.lola import make_lola +from pax.naive.naive import make_naive from pax.runner import Runner from pax.sac.agent import SAC from pax.strategies import ( @@ -39,6 +40,7 @@ policy_logger_ppo, value_logger_ppo, policy_logger_ppo_with_memory, + naive_losses, ) import hydra @@ -187,6 +189,19 @@ def get_PPO_agent(seed, player_id): ) return ppo_agent + def get_naive_agent(seed, player_id): + dummy_env = SequentialMatrixGame( + args.num_envs, args.payoff, args.num_steps + ) + naive_agent = make_naive( + args, + obs_spec=(dummy_env.observation_spec().num_values,), + action_spec=dummy_env.action_spec().num_values, + seed=seed, + player_id=player_id, + ) + return naive_agent + def get_LOLA_agent(seed, player_id): dummy_env = SequentialMatrixGame( args.num_envs, args.payoff, args.num_steps @@ -236,6 +251,7 @@ def get_hyper_agent(seed, player_id): "SAC": get_SAC_agent, "DQN": get_DQN_agent, "PPO": get_PPO_agent, + "Naive": get_naive_agent, "LOLA": get_LOLA_agent, "PPO_memory": get_PPO_memory_agent, # HyperNetworks @@ -300,6 +316,16 @@ def ppo_log(agent): wandb.log(losses) return + def naive_log(agent): + losses = naive_losses(agent) + policy = policy_logger_ppo(agent) + value = value_logger_ppo(agent) + losses.update(value) + losses.update(policy) + if args.wandb.log: + wandb.log(losses) + return + def dumb_log(agent, *args): return @@ -324,6 +350,7 @@ def hyper_log(agent): "SAC": sac_log, "DQN": dqn_log, "PPO": ppo_log, + "Naive": naive_log, "LOLA": dumb_log, "PPO_memory": ppo_log, "Hyper": hyper_log, diff --git a/pax/naive/__init__.py b/pax/naive/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pax/naive/buffer.py b/pax/naive/buffer.py new file mode 100644 index 00000000..10b6bd0d --- /dev/null +++ b/pax/naive/buffer.py @@ -0,0 +1,160 @@ +from typing import NamedTuple, Tuple + +from dm_env import TimeStep +import jax +import jax.numpy as jnp +import numpy as np + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + +class TrajectoryBuffer: + """ + A buffer for storing trajectories experienced by a PPO agent + interacting with the environment. The buffer's capacity should equal + the number of steps * number of environments + 1 + """ + + def __init__( + self, + num_envs: int, + num_steps: int, + obs_space: Tuple[int], # (envs.observation_spec().num_values, ) + gru_dim: int = 1, + ): + + # Buffer information + self._num_envs = num_envs + self._num_steps = ( + num_steps + 1 + ) # Take an additional rollout step for boostrapping value + self._rollout_length = num_steps * num_envs + + # Environment specs + self.obs_space = obs_space + self.gru_dim = gru_dim + + # Initialise pointers + self.ptr = 0 + + # extra info + self._num_added = 0 + self.full = False + self.reset() + + def add( + self, + timestep: TimeStep, + action: jnp.ndarray, + log_prob: jnp.ndarray, + value: jnp.ndarray, + new_timestep: TimeStep, + hidden: jnp.ndarray = None, + ): + """Append a batched time step to the buffer. + Resets buffer and ptr if the buffer is full.""" + + if self.full: + self.reset() + + self.observations = self.observations.at[:, self.ptr].set( + timestep.observation + ) + self.actions = self.actions.at[:, self.ptr].set(action) + self.behavior_log_probs = self.behavior_log_probs.at[:, self.ptr].set( + log_prob + ) + self.behavior_values = jax.lax.stop_gradient( + self.behavior_values.at[:, self.ptr].set(value.flatten()) + ) + self.dones = self.dones.at[:, self.ptr].set(new_timestep.step_type) + self.rewards = self.rewards.at[:, self.ptr].set(new_timestep.reward) + + if hidden is not None: + self.hiddens = self.hiddens.at[:, self.ptr].set(hidden) + + self.ptr += 1 + self._num_added += self._num_envs + + if self.ptr == self._num_steps: + self.full = True + + def sample(self): + """Returns current data""" + return Sample( + observations=self.observations, + actions=self.actions, + rewards=self.rewards, + behavior_log_probs=self.behavior_log_probs, + behavior_values=self.behavior_values, + dones=self.dones, + hiddens=self.hiddens, + ) + + def size(self) -> int: + return min(self._rollout_length, self._num_added) + + def fraction_filled(self) -> float: + return self.size / self._rollout_length + + def reset(self): + """Resets the replay buffer. Called upon __init__ and when buffer is full""" + self.ptr = 0 + self.full = False + + self.observations = jnp.zeros( + (self._num_envs, self._num_steps, *self.obs_space) + ) + + self.actions = jnp.zeros( + ( + self._num_envs, + self._num_steps, + ), + dtype="int32", + ) + + self.behavior_log_probs = jnp.zeros( + ( + self._num_envs, + self._num_steps, + ) + ) + self.rewards = jnp.zeros( + ( + self._num_envs, + self._num_steps, + ) + ) + + self.dones = jnp.zeros( + ( + self._num_envs, + self._num_steps, + ) + ) + self.behavior_values = jnp.zeros( + ( + self._num_envs, + self._num_steps, + ) + ) + self.hiddens = jnp.zeros( + (self._num_envs, self._num_steps, self.gru_dim) + ) + + self.parameters = jnp.zeros((self._num_envs, self._num_steps)) + + +if __name__ == "__main__": + pass diff --git a/pax/naive/naive.py b/pax/naive/naive.py new file mode 100644 index 00000000..82a1a3ea --- /dev/null +++ b/pax/naive/naive.py @@ -0,0 +1,483 @@ +# Adapted from https://github.com/deepmind/acme/blob/master/acme/agents/jax/ppo/learning.py + +from typing import Any, Mapping, NamedTuple, Tuple, Dict + +from pax import utils +from pax.naive.buffer import TrajectoryBuffer +from pax.naive.network import make_network + +from dm_env import TimeStep +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import optax + + +class Batch(NamedTuple): + """A batch of data; all shapes are expected to be [B, ...].""" + + observations: jnp.ndarray + actions: jnp.ndarray + advantages: jnp.ndarray + + # Target value estimate used to bootstrap the value function. + target_values: jnp.ndarray + + # Value estimate and action log-prob at behavior time. + behavior_values: jnp.ndarray + behavior_log_probs: jnp.ndarray + + +class TrainingState(NamedTuple): + """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" + + params: hk.Params + opt_state: optax.GradientTransformation + random_key: jnp.ndarray + timesteps: int + extras: Mapping[str, jnp.ndarray] + + +class Logger: + metrics: dict + + +class NaiveLearner: + """A simple naive learner agent using JAX + This agent has a few variations on the original naive learner from LOLA (Foerster 2017, et al) + Notably: + - LOLA uses a baseline for variance reduction; ours uses generalized advantages estimation + """ + + def __init__( + self, + network: NamedTuple, + optimizer: optax.GradientTransformation, + random_key: jnp.ndarray, + obs_spec: Tuple, + num_envs: int = 4, + num_steps: int = 500, + num_minibatches: int = 16, + num_epochs: int = 4, + gamma: float = 0.99, + gae_lambda: float = 0.95, + ): + @jax.jit + def policy( + params: hk.Params, observation: TimeStep, state: TrainingState + ): + """Agent policy to select actions and calculate agent specific information""" + key, subkey = jax.random.split(state.random_key) + dist, values = network.apply(params, observation) + actions = dist.sample(seed=subkey) + state.extras["values"] = values + state.extras["log_probs"] = dist.log_prob(actions) + state = TrainingState( + params=params, + opt_state=state.opt_state, + random_key=key, + timesteps=state.timesteps, + extras=state.extras, + ) + return actions, state + + def rollouts( + buffer: TrajectoryBuffer, + t: TimeStep, + actions: np.array, + t_prime: TimeStep, + state: TrainingState, + ) -> None: + """Stores rollout in buffer""" + log_probs, values = ( + state.extras["log_probs"], + state.extras["values"], + ) + buffer.add(t, actions, log_probs, values, t_prime) + + def gae_advantages( + rewards: jnp.ndarray, values: jnp.ndarray, dones: jnp.ndarray + ) -> jnp.ndarray: + """Calculates the gae advantages from a sequence. Note that the + arguments are of length = rollout length + 1""" + # Only need up to the rollout length + rewards = rewards[:-1] + dones = dones[:-1] + + # 'Zero out' the terminated states + discounts = gamma * jnp.where(dones < 2, 1, 0) + + delta = rewards + discounts * values[1:] - values[:-1] + advantage_t = [0.0] + for t in reversed(range(delta.shape[0])): + advantage_t.insert( + 0, delta[t] + gae_lambda * discounts[t] * advantage_t[0] + ) + advantages = jax.lax.stop_gradient(jnp.array(advantage_t[:-1])) + + # this is where the gae function will end + target_values = values[:-1] + advantages # Q-value estimates + target_values = jax.lax.stop_gradient(target_values) + return advantages, target_values + + def loss( + params: hk.Params, + timesteps: int, + observations: jnp.ndarray, + actions: jnp.array, + behavior_log_probs: jnp.array, + target_values: jnp.array, + advantages: jnp.array, + behavior_values: jnp.array, + ): + """Surrogate loss using clipped probability ratios.""" + distribution, values = network.apply(params, observations) + log_prob = distribution.log_prob(actions) + + # ACTOR + # Importance sampling weights: current policy / behavior policy. + rhos = jnp.exp(log_prob - behavior_log_probs) + policy_loss = -jnp.mean(rhos * advantages) + + # CRITIC + value_loss = jnp.mean((target_values - values) ** 2) + total_loss = policy_loss + value_loss + + return total_loss, { + "loss_total": total_loss, + "loss_policy": policy_loss, + "loss_value": value_loss, + } + + @jax.jit + def sgd_step( + state: TrainingState, sample: NamedTuple + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + """Performs a minibatch SGD step, returning new state and metrics.""" + + # Extract data + ( + observations, + actions, + rewards, + behavior_log_probs, + behavior_values, + dones, + ) = ( + sample.observations, + sample.actions, + sample.rewards, + sample.behavior_log_probs, + sample.behavior_values, + sample.dones, + ) + + # vmap + batch_gae_advantages = jax.vmap(gae_advantages, in_axes=0) + advantages, target_values = batch_gae_advantages( + rewards=rewards, values=behavior_values, dones=dones + ) + + # Exclude the last step - it was only used for bootstrapping. + # The shape is [num_envs, num_steps, ..] + ( + observations, + actions, + behavior_log_probs, + behavior_values, + ) = jax.tree_map( + lambda x: x[:, :-1], + (observations, actions, behavior_log_probs, behavior_values), + ) + + trajectories = Batch( + observations=observations, + actions=actions, + advantages=advantages, + behavior_log_probs=behavior_log_probs, + target_values=target_values, + behavior_values=behavior_values, + ) + + # Concatenate all trajectories. Reshape from [num_envs, num_steps, ..] + # to [num_envs * num_steps,..] + assert len(target_values.shape) > 1 + num_envs = target_values.shape[0] + num_steps = target_values.shape[1] + batch_size = num_envs * num_steps + assert batch_size % num_minibatches == 0, ( + "Num minibatches must divide batch size. Got batch_size={}" + " num_minibatches={}." + ).format(batch_size, num_minibatches) + + batch = jax.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), trajectories + ) + + # Compute gradients. + grad_fn = jax.grad(loss, has_aux=True) + + def model_update_minibatch( + carry: Tuple[hk.Params, optax.OptState, int], + minibatch: Batch, + ) -> Tuple[ + Tuple[hk.Params, optax.OptState, int], Dict[str, jnp.ndarray] + ]: + """Performs model update for a single minibatch.""" + params, opt_state, timesteps = carry + # Normalize advantages at the minibatch level before using them. + advantages = ( + minibatch.advantages + - jnp.mean(minibatch.advantages, axis=0) + ) / (jnp.std(minibatch.advantages, axis=0) + 1e-8) + gradients, metrics = grad_fn( + params, + timesteps, + minibatch.observations, + minibatch.actions, + minibatch.behavior_log_probs, + minibatch.target_values, + advantages, + minibatch.behavior_values, + ) + + # Apply updates + updates, opt_state = optimizer.update(gradients, opt_state) + params = optax.apply_updates(params, updates) + + metrics["norm_grad"] = optax.global_norm(gradients) + metrics["norm_updates"] = optax.global_norm(updates) + return (params, opt_state, timesteps), metrics + + def model_update_epoch( + carry: Tuple[ + jnp.ndarray, hk.Params, optax.OptState, int, Batch + ], + unused_t: Tuple[()], + ) -> Tuple[ + Tuple[jnp.ndarray, hk.Params, optax.OptState, Batch], + Dict[str, jnp.ndarray], + ]: + """Performs model updates based on one epoch of data.""" + key, params, opt_state, timesteps, batch = carry + key, subkey = jax.random.split(key) + permutation = jax.random.permutation(subkey, batch_size) + shuffled_batch = jax.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_map( + lambda x: jnp.reshape( + x, [num_minibatches, -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + + (params, opt_state, timesteps), metrics = jax.lax.scan( + model_update_minibatch, + (params, opt_state, timesteps), + minibatches, + length=num_minibatches, + ) + return (key, params, opt_state, timesteps, batch), metrics + + params = state.params + opt_state = state.opt_state + timesteps = state.timesteps + + # Repeat training for the given number of epoch, taking a random + # permutation for every epoch. + # signature is scan(function, carry, tuple to iterate over, length) + (key, params, opt_state, timesteps, _), metrics = jax.lax.scan( + model_update_epoch, + (state.random_key, params, opt_state, timesteps, batch), + (), + length=num_epochs, + ) + + metrics = jax.tree_map(jnp.mean, metrics) + metrics["rewards_mean"] = jnp.mean( + jnp.abs(jnp.mean(rewards, axis=(0, 1))) + ) + metrics["rewards_std"] = jnp.std(rewards, axis=(0, 1)) + + new_state = TrainingState( + params=params, + opt_state=opt_state, + random_key=key, + timesteps=timesteps, + extras={"log_probs": None, "values": None}, + ) + + return new_state, metrics + + def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + key, subkey = jax.random.split(key) + dummy_obs = jnp.zeros(shape=obs_spec) + dummy_obs = utils.add_batch_dim(dummy_obs) + initial_params = network.init(subkey, dummy_obs) + initial_opt_state = optimizer.init(initial_params) + return TrainingState( + params=initial_params, + opt_state=initial_opt_state, + random_key=key, + timesteps=0, + extras={"values": None, "log_probs": None}, + ) + + # Initialise training state (parameters, optimiser state, extras). + self._state = make_initial_state(random_key, obs_spec) + + # Initialize buffer and sgd + self._trajectory_buffer = TrajectoryBuffer( + num_envs, num_steps, obs_spec + ) + self._sgd_step = sgd_step + + # Set up counters and logger + self._logger = Logger() + self._total_steps = 0 + self._until_sgd = 0 + self._logger.metrics = { + "total_steps": 0, + "sgd_steps": 0, + "loss_total": 0, + "loss_policy": 0, + "loss_value": 0, + } + + # Initialize functions + self._policy = policy + self._rollouts = rollouts + + # Other useful hyperparameters + self._num_envs = num_envs # number of environments + self._num_steps = num_steps # number of steps per environment + self._batch_size = int(num_envs * num_steps) # number in one batch + self._num_minibatches = num_minibatches # number of minibatches + self._num_epochs = num_epochs # number of epochs to use sample + + def select_action(self, t: TimeStep): + """Selects action and updates info with PPO specific information""" + actions, self._state = self._policy( + self._state.params, t.observation, self._state + ) + return utils.to_numpy(actions) + + def update( + self, + t: TimeStep, + actions: np.array, + t_prime: TimeStep, + other_agents=None, + ): + # Adds agent and environment info to buffer + self._rollouts( + buffer=self._trajectory_buffer, + t=t, + actions=actions, + t_prime=t_prime, + state=self._state, + ) + + # Log metrics + self._total_steps += self._num_envs + self._logger.metrics["total_steps"] += self._num_envs + + # Update internal state with total_steps + self._state = TrainingState( + params=self._state.params, + opt_state=self._state.opt_state, + random_key=self._state.random_key, + timesteps=self._total_steps, + extras=self._state.extras, + ) + + # Update counter until doing SGD + self._until_sgd += 1 + + # Rollouts onging + if self._until_sgd % (self._num_steps) != 0: + return + + # Rollouts complete -> Training begins + # Add an additional rollout step for advantage calculation + _, self._state = self._policy( + self._state.params, t_prime.observation, self._state + ) + + self._trajectory_buffer.add( + timestep=t_prime, + action=0, + log_prob=0, + value=self._state.extras["values"] + if not t_prime.last() + else jnp.zeros_like(self._state.extras["values"]), + new_timestep=t_prime, + ) + + sample = self._trajectory_buffer.sample() + self._state, results = self._sgd_step(self._state, sample) + self._logger.metrics["sgd_steps"] += ( + self._num_minibatches * self._num_epochs + ) + self._logger.metrics["loss_total"] = results["loss_total"] + self._logger.metrics["loss_policy"] = results["loss_policy"] + self._logger.metrics["loss_value"] = results["loss_value"] + + +def make_naive(args, obs_spec, action_spec, seed: int, player_id: int): + """Make Naive Learner Policy Gradient agent""" + + print(f"Making network for {args.env_id}") + network = make_network(action_spec) + + # Optimizer + batch_size = int(args.num_envs * args.num_steps) + transition_steps = ( + args.total_timesteps + / batch_size + * args.ppo.num_epochs + * args.ppo.num_minibatches + ) + + if args.ppo.lr_scheduling: + scheduler = optax.linear_schedule( + init_value=args.ppo.learning_rate, + end_value=0, + transition_steps=transition_steps, + ) + optimizer = optax.chain( + optax.clip_by_global_norm(args.ppo.max_gradient_norm), + optax.scale_by_adam(eps=args.ppo.adam_epsilon), + optax.scale_by_schedule(scheduler), + optax.scale(-1), + ) + + else: + optimizer = optax.chain( + optax.clip_by_global_norm(args.ppo.max_gradient_norm), + optax.scale_by_adam(eps=args.ppo.adam_epsilon), + optax.scale(-args.ppo.learning_rate), + ) + + # Random key + random_key = jax.random.PRNGKey(seed=seed) + + return NaiveLearner( + network=network, + optimizer=optimizer, + random_key=random_key, + obs_spec=obs_spec, + num_envs=args.num_envs, + num_steps=args.num_steps, + num_minibatches=args.ppo.num_minibatches, + num_epochs=args.ppo.num_epochs, + gamma=args.ppo.gamma, + gae_lambda=args.ppo.gae_lambda, + ) + + +if __name__ == "__main__": + pass diff --git a/pax/naive/network.py b/pax/naive/network.py new file mode 100644 index 00000000..04abacd1 --- /dev/null +++ b/pax/naive/network.py @@ -0,0 +1,50 @@ +from typing import Optional + +from pax import utils + +import distrax +import haiku as hk +import jax.numpy as jnp + + +class CategoricalValueHead(hk.Module): + """Network head that produces a categorical distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._logit_layer = hk.Linear( + num_values, + w_init=hk.initializers.Constant(0.5), + with_bias=False, + ) + self._value_layer = hk.Linear( + 1, + w_init=hk.initializers.Constant(0.5), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + logits = self._logit_layer(inputs) + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + return (distrax.Categorical(logits=logits), value) + + +def make_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + CategoricalValueHead(num_values=num_actions), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network diff --git a/pax/watchers.py b/pax/watchers.py index 55dae8cd..8bd15b69 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -155,6 +155,20 @@ def logger_hyper(agent) -> None: return cooperation_probs +def naive_losses(agent) -> None: + sgd_steps = agent._logger.metrics["sgd_steps"] + loss_total = agent._logger.metrics["loss_total"] + loss_policy = agent._logger.metrics["loss_policy"] + loss_value = agent._logger.metrics["loss_value"] + losses = { + "sgd_steps": sgd_steps, + "train/total": loss_total, + "train/policy": loss_policy, + "train/value": loss_value, + } + return losses + + def ppo_losses(agent) -> None: sgd_steps = agent._logger.metrics["sgd_steps"] loss_total = agent._logger.metrics["loss_total"] @@ -173,7 +187,7 @@ def ppo_losses(agent) -> None: return losses -def policy_logger_lola(agent) -> None: +def policy_logger_naive(agent) -> None: weights = agent._state.params["categorical_value_head/~/linear"]["w"] pi = nn.softmax(weights) sgd_steps = agent._total_steps / agent._num_steps From a855c4bdda790527c6d6c48a84037094d8a1a28c Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Fri, 22 Jul 2022 13:24:40 +0100 Subject: [PATCH 13/29] lazy commit. commiting to add naive learner PR --- pax/centralized_learners.py | 10 + pax/conf/config.yaml | 11 +- pax/conf/experiment/lola.yaml | 12 +- pax/experiment.py | 22 +- pax/lola/lola.py | 642 +++++++++++++++++++++++++++------- pax/runner.py | 8 +- pax/watchers.py | 58 ++- 7 files changed, 631 insertions(+), 132 deletions(-) diff --git a/pax/centralized_learners.py b/pax/centralized_learners.py index 67f5ed30..fa77d73d 100644 --- a/pax/centralized_learners.py +++ b/pax/centralized_learners.py @@ -18,6 +18,16 @@ def select_action(self, timesteps: List[TimeStep]) -> List[jnp.ndarray]: agent.select_action(t) for agent, t in zip(self.agents, timesteps) ] + def lookahead(self, env): + """Simulates a rollout and gradient update""" + counter = 0 + for agent in self.agents: + # All other agents in a list + # i.e. if i am agent2, then other_agents=[agent1, agent3, agent4 ...] + other_agents = self.agents[:counter] + self.agents[counter + 1 :] + agent.lookhead(env, other_agents) + counter += 1 + def update( self, old_timesteps: List[TimeStep], diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index 8c4e1c53..100819a6 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -22,7 +22,7 @@ env_id: ipd game: ipd env_type: finite env_discount: 0.99 -payoff: +payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] centralized: True # Training hyperparameters @@ -74,13 +74,16 @@ naive: lr: 1.0 # LOLA agent parameters -# lola: -# ... +lola: + replay_capacity: 100000 #args.lola.replay_capacity, + min_replay_size: 50 #args.lola.min_replay_size, + sgd_period: 1 #args.dqn.sgd_period, + batch_size: 2 #args.dqn.batch_size # Logging setup wandb: entity: "ucl-dark" project: ipd - group: '${agent1}-vs-${agent2}-${game}-with-memory=${ppo.with_memory}-v3' + group: 'LOLA-vs-${agent2}-${game}' name: run-seed-${seed} log: False diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index 9ab3cb61..92dac307 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -3,19 +3,21 @@ # Agents agent1: 'LOLA' agent2: 'LOLA' +centralized: True # Environment env_id: ipd game: ipd env_type: finite env_discount: 0.99 -payoff: +payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] + # Training hyperparameters -num_envs: 2 -num_steps: 10 # number of steps per episode -total_timesteps: 10000 -eval_every: 50 # timesteps +num_envs: 10 +num_steps: 150 # number of steps per episode +total_timesteps: 100_000 +eval_every: 4000 # timesteps # Useful information # num_episodes = total_timesteps / num_steps diff --git a/pax/experiment.py b/pax/experiment.py index 7987122b..a2754b06 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -45,6 +45,11 @@ value_logger_ppo, policy_logger_ppo_with_memory, naive_losses, + policy_logger_naive, + policy_logger_lola, + losses_lola, + value_logger_lola, + value_logger_naive, ) import hydra @@ -321,6 +326,7 @@ def get_naive_learner(seed, player_id): logger.info(f"Agent seeds: {seeds[0]} | {seeds[1]}") if args.centralized: + logger.info("Training: Centralized") return CentralizedLearners([agent_0, agent_1]) return IndependentLearners([agent_0, agent_1]) @@ -359,8 +365,18 @@ def ppo_log(agent): def naive_log(agent): losses = naive_losses(agent) - policy = policy_logger_ppo(agent) - value = value_logger_ppo(agent) + policy = policy_logger_naive(agent) + value = value_logger_naive(agent) + losses.update(value) + losses.update(policy) + if args.wandb.log: + wandb.log(losses) + return + + def lola_log(agent): + losses = losses_lola(agent) + policy = policy_logger_lola(agent) + value = value_logger_lola(agent) losses.update(value) losses.update(policy) if args.wandb.log: @@ -397,7 +413,7 @@ def naive_logger(agent): "DQN": dqn_log, "PPO": ppo_log, "Naive": naive_log, - "LOLA": dumb_log, + "LOLA": lola_log, "PPO_memory": ppo_log, "Hyper": hyper_log, "NaiveLearner": naive_logger, diff --git a/pax/lola/lola.py b/pax/lola/lola.py index f622ca9b..49cc4f14 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -1,53 +1,180 @@ -# Learning with Opponent-Learning Awareness (LOLA) implementation in JAX -# https://arxiv.org/pdf/1709.04326.pdf +# Adapted from https://github.com/deepmind/acme/blob/master/acme/agents/jax/ppo/learning.py -from typing import NamedTuple, Any, Mapping -from pax.lola.buffer import Replay +from typing import Any, Mapping, NamedTuple, Tuple, Dict + +from pax import utils +from pax.lola.buffer import TrajectoryBuffer from pax.lola.network import make_network + from dm_env import TimeStep import haiku as hk import jax import jax.numpy as jnp import numpy as np +import optax + + +class Batch(NamedTuple): + """A batch of data; all shapes are expected to be [B, ...].""" + + observations: jnp.ndarray + actions: jnp.ndarray + advantages: jnp.ndarray + + # Target value estimate used to bootstrap the value function. + target_values: jnp.ndarray + + # Value estimate and action log-prob at behavior time. + behavior_values: jnp.ndarray + behavior_log_probs: jnp.ndarray class TrainingState(NamedTuple): + """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" + params: hk.Params + opt_state: optax.GradientTransformation random_key: jnp.ndarray + timesteps: int extras: Mapping[str, jnp.ndarray] -class LOLA: - """Implements LOLA with policy gradient""" +class Hp: + def __init__(self): + self.lr_out = 0.2 + self.lr_in = 0.3 + self.lr_v = 0.1 + self.gamma = 0.96 + self.n_update = 200 + self.len_rollout = 150 + self.batch_size = 128 + self.use_baseline = True + self.seed = 42 + + +hp = Hp() + + +def magic_box(x): + return jnp.exp(x - jax.lax.stop_gradient(x)) + + +class Logger: + metrics: dict + + +# class Memory: +# def __init__(self): +# self.self_logprobs = [] +# self.other_logprobs = [] +# self.values = [] +# self.rewards = [] + +# def add(self, lp, other_lp, v, r): +# self.self_logprobs.append(lp) +# self.other_logprobs.append(other_lp) +# self.values.append(v) +# self.rewards.append(r) + +# def dice_objective(self): +# self_logprobs = jnp.stack(self.self_logprobs, dim=1) +# other_logprobs = jnp.stack(self.other_logprobs, dim=1) +# values = jnp.stack(self.values, dim=1) +# rewards = jnp.stack(self.rewards, dim=1) + +# # apply discount: +# cum_discount = ( +# jnp.cumprod(hp.gamma * jnp.ones(*rewards.size()), dim=1) / hp.gamma +# ) +# discounted_rewards = rewards * cum_discount +# discounted_values = values * cum_discount + +# # stochastics nodes involved in rewards dependencies: +# dependencies = jnp.cumsum(self_logprobs + other_logprobs, dim=1) + +# # logprob of each stochastic nodes: +# stochastic_nodes = self_logprobs + other_logprobs + +# # dice objective: +# dice_objective = jnp.mean( +# jnp.sum(magic_box(dependencies) * discounted_rewards, dim=1) +# ) + +# if hp.use_baseline: +# # variance_reduction: +# baseline_term = jnp.mean( +# jnp.sum( +# (1 - magic_box(stochastic_nodes)) * discounted_values, +# dim=1, +# ) +# ) +# dice_objective = dice_objective + baseline_term + +# return -dice_objective # want to minimize -objective + +# def value_loss(self): +# values = jnp.stack(self.values, dim=1) +# rewards = jnp.stack(self.rewards, dim=1) +# return jnp.mean((rewards - values) ** 2) + - # TODO - # 1. implement policy gradient - # 2. Implement loss function - # 3. Implement storing parameters in the buffer so that both agents can access - # 4. Second order LOLA term +# ipd = IPD(hp.len_rollout, hp.batch_size) + + +# def act(batch_states, theta, values): +# batch_states = jnp.from_numpy(batch_states).long() +# probs = jnp.sigmoid(theta)[batch_states] +# m = Bernoulli(1 - probs) +# actions = m.sample() +# log_probs_actions = m.log_prob(actions) +# return actions.numpy().astype(int), log_probs_actions, values[batch_states] + + +# def get_gradient(objective, theta): +# # create differentiable gradient for 2nd orders: +# grad_objective = torch.autograd.grad( +# objective, (theta), create_graph=True +# )[0] +# return grad_objective + + +# def step(theta1, theta2, values1, values2): +# # just to evaluate progress: +# (s1, s2), _ = ipd.reset() +# score1 = 0 +# score2 = 0 +# for t in range(hp.len_rollout): +# a1, lp1, v1 = act(s1, theta1, values1) +# a2, lp2, v2 = act(s2, theta2, values2) +# (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) +# # cumulate scores +# score1 += np.mean(r1) / float(hp.len_rollout) +# score2 += np.mean(r2) / float(hp.len_rollout) +# return (score1, score2) + + +class LOLA: + """LOLA with the DiCE objective function.""" def __init__( self, - network: hk.Params, + network: NamedTuple, + optimizer: optax.GradientTransformation, random_key: jnp.ndarray, player_id: int, - replay_capacity: int = 100000, - min_replay_size: int = 1000, - sgd_period: int = 1, - batch_size: int = 256, - obs_spec: tuple = (5,), + obs_spec: Tuple, num_envs: int = 4, - num_steps: int = 500, + num_steps: int = 150, + num_minibatches: int = 1, + num_epochs: int = 1, + gamma: float = 0.96, ): + + # @jax.jit def policy( - params: hk.Params, observation: jnp.ndarray, state: TrainingState + params: hk.Params, observation: TimeStep, state: TrainingState ): - """Determines how to choose an action""" - # key, subkey = jax.random.split(state.random_key) - # logits = network.apply(params, observation) - # actions = jax.random.categorical(subkey, logits) - # state = state._replace(random_key=key) - # return actions, state + """Agent policy to select actions and calculate agent specific information""" key, subkey = jax.random.split(state.random_key) dist, values = network.apply(params, observation) actions = dist.sample(seed=subkey) @@ -55,13 +182,15 @@ def policy( state.extras["log_probs"] = dist.log_prob(actions) state = TrainingState( params=params, + opt_state=state.opt_state, random_key=key, + timesteps=state.timesteps, extras=state.extras, ) return actions, state def rollouts( - buffer: Replay, + buffer: TrajectoryBuffer, t: TimeStep, actions: np.array, t_prime: TimeStep, @@ -74,143 +203,420 @@ def rollouts( ) buffer.add(t, actions, log_probs, values, t_prime) - def loss(): - """Loss function""" - # see pg. 4 of the LOLA paper - pass + def loss( + logprobs: jnp.ndarray, + other_logprobs: jnp.ndarray, + values: jnp.ndarray, + rewards: jnp.ndarray, + ): + logprobs = jnp.stack(logprobs, dim=1) + other_logprobs = jnp.stack(other_logprobs, dim=1) + values = jnp.stack(self.values, dim=1) + rewards = jnp.stack(self.rewards, dim=1) + + # apply discount: + cum_discount = ( + jnp.cumprod(gamma * jnp.ones(*rewards.size()), dim=1) / gamma + ) + discounted_rewards = rewards * cum_discount + discounted_values = values * cum_discount + + # stochastics nodes involved in rewards dependencies: + dependencies = jnp.cumsum(logprobs + other_logprobs, dim=1) + + # logprob of each stochastic nodes: + stochastic_nodes = logprobs + other_logprobs + + # dice objective: + dice_objective = jnp.mean( + jnp.sum(magic_box(dependencies) * discounted_rewards, dim=1) + ) + + baseline_term = jnp.mean( + jnp.sum( + (1 - magic_box(stochastic_nodes)) * discounted_values, + dim=1, + ) + ) + dice_objective = dice_objective + baseline_term + + return -dice_objective, {} # want to minimize -objective + + @jax.jit + def sgd_step( + state: TrainingState, + other_agent_params: hk.Params, + sample: NamedTuple, + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + """Performs a minibatch SGD step, returning new state and metrics.""" + + # Extract data + ( + observations, + actions, + rewards, + behavior_log_probs, + behavior_values, + dones, + ) = ( + sample.observations, + sample.actions, + sample.rewards, + sample.behavior_log_probs, + sample.behavior_values, + sample.dones, + ) + + # vmap + # TODO: REMOVE THIS INEA IFNEF IAEIF AE + def gae_advantages(): + pass + + batch_gae_advantages = jax.vmap(gae_advantages, in_axes=0) + advantages, target_values = batch_gae_advantages( + rewards=rewards, values=behavior_values, dones=dones + ) + + # Exclude the last step - it was only used for bootstrapping. + # The shape is [num_envs, num_steps, ..] + ( + observations, + actions, + behavior_log_probs, + behavior_values, + ) = jax.tree_map( + lambda x: x[:, :-1], + (observations, actions, behavior_log_probs, behavior_values), + ) + + trajectories = Batch( + observations=observations, + actions=actions, + advantages=advantages, + behavior_log_probs=behavior_log_probs, + target_values=target_values, + behavior_values=behavior_values, + ) + + # Concatenate all trajectories. Reshape from [num_envs, num_steps, ..] + # to [num_envs * num_steps,..] + assert len(target_values.shape) > 1 + num_envs = target_values.shape[0] + num_steps = target_values.shape[1] + batch_size = num_envs * num_steps + assert batch_size % num_minibatches == 0, ( + "Num minibatches must divide batch size. Got batch_size={}" + " num_minibatches={}." + ).format(batch_size, num_minibatches) + + batch = jax.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), trajectories + ) + + # Compute gradients. + grad_fn = jax.grad(loss, has_aux=True) + + def model_update_minibatch( + carry: Tuple[hk.Params, optax.OptState, int], + minibatch: Batch, + ) -> Tuple[ + Tuple[hk.Params, optax.OptState, int], Dict[str, jnp.ndarray] + ]: + """Performs model update for a single minibatch.""" + params, opt_state, timesteps = carry + # Normalize advantages at the minibatch level before using them. + advantages = ( + minibatch.advantages + - jnp.mean(minibatch.advantages, axis=0) + ) / (jnp.std(minibatch.advantages, axis=0) + 1e-8) + gradients, metrics = grad_fn( + params, + timesteps, + minibatch.observations, + minibatch.actions, + minibatch.behavior_log_probs, + minibatch.target_values, + advantages, + minibatch.behavior_values, + ) + + # Apply updates + updates, opt_state = optimizer.update(gradients, opt_state) + params = optax.apply_updates(params, updates) + + metrics["norm_grad"] = optax.global_norm(gradients) + metrics["norm_updates"] = optax.global_norm(updates) + return (params, opt_state, timesteps), metrics - def sgd(): - """Stochastic gradient descent""" - pass + def model_update_epoch( + carry: Tuple[ + jnp.ndarray, hk.Params, optax.OptState, int, Batch + ], + unused_t: Tuple[()], + ) -> Tuple[ + Tuple[jnp.ndarray, hk.Params, optax.OptState, Batch], + Dict[str, jnp.ndarray], + ]: + """Performs model updates based on one epoch of data.""" + key, params, opt_state, timesteps, batch = carry + key, subkey = jax.random.split(key) + permutation = jax.random.permutation(subkey, batch_size) + shuffled_batch = jax.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_map( + lambda x: jnp.reshape( + x, [num_minibatches, -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) - def make_initial_state(key: jnp.ndarray) -> TrainingState: - """Make initial training state for LOLA""" + (params, opt_state, timesteps), metrics = jax.lax.scan( + model_update_minibatch, + (params, opt_state, timesteps), + minibatches, + length=num_minibatches, + ) + return (key, params, opt_state, timesteps, batch), metrics + + params = state.params + opt_state = state.opt_state + timesteps = state.timesteps + + # Repeat training for the given number of epoch, taking a random + # permutation for every epoch. + # signature is scan(function, carry, tuple to iterate over, length) + (key, params, opt_state, timesteps, _), metrics = jax.lax.scan( + model_update_epoch, + (state.random_key, params, opt_state, timesteps, batch), + (), + length=num_epochs, + ) + + metrics = jax.tree_map(jnp.mean, metrics) + metrics["rewards_mean"] = jnp.mean( + jnp.abs(jnp.mean(rewards, axis=(0, 1))) + ) + metrics["rewards_std"] = jnp.std(rewards, axis=(0, 1)) + + new_state = TrainingState( + params=params, + opt_state=opt_state, + random_key=key, + timesteps=timesteps, + extras={"log_probs": None, "values": None}, + ) + + return new_state, metrics + + def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" key, subkey = jax.random.split(key) - dummy_obs = jnp.zeros(shape=(1, 5)) - params = network.init(subkey, dummy_obs) + dummy_obs = jnp.zeros(shape=obs_spec) + dummy_obs = utils.add_batch_dim(dummy_obs) + initial_params = network.init(subkey, dummy_obs) + initial_opt_state = optimizer.init(initial_params) return TrainingState( - params=params, + params=initial_params, + opt_state=initial_opt_state, random_key=key, + timesteps=0, extras={"values": None, "log_probs": None}, ) - # init buffer - self._buffer = Replay(num_envs, num_steps, replay_capacity, obs_spec) + # def in_lookahead(self, other_theta, other_values): + # (s1, s2), _ = ipd.reset() + # other_memory = Memory() + # for t in range(hp.len_rollout): + # a1, lp1, v1 = act(s1, self.theta, self.values) + # a2, lp2, v2 = act(s2, other_theta, other_values) + # (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) + # other_memory.add(lp2, lp1, v2, torch.from_numpy(r2).float()) - # init functions - self._state = make_initial_state(random_key) - self._policy = policy - self._rollouts = rollouts - self._sgd_step = sgd + # other_objective = other_memory.dice_objective() + # grad = get_gradient(other_objective, other_theta) + # return grad + + # Initialise training state (parameters, optimiser state, extras). + self._state = make_initial_state(random_key, obs_spec) - # init constants + # Setup player id self.player_id = player_id - self._min_replay_size = min_replay_size - self._sgd_period = sgd_period - self._batch_size = batch_size - # init variables + # Initialize buffer and sgd + self._trajectory_buffer = TrajectoryBuffer( + num_envs, num_steps, obs_spec + ) + self._sgd_step = sgd_step + + # Set up counters and logger + self._logger = Logger() self._total_steps = 0 + self._until_sgd = 0 + self._logger.metrics = { + "total_steps": 0, + "sgd_steps": 0, + "loss_total": 0, + "loss_policy": 0, + "loss_value": 0, + } - # self._trajectory_buffer = TrajectoryBuffer( - # num_envs, num_steps, obs_spec - # ) - # TODO: get the correct parameters for this. - # self._trajectory_buffer = Replay( - # num_envs, num_steps, obs_spec - # ) + # Initialize functions + self._policy = policy + self._rollouts = rollouts + + # Other useful hyperparameters + self._num_envs = num_envs # number of environments + self._num_steps = num_steps # number of steps per environment + self._batch_size = int(num_envs * num_steps) # number in one batch + self._num_minibatches = num_minibatches # number of minibatches + self._num_epochs = num_epochs # number of epochs to use sample def select_action(self, t: TimeStep): - """Select action based on observation""" - # Unpack - params = self._state.params - state = self._state - action, self._state = self._policy(params, t.observation, state) - return action + """Selects action and updates info with PPO specific information""" + actions, self._state = self._policy( + self._state.params, t.observation, self._state + ) + return utils.to_numpy(actions) + + # def lookahead(self, env, other_agents): + # """ + # Performs a rollout using the current parameters of both agents + # and simulates a naive learning update step for the other agent + + # INPUT: + # env: SequentialMatrixGame, an environment object of the game being played + # other_agents: list, a list of objects of the other agents + # """ + # t = env.reset() + # other_memory = Memory() + # for t in range(hp.len_rollout): + # actions, self.w = self._policy( + # self._state.params, t.observation, self._state + # ) + # a1, lp1, v1 = act(s1, self.theta, self.values) + # a2, lp2, v2 = act(s2, other_theta, other_values) + # (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) + # other_memory.add(lp2, lp1, v2, torch.from_numpy(r2).float()) + + # other_objective = other_memory.dice_objective() + # grad = get_gradient(other_objective, other_theta) + # return grad + + # pass def update( self, t: TimeStep, - action: jnp.ndarray, + actions: np.array, t_prime: TimeStep, other_agents: list = None, ): - + # Adds agent and environment info to buffer self._rollouts( - buffer=self._buffer, + buffer=self._trajectory_buffer, t=t, - actions=action, + actions=actions, t_prime=t_prime, state=self._state, ) - self._total_steps += 1 - if self._total_steps % self._sgd_period != 0: - return + # Log metrics + self._total_steps += self._num_envs + self._logger.metrics["total_steps"] += self._num_envs - print(self._buffer.size()) - if self._buffer.size() < self._min_replay_size: + # Update internal state with total_steps + self._state = TrainingState( + params=self._state.params, + opt_state=self._state.opt_state, + random_key=self._state.random_key, + timesteps=self._total_steps, + extras=self._state.extras, + ) + + # Update counter until doing SGD + self._until_sgd += 1 + + # Add params to buffer + # doesn't change throughout the rollout + # but it can't be before the return + self._trajectory_buffer.params = self._state.params + + # Rollouts onging + if self._until_sgd % (self._num_steps) != 0: return - # Do a batch of SGD. - sample, key = self._buffer.sample( - self._batch_size, self._state.random_key + # Rollouts complete -> Training begins + # Add an additional rollout step for advantage calculation + _, self._state = self._policy( + self._state.params, t_prime.observation, self._state ) - print(sample) - self._state = self._state._replace(random_key=key) - # self._state = self._sgd_step(self._state, transitions) - - """Update agent""" - # for agent in other_agents: - # pass - # other_agent_parameters = agent._trajectory_buffer.parameters - # # print(f"other_agent_parameters: {other_agent_parameters}") - # print( - # f"other_agent_parameters shape: {other_agent_parameters.shape}" - # ) + # print("Other agents params", other_agents[0]._trajectory_buffer.params) + self._trajectory_buffer.add( + timestep=t_prime, + action=0, + log_prob=0, + value=self._state.extras["values"] + if not t_prime.last() + else jnp.zeros_like(self._state.extras["values"]), + new_timestep=t_prime, + ) -# TODO: Add args argument -def make_lola(args, obs_spec, action_spec, seed: int, player_id: int) -> LOLA: - """Instantiate LOLA""" - random_key = jax.random.PRNGKey(seed) + # other_agent = other_agents[0] + # other_agent_params = other_agents[0]._trajectory_buffer.params + # It needs to be able to take in the opponents parameters + # and then do a rollout under those parameters + # could do sgd here for other agent? + # sample = self._trajectory_buffer.sample() + # self._state, results = self._sgd_step( + # self._state, other_agent_params, sample + # ) + # self._logger.metrics["sgd_steps"] += ( + # self._num_minibatches * self._num_epochs + # ) + # self._logger.metrics["loss_total"] = results["loss_total"] + # self._logger.metrics["loss_policy"] = results["loss_policy"] + # self._logger.metrics["loss_value"] = results["loss_value"] - # def forward(inputs): - # """Forward pass for LOLA""" - # values = hk.Linear(2, with_bias=False) - # return values(inputs) - # network = hk.without_apply_rng(hk.transform(forward)) +def make_lola(args, obs_spec, action_spec, seed: int, player_id: int): + """Make Naive Learner Policy Gradient agent""" + # print(f"Making network for {args.env_id}") network = make_network(action_spec) + # Optimizer + # batch_size = int(args.num_envs * args.num_steps) + # transition_steps = ( + # args.total_timesteps + # / batch_size + # * args.ppo.num_epochs + # * args.ppo.num_minibatches + # ) + + optimizer = optax.chain( + optax.scale_by_adam(eps=args.ppo.adam_epsilon), + optax.scale(-args.ppo.learning_rate), + ) + + # Random key + random_key = jax.random.PRNGKey(seed=seed) + return LOLA( network=network, + optimizer=optimizer, random_key=random_key, + obs_spec=obs_spec, player_id=player_id, - replay_capacity=100000, # args.lola.replay_capacity, - min_replay_size=50, # args.lola.min_replay_size, - sgd_period=1, # args.dqn.sgd_period, - batch_size=2, # args.dqn.batch_size - obs_spec=(5,), # obs_spec - num_envs=2, # num_envs=args.num_envs, - num_steps=4, # num_steps=args.num_steps, + num_envs=args.num_envs, + num_steps=args.num_steps, + num_minibatches=args.ppo.num_minibatches, + num_epochs=args.ppo.num_epochs, + gamma=args.ppo.gamma, + gae_lambda=args.ppo.gae_lambda, ) if __name__ == "__main__": - lola = make_lola(seed=0) - print(f"LOLA state: {lola._state}") - timestep = TimeStep( - step_type=0, - reward=1, - discount=1, - observation=jnp.array([[1, 0, 0, 0, 0]]), - ) - action = lola.select_action(timestep) - print("Action", action) - timestep = TimeStep( - step_type=0, reward=1, discount=1, observation=jnp.zeros(shape=(1, 5)) - ) - action = lola.select_action(timestep) - print("Action", action) + pass diff --git a/pax/runner.py b/pax/runner.py index eea713fc..0000f43a 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -26,9 +26,15 @@ def train_loop(self, env, agents, num_episodes, watchers): print("Training ") print("-----------------------") for _ in range(0, max(int(num_episodes / env.num_envs), 1)): + # TODO: Inner rollout + # 1. Get other agents' parameters + # 2. Do a rollout + # 3. Simulate gradient update + agents.lookahead(env) + + # NOTE: Outer for loop begins rewards_0, rewards_1 = [], [] t = env.reset() - while not (t[0].last()): actions = agents.select_action(t) t_prime = env.step(actions) diff --git a/pax/watchers.py b/pax/watchers.py index 66780aca..ad0b1876 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -1,6 +1,8 @@ from flax import linen as nn from pax.naive_learners import NaiveLearnerEx +from pax.naive.naive import NaiveLearner +from pax.lola.lola import LOLA from .env import State import jax.numpy as jnp import pax.hyper.ppo as HyperPPO @@ -98,6 +100,30 @@ def value_logger_ppo(agent: PPO) -> dict: return probs +def value_logger_naive(agent: NaiveLearner) -> dict: + weights = agent._state.params["categorical_value_head/~/linear_1"][ + "w" + ] # 5 x 1 matrix + sgd_steps = agent._total_steps / agent._num_steps + probs = { + f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) + } + probs.update({"value/total_steps": sgd_steps}) + return probs + + +def value_logger_lola(agent: LOLA) -> dict: + weights = agent._state.params["categorical_value_head/~/linear_1"][ + "w" + ] # 5 x 1 matrix + sgd_steps = agent._total_steps / agent._num_steps + probs = { + f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) + } + probs.update({"value/total_steps": sgd_steps}) + return probs + + def policy_logger_ppo_with_memory(agent) -> dict: """Calculate probability of coopreation""" # n = 5 @@ -167,6 +193,20 @@ def naive_losses(agent) -> None: return losses +def losses_lola(agent) -> None: + sgd_steps = agent._logger.metrics["sgd_steps"] + loss_total = agent._logger.metrics["loss_total"] + loss_policy = agent._logger.metrics["loss_policy"] + loss_value = agent._logger.metrics["loss_value"] + losses = { + "sgd_steps": sgd_steps, + "train/total": loss_total, + "train/policy": loss_policy, + "train/value": loss_value, + } + return losses + + def losses_ppo(agent: PPO) -> dict: pid = agent.player_id sgd_steps = agent._logger.metrics["sgd_steps"] @@ -191,7 +231,23 @@ def policy_logger_naive(agent) -> None: pi = nn.softmax(weights) sgd_steps = agent._total_steps / agent._num_steps probs = { - f"policy/{str(s)}/{agent.player_id}.cooperate": p[0] + f"policy/{str(s)}/{agent.player_id}/player_{agent.player_id}.cooperate": p[ + 0 + ] + for (s, p) in zip(State, pi) + } + probs.update({"policy/total_steps": sgd_steps}) + return probs + + +def policy_logger_lola(agent) -> None: + weights = agent._state.params["categorical_value_head/~/linear"]["w"] + pi = nn.softmax(weights) + sgd_steps = agent._total_steps / agent._num_steps + probs = { + f"policy/{str(s)}/{agent.player_id}/player_{agent.player_id}.cooperate": p[ + 0 + ] for (s, p) in zip(State, pi) } probs.update({"policy/total_steps": sgd_steps}) From bb7b03b1c33f700ea6cf3860bf6537e2041d7f90 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Wed, 27 Jul 2022 16:28:22 +0100 Subject: [PATCH 14/29] add logic for lola (still debugging) --- pax/centralized_learners.py | 7 +- pax/conf/config.yaml | 4 +- pax/conf/experiment/lola.yaml | 8 +- pax/conf/experiment/naive.yaml | 2 +- pax/lola/lola.py | 456 ++++++++++++++++++++++----------- pax/runner.py | 84 +++--- 6 files changed, 357 insertions(+), 204 deletions(-) diff --git a/pax/centralized_learners.py b/pax/centralized_learners.py index fa77d73d..fcfc2f33 100644 --- a/pax/centralized_learners.py +++ b/pax/centralized_learners.py @@ -25,9 +25,14 @@ def lookahead(self, env): # All other agents in a list # i.e. if i am agent2, then other_agents=[agent1, agent3, agent4 ...] other_agents = self.agents[:counter] + self.agents[counter + 1 :] - agent.lookhead(env, other_agents) + agent.lookahead(env, other_agents) counter += 1 + def out_lookahead(self, env): + """Performs a real rollout and update""" + for agent in self.agents: + agent.out_lookahead(env) + def update( self, old_timesteps: List[TimeStep], diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index 100819a6..bf8532a4 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -21,7 +21,7 @@ agent2: 'LOLA' env_id: ipd game: ipd env_type: finite -env_discount: 0.99 +env_discount: 0.96 payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] centralized: True @@ -75,7 +75,7 @@ naive: # LOLA agent parameters lola: - replay_capacity: 100000 #args.lola.replay_capacity, + replay_capacity: 1000 #args.lola.replay_capacity, min_replay_size: 50 #args.lola.min_replay_size, sgd_period: 1 #args.dqn.sgd_period, batch_size: 2 #args.dqn.batch_size diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index 92dac307..77c2a78c 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -26,10 +26,10 @@ eval_every: 4000 # timesteps # LOLA agent parameters lola: - replay_capacity: 100000 #args.lola.replay_capacity, - min_replay_size: 50 #args.lola.min_replay_size, - sgd_period: 1 #args.dqn.sgd_period, - batch_size: 2 #args.dqn.batch_size + replay_capacity: 100000 + min_replay_size: 50 + sgd_period: 1 + batch_size: 2 # Logging setup wandb: diff --git a/pax/conf/experiment/naive.yaml b/pax/conf/experiment/naive.yaml index 3345f514..2d94d79d 100644 --- a/pax/conf/experiment/naive.yaml +++ b/pax/conf/experiment/naive.yaml @@ -2,7 +2,7 @@ # Agents agent1: 'Naive' -agent2: 'Naive' +agent2: 'TitForTat' # Environment env_id: ipd diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 49cc4f14..1ac1fe16 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -1,5 +1,3 @@ -# Adapted from https://github.com/deepmind/acme/blob/master/acme/agents/jax/ppo/learning.py - from typing import Any, Mapping, NamedTuple, Tuple, Dict from pax import utils @@ -46,7 +44,7 @@ def __init__(self): self.lr_v = 0.1 self.gamma = 0.96 self.n_update = 200 - self.len_rollout = 150 + self.len_rollout = 100 self.batch_size = 128 self.use_baseline = True self.seed = 42 @@ -63,94 +61,61 @@ class Logger: metrics: dict -# class Memory: -# def __init__(self): -# self.self_logprobs = [] -# self.other_logprobs = [] -# self.values = [] -# self.rewards = [] - -# def add(self, lp, other_lp, v, r): -# self.self_logprobs.append(lp) -# self.other_logprobs.append(other_lp) -# self.values.append(v) -# self.rewards.append(r) - -# def dice_objective(self): -# self_logprobs = jnp.stack(self.self_logprobs, dim=1) -# other_logprobs = jnp.stack(self.other_logprobs, dim=1) -# values = jnp.stack(self.values, dim=1) -# rewards = jnp.stack(self.rewards, dim=1) - -# # apply discount: -# cum_discount = ( -# jnp.cumprod(hp.gamma * jnp.ones(*rewards.size()), dim=1) / hp.gamma -# ) -# discounted_rewards = rewards * cum_discount -# discounted_values = values * cum_discount - -# # stochastics nodes involved in rewards dependencies: -# dependencies = jnp.cumsum(self_logprobs + other_logprobs, dim=1) - -# # logprob of each stochastic nodes: -# stochastic_nodes = self_logprobs + other_logprobs - -# # dice objective: -# dice_objective = jnp.mean( -# jnp.sum(magic_box(dependencies) * discounted_rewards, dim=1) -# ) - -# if hp.use_baseline: -# # variance_reduction: -# baseline_term = jnp.mean( -# jnp.sum( -# (1 - magic_box(stochastic_nodes)) * discounted_values, -# dim=1, -# ) -# ) -# dice_objective = dice_objective + baseline_term - -# return -dice_objective # want to minimize -objective - -# def value_loss(self): -# values = jnp.stack(self.values, dim=1) -# rewards = jnp.stack(self.rewards, dim=1) -# return jnp.mean((rewards - values) ** 2) - - -# ipd = IPD(hp.len_rollout, hp.batch_size) - - -# def act(batch_states, theta, values): -# batch_states = jnp.from_numpy(batch_states).long() -# probs = jnp.sigmoid(theta)[batch_states] -# m = Bernoulli(1 - probs) -# actions = m.sample() -# log_probs_actions = m.log_prob(actions) -# return actions.numpy().astype(int), log_probs_actions, values[batch_states] - - -# def get_gradient(objective, theta): -# # create differentiable gradient for 2nd orders: -# grad_objective = torch.autograd.grad( -# objective, (theta), create_graph=True -# )[0] -# return grad_objective - - -# def step(theta1, theta2, values1, values2): -# # just to evaluate progress: -# (s1, s2), _ = ipd.reset() -# score1 = 0 -# score2 = 0 -# for t in range(hp.len_rollout): -# a1, lp1, v1 = act(s1, theta1, values1) -# a2, lp2, v2 = act(s2, theta2, values2) -# (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) -# # cumulate scores -# score1 += np.mean(r1) / float(hp.len_rollout) -# score2 += np.mean(r2) / float(hp.len_rollout) -# return (score1, score2) +class Memory: + def __init__(self): + self.self_logprobs = [] + self.other_logprobs = [] + self.values = [] + self.rewards = [] + + def add(self, lp, other_lp, v, r): + self.self_logprobs.append(lp) + self.other_logprobs.append(other_lp) + self.values.append(v) + self.rewards.append(r) + + def dice_objective(self): + # Stacks so that the dimension is now (num_envs, num_steps) + self_logprobs = jnp.stack(self.self_logprobs, axis=1) + other_logprobs = jnp.stack(self.other_logprobs, axis=1) + values = jnp.stack(self.values, axis=1) + rewards = jnp.stack(self.rewards, axis=1) + + # apply discount: + cum_discount = ( + jnp.cumprod(hp.gamma * jnp.ones(rewards.shape), axis=1) / hp.gamma + ) + discounted_rewards = rewards * cum_discount + discounted_values = values * cum_discount + + # stochastics nodes involved in rewards dependencies: + dependencies = jnp.cumsum(self_logprobs + other_logprobs, axis=1) + + # logprob of each stochastic nodes: + stochastic_nodes = self_logprobs + other_logprobs + + # dice objective: + dice_objective = jnp.mean( + jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) + ) + + if hp.use_baseline: + # variance_reduction: + baseline_term = jnp.mean( + jnp.sum( + (1 - magic_box(stochastic_nodes)) * discounted_values, + axis=1, + ) + ) + dice_objective = dice_objective + baseline_term + + # TODO: Combine the value loss with the dice objective loss? + return -dice_objective # want to minimize -objective + + def value_loss(self): + values = jnp.stack(self.values, axis=1) + rewards = jnp.stack(self.rewards, axis=1) + return jnp.mean((rewards - values) ** 2) class LOLA: @@ -203,42 +168,41 @@ def rollouts( ) buffer.add(t, actions, log_probs, values, t_prime) - def loss( - logprobs: jnp.ndarray, - other_logprobs: jnp.ndarray, - values: jnp.ndarray, - rewards: jnp.ndarray, - ): - logprobs = jnp.stack(logprobs, dim=1) - other_logprobs = jnp.stack(other_logprobs, dim=1) - values = jnp.stack(self.values, dim=1) - rewards = jnp.stack(self.rewards, dim=1) + def loss(log_probs, other_log_probs, values, rewards): + # Stacks so that the dimension is now (num_envs, num_steps) + self_logprobs = jnp.stack(log_probs, axis=1) + other_logprobs = jnp.stack(other_log_probs, axis=1) + values = jnp.stack(values, axis=1) + rewards = jnp.stack(rewards, axis=1) # apply discount: cum_discount = ( - jnp.cumprod(gamma * jnp.ones(*rewards.size()), dim=1) / gamma + jnp.cumprod(hp.gamma * jnp.ones(rewards.shape), axis=1) + / hp.gamma ) discounted_rewards = rewards * cum_discount discounted_values = values * cum_discount # stochastics nodes involved in rewards dependencies: - dependencies = jnp.cumsum(logprobs + other_logprobs, dim=1) + dependencies = jnp.cumsum(self_logprobs + other_logprobs, axis=1) # logprob of each stochastic nodes: - stochastic_nodes = logprobs + other_logprobs + stochastic_nodes = self_logprobs + other_logprobs # dice objective: dice_objective = jnp.mean( - jnp.sum(magic_box(dependencies) * discounted_rewards, dim=1) + jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) ) - baseline_term = jnp.mean( - jnp.sum( - (1 - magic_box(stochastic_nodes)) * discounted_values, - dim=1, + if hp.use_baseline: + # variance_reduction: + baseline_term = jnp.mean( + jnp.sum( + (1 - magic_box(stochastic_nodes)) * discounted_values, + axis=1, + ) ) - ) - dice_objective = dice_objective + baseline_term + dice_objective = dice_objective + baseline_term return -dice_objective, {} # want to minimize -objective @@ -424,19 +388,6 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: extras={"values": None, "log_probs": None}, ) - # def in_lookahead(self, other_theta, other_values): - # (s1, s2), _ = ipd.reset() - # other_memory = Memory() - # for t in range(hp.len_rollout): - # a1, lp1, v1 = act(s1, self.theta, self.values) - # a2, lp2, v2 = act(s2, other_theta, other_values) - # (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) - # other_memory.add(lp2, lp1, v2, torch.from_numpy(r2).float()) - - # other_objective = other_memory.dice_objective() - # grad = get_gradient(other_objective, other_theta) - # return grad - # Initialise training state (parameters, optimiser state, extras). self._state = make_initial_state(random_key, obs_spec) @@ -465,6 +416,10 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: self._policy = policy self._rollouts = rollouts + # initialize some variables + self._optimizer = optimizer + self.gamma = gamma + # Other useful hyperparameters self._num_envs = num_envs # number of environments self._num_steps = num_steps # number of steps per environment @@ -479,31 +434,225 @@ def select_action(self, t: TimeStep): ) return utils.to_numpy(actions) - # def lookahead(self, env, other_agents): - # """ - # Performs a rollout using the current parameters of both agents - # and simulates a naive learning update step for the other agent - - # INPUT: - # env: SequentialMatrixGame, an environment object of the game being played - # other_agents: list, a list of objects of the other agents - # """ - # t = env.reset() - # other_memory = Memory() - # for t in range(hp.len_rollout): - # actions, self.w = self._policy( - # self._state.params, t.observation, self._state - # ) - # a1, lp1, v1 = act(s1, self.theta, self.values) - # a2, lp2, v2 = act(s2, other_theta, other_values) - # (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) - # other_memory.add(lp2, lp1, v2, torch.from_numpy(r2).float()) - - # other_objective = other_memory.dice_objective() - # grad = get_gradient(other_objective, other_theta) - # return grad - - # pass + def lookahead(self, env, other_agents): + """ + Performs a rollout using the current parameters of both agents + and simulates a naive learning update step for the other agent + + INPUT: + env: SequentialMatrixGame, an environment object of the game being played + other_agents: list, a list of objects of the other agents + """ + + def loss(log_probs, other_log_probs, values, rewards): + # Stacks so that the dimension is now (num_envs, num_steps) + self_logprobs = jnp.stack(log_probs, axis=1) + other_logprobs = jnp.stack(other_log_probs, axis=1) + values = jnp.stack(values, axis=1) + rewards = jnp.stack(rewards, axis=1) + + # apply discount: + cum_discount = ( + jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=1) + / self.gamma + ) + discounted_rewards = rewards * cum_discount + discounted_values = values * cum_discount + + # stochastics nodes involved in rewards dependencies: + dependencies = jnp.cumsum(self_logprobs + other_logprobs, axis=1) + + # logprob of each stochastic nodes: + stochastic_nodes = self_logprobs + other_logprobs + + # dice objective: + dice_objective = jnp.mean( + jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) + ) + + if hp.use_baseline: + # variance_reduction: + baseline_term = jnp.mean( + jnp.sum( + (1 - magic_box(stochastic_nodes)) * discounted_values, + axis=1, + ) + ) + dice_objective = dice_objective + baseline_term + + return -dice_objective # want to minimize -objective + + # Reset environment and initialize buffer + t = env.reset() + # initialize buffer + other_memory = Memory() + # get the other agent + other_agent = other_agents[0] + # get a copy of the agent's state + my_state = self._state + # create a temporary training state for the other agent for the simulated rollout + other_state = TrainingState( + params=other_agent._state.params, + opt_state=other_agent._state.opt_state, + random_key=other_agent._state.random_key, + timesteps=other_agent._state.timesteps, + extras=other_agent._state.extras, + ) + + # TODO: Replace with jax.lax.scan + for _ in range(self._num_steps): + # I take an action using my parameters + a1, my_state = self._policy( + my_state.params, t[0].observation, my_state + ) + # Opponent takes an action using their parameters + a2, other_state = self._policy( + other_state.params, t[1].observation, other_state + ) + # The environment steps as usual + t_prime = env.step((a1, a2)) + # We add to the buffer as if we are the other player + _, r_2 = t_prime[0].reward, t_prime[1].reward + other_memory.add( + other_state.extras["log_probs"], + my_state.extras["log_probs"], + other_state.extras["values"], + r_2, + ) + + # unpack the values from the buffer + my_logprobs = other_memory.self_logprobs + other_logprobs = other_memory.other_logprobs + values = other_memory.other_logprobs + rewards = other_memory.rewards + + # initialize the grad function + grad_fn = jax.grad(loss) + + # calculate the gradients + gradients = grad_fn(my_logprobs, other_logprobs, values, rewards) + + # TODO: BREAKS HERE BECAUSE IT EXPECTS A LIST? + # update the optimizer + updates, opt_state = self._optimizer.update( + gradients, other_state.opt_state + ) + # apply the optimizer updates + params = optax.apply_updates(other_state.params, updates) + # replace the other player's current parameters with a simulated update + other_state._replace(params=params) # might be redundant + other_state._replace(opt_state=opt_state) # might be redundant + self._other_state = other_state + + def outer_rollout(self, env): + """ + Performs a real rollout using the current parameters of both agents + and a naive learning update step for the other agent + + INPUT: + env: SequentialMatrixGame, an environment object of the game being played + other_agents: list, a list of objects of the other agents + """ + + def loss(log_probs, other_log_probs, values, rewards): + # Stacks so that the dimension is now (num_envs, num_steps) + self_logprobs = jnp.stack(log_probs, axis=1) + other_logprobs = jnp.stack(other_log_probs, axis=1) + values = jnp.stack(values, axis=1) + rewards = jnp.stack(rewards, axis=1) + + # apply discount: + cum_discount = ( + jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=1) + / self.gamma + ) + discounted_rewards = rewards * cum_discount + discounted_values = values * cum_discount + + # stochastics nodes involved in rewards dependencies: + dependencies = jnp.cumsum(self_logprobs + other_logprobs, axis=1) + + # logprob of each stochastic nodes: + stochastic_nodes = self_logprobs + other_logprobs + + # dice objective: + dice_objective = jnp.mean( + jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) + ) + + if hp.use_baseline: + # variance_reduction: + baseline_term = jnp.mean( + jnp.sum( + (1 - magic_box(stochastic_nodes)) * discounted_values, + axis=1, + ) + ) + dice_objective = dice_objective + baseline_term + + return -dice_objective # want to minimize -objective + + # Reset environment and initialize buffer + t = env.reset() + # initialize buffer + memory = Memory() + my_state = self._state + # create a temporary training state for the other agent for the simulated rollout + other_state = TrainingState( + params=self._other_state.params, + opt_state=self._other_state.opt_state, + random_key=self._other_state.random_key, + timesteps=self._other_state.timesteps, + extras=self._other_state.extras, + ) + + reward_list = [] + # TODO: Replace with jax.lax.scan + for _ in range(self._num_steps): + # I take an action using my parameters + a1, my_state = self._policy( + my_state.params, t[0].observation, my_state + ) + # Opponent takes an action using their parameters + a2, other_state = self._policy( + other_state.params, t[1].observation, other_state + ) + # The environment steps as usual + t_prime = env.step((a1, a2)) + # We add to the buffer as if we are the other player + r_1, _ = t_prime[0].reward, t_prime[1].reward + reward_list.append(r_1) + memory.add( + my_state.extras["log_probs"], + other_state.extras["log_probs"], + my_state.extras["values"], + r_1, + ) + + # unpack the values from the buffer + my_logprobs = memory.self_logprobs + other_logprobs = memory.other_logprobs + values = memory.other_logprobs + rewards = memory.rewards + + # initialize the grad function + grad_fn = jax.grad(loss) + + # calculate the gradients + gradients = grad_fn(my_logprobs, other_logprobs, values, rewards) + # TODO: Need to include the value function somehow? + # TODO: BREAKS HERE BECAUSE IT EXPECTS A LIST? + # update the optimizer + updates, opt_state = self._optimizer.update( + gradients, other_state.opt_state + ) + # apply the optimizer updates + params = optax.apply_updates(other_state.params, updates) + self._state._replace(params=params) + self._state._replace(opt_state=opt_state) + + rewards_array = jnp.array(reward_list) + self.rewards = rewards_array def update( self, @@ -568,10 +717,10 @@ def update( # It needs to be able to take in the opponents parameters # and then do a rollout under those parameters # could do sgd here for other agent? - # sample = self._trajectory_buffer.sample() - # self._state, results = self._sgd_step( - # self._state, other_agent_params, sample - # ) + sample = self._trajectory_buffer.sample() + self._state, results = self._sgd_step( + self._state, self.other_params, sample + ) # self._logger.metrics["sgd_steps"] += ( # self._num_minibatches * self._num_epochs # ) @@ -614,7 +763,6 @@ def make_lola(args, obs_spec, action_spec, seed: int, player_id: int): num_minibatches=args.ppo.num_minibatches, num_epochs=args.ppo.num_epochs, gamma=args.ppo.gamma, - gae_lambda=args.ppo.gae_lambda, ) diff --git a/pax/runner.py b/pax/runner.py index 0000f43a..97c3ecaf 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -8,9 +8,6 @@ import wandb -# TODO: make these a copy of acme - - class Runner: """Holds the runner's state.""" @@ -33,48 +30,51 @@ def train_loop(self, env, agents, num_episodes, watchers): agents.lookahead(env) # NOTE: Outer for loop begins - rewards_0, rewards_1 = [], [] - t = env.reset() - while not (t[0].last()): - actions = agents.select_action(t) - t_prime = env.step(actions) - r_0, r_1 = t_prime[0].reward, t_prime[1].reward - - # append step rewards to episode rewards - rewards_0.append(r_0) - rewards_1.append(r_1) - - agents.update(t, actions, t_prime) - self.train_steps += 1 - - # book keeping - t = t_prime - - # logging - # if watchers: - # agents.log(watchers, None) - # wandb.log( - # { - # "train/training_steps": self.train_steps, - # "train/step_reward/player_1": float( - # jnp.array(r_0).mean() - # ), - # "train/step_reward/player_2": float( - # jnp.array(r_1).mean() - # ), - # "time_elapsed_minutes": ( - # time.time() - self.start_time - # ) - # / 60, - # "time_elapsed_seconds": time.time() - # - self.start_time, - # } - # ) + # rewards_0, rewards_1 = [], [] + # t = env.reset() + # while not (t[0].last()): + # actions = agents.select_action(t) + # t_prime = env.step(actions) + # r_0, r_1 = t_prime[0].reward, t_prime[1].reward + + # # append step rewards to episode rewards + # rewards_0.append(r_0) + # rewards_1.append(r_1) + + # agents.update(t, actions, t_prime) + # self.train_steps += 1 + + # # book keeping + # t = t_prime + + # # logging + # # if watchers: + # # agents.log(watchers, None) + # # wandb.log( + # # { + # # "train/training_steps": self.train_steps, + # # "train/step_reward/player_1": float( + # # jnp.array(r_0).mean() + # # ), + # # "train/step_reward/player_2": float( + # # jnp.array(r_1).mean() + # # ), + # # "time_elapsed_minutes": ( + # # time.time() - self.start_time + # # ) + # # / 60, + # # "time_elapsed_seconds": time.time() + # # - self.start_time, + # # } + # # ) + agents.out_lookahead(env) # end of episode stats self.train_episodes += env.num_envs - rewards_0 = jnp.array(rewards_0) - rewards_1 = jnp.array(rewards_1) + # rewards_0 = jnp.array(rewards_0) + # rewards_1 = jnp.array(rewards_1) + rewards_0 = jnp.array(agents.agents[0].rewards) + rewards_1 = jnp.array(agents.agents[1].rewards) print( f"Total Episode Reward: {float(rewards_0.mean()), float(rewards_1.mean())}" From c169c757ba77a6b667bea5fdeb4703b1b0654972 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Wed, 27 Jul 2022 20:47:33 +0100 Subject: [PATCH 15/29] add lola (doesn't quite work yet) --- pax/centralized_learners.py | 11 +- pax/conf/config.yaml | 4 +- pax/conf/experiment/lola.yaml | 20 +- pax/lola/lola.py | 526 ++++++++-------------------------- pax/runner.py | 88 +++--- pax/watchers.py | 5 +- 6 files changed, 183 insertions(+), 471 deletions(-) diff --git a/pax/centralized_learners.py b/pax/centralized_learners.py index fcfc2f33..658375d5 100644 --- a/pax/centralized_learners.py +++ b/pax/centralized_learners.py @@ -18,20 +18,25 @@ def select_action(self, timesteps: List[TimeStep]) -> List[jnp.ndarray]: agent.select_action(t) for agent, t in zip(self.agents, timesteps) ] - def lookahead(self, env): + def in_lookahead(self, env): """Simulates a rollout and gradient update""" counter = 0 for agent in self.agents: # All other agents in a list # i.e. if i am agent2, then other_agents=[agent1, agent3, agent4 ...] other_agents = self.agents[:counter] + self.agents[counter + 1 :] - agent.lookahead(env, other_agents) + agent.in_lookahead(env, other_agents) counter += 1 def out_lookahead(self, env): """Performs a real rollout and update""" + counter = 0 for agent in self.agents: - agent.out_lookahead(env) + # All other agents in a list + # i.e. if i am agent2, then other_agents=[agent1, agent3, agent4 ...] + other_agents = self.agents[:counter] + self.agents[counter + 1 :] + agent.out_lookahead(env, other_agents) + counter += 1 def update( self, diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index bf8532a4..e6f4b77e 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -26,10 +26,10 @@ payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] centralized: True # Training hyperparameters -num_envs: 100 +num_envs: 10 num_steps: 100 # number of steps per episode total_timesteps: 100_000_000 -eval_every: 500 # eval every n episodes, not timesteps +eval_every: 50_000 # eval every n episodes, not timesteps # Useful information diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index 77c2a78c..ed9201f6 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -9,15 +9,15 @@ centralized: True env_id: ipd game: ipd env_type: finite -env_discount: 0.99 +env_discount: 0.96 payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] # Training hyperparameters -num_envs: 10 +num_envs: 128 num_steps: 150 # number of steps per episode -total_timesteps: 100_000 -eval_every: 4000 # timesteps +total_timesteps: 10_000_000 +eval_every: 50_000 # timesteps # Useful information # num_episodes = total_timesteps / num_steps @@ -26,10 +26,12 @@ eval_every: 4000 # timesteps # LOLA agent parameters lola: - replay_capacity: 100000 - min_replay_size: 50 - sgd_period: 1 - batch_size: 2 + use_baseline: True + adam_epsilon: 1e-5 + learning_rate: 0.2 + lr_in: 0.3 + lr_out: 0.2 + gamma: 0.96 # Logging setup wandb: @@ -37,4 +39,4 @@ wandb: project: ipd group: 'LOLA-vs-${agent2}-${game}' name: run-seed-${seed} - log: False + log: True diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 1ac1fe16..4b967e58 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -1,3 +1,4 @@ +from binascii import a2b_base64 from typing import Any, Mapping, NamedTuple, Tuple, Dict from pax import utils @@ -37,22 +38,6 @@ class TrainingState(NamedTuple): extras: Mapping[str, jnp.ndarray] -class Hp: - def __init__(self): - self.lr_out = 0.2 - self.lr_in = 0.3 - self.lr_v = 0.1 - self.gamma = 0.96 - self.n_update = 200 - self.len_rollout = 100 - self.batch_size = 128 - self.use_baseline = True - self.seed = 42 - - -hp = Hp() - - def magic_box(x): return jnp.exp(x - jax.lax.stop_gradient(x)) @@ -61,63 +46,6 @@ class Logger: metrics: dict -class Memory: - def __init__(self): - self.self_logprobs = [] - self.other_logprobs = [] - self.values = [] - self.rewards = [] - - def add(self, lp, other_lp, v, r): - self.self_logprobs.append(lp) - self.other_logprobs.append(other_lp) - self.values.append(v) - self.rewards.append(r) - - def dice_objective(self): - # Stacks so that the dimension is now (num_envs, num_steps) - self_logprobs = jnp.stack(self.self_logprobs, axis=1) - other_logprobs = jnp.stack(self.other_logprobs, axis=1) - values = jnp.stack(self.values, axis=1) - rewards = jnp.stack(self.rewards, axis=1) - - # apply discount: - cum_discount = ( - jnp.cumprod(hp.gamma * jnp.ones(rewards.shape), axis=1) / hp.gamma - ) - discounted_rewards = rewards * cum_discount - discounted_values = values * cum_discount - - # stochastics nodes involved in rewards dependencies: - dependencies = jnp.cumsum(self_logprobs + other_logprobs, axis=1) - - # logprob of each stochastic nodes: - stochastic_nodes = self_logprobs + other_logprobs - - # dice objective: - dice_objective = jnp.mean( - jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) - ) - - if hp.use_baseline: - # variance_reduction: - baseline_term = jnp.mean( - jnp.sum( - (1 - magic_box(stochastic_nodes)) * discounted_values, - axis=1, - ) - ) - dice_objective = dice_objective + baseline_term - - # TODO: Combine the value loss with the dice objective loss? - return -dice_objective # want to minimize -objective - - def value_loss(self): - values = jnp.stack(self.values, axis=1) - rewards = jnp.stack(self.rewards, axis=1) - return jnp.mean((rewards - values) ** 2) - - class LOLA: """LOLA with the DiCE objective function.""" @@ -132,10 +60,10 @@ def __init__( num_steps: int = 150, num_minibatches: int = 1, num_epochs: int = 1, + use_baseline: bool = True, gamma: float = 0.96, ): - - # @jax.jit + @jax.jit def policy( params: hk.Params, observation: TimeStep, state: TrainingState ): @@ -154,47 +82,48 @@ def policy( ) return actions, state - def rollouts( - buffer: TrajectoryBuffer, - t: TimeStep, - actions: np.array, - t_prime: TimeStep, - state: TrainingState, - ) -> None: - """Stores rollout in buffer""" - log_probs, values = ( - state.extras["log_probs"], - state.extras["values"], - ) - buffer.add(t, actions, log_probs, values, t_prime) - - def loss(log_probs, other_log_probs, values, rewards): + def loss(params, other_params, samples): # Stacks so that the dimension is now (num_envs, num_steps) - self_logprobs = jnp.stack(log_probs, axis=1) - other_logprobs = jnp.stack(other_log_probs, axis=1) - values = jnp.stack(values, axis=1) - rewards = jnp.stack(rewards, axis=1) + + obs_1 = samples.obs_self + obs_2 = samples.obs_other + + rewards = samples.rewards_self + # r_1 = samples.rewards_self + # r_2 = samples.rewards_other + + actions_1 = samples.actions_self + actions_2 = samples.actions_other + + # distribution, values_self = self.network.apply(params, obs_1) + distribution, values = self.network.apply(params, obs_1) + self_log_prob = distribution.log_prob(actions_1) + + distribution, values_others = self.network.apply( + other_params, obs_2 + ) + other_log_prob = distribution.log_prob(actions_2) # apply discount: cum_discount = ( - jnp.cumprod(hp.gamma * jnp.ones(rewards.shape), axis=1) - / hp.gamma + jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=1) + / self.gamma ) discounted_rewards = rewards * cum_discount discounted_values = values * cum_discount # stochastics nodes involved in rewards dependencies: - dependencies = jnp.cumsum(self_logprobs + other_logprobs, axis=1) + dependencies = jnp.cumsum(self_log_prob + other_log_prob, axis=1) # logprob of each stochastic nodes: - stochastic_nodes = self_logprobs + other_logprobs + stochastic_nodes = self_log_prob + other_log_prob # dice objective: dice_objective = jnp.mean( jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) ) - if hp.use_baseline: + if use_baseline: # variance_reduction: baseline_term = jnp.mean( jnp.sum( @@ -204,174 +133,15 @@ def loss(log_probs, other_log_probs, values, rewards): ) dice_objective = dice_objective + baseline_term - return -dice_objective, {} # want to minimize -objective + loss_value = jnp.mean((rewards - values) ** 2) + loss_total = -dice_objective + loss_value - @jax.jit - def sgd_step( - state: TrainingState, - other_agent_params: hk.Params, - sample: NamedTuple, - ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: - """Performs a minibatch SGD step, returning new state and metrics.""" - - # Extract data - ( - observations, - actions, - rewards, - behavior_log_probs, - behavior_values, - dones, - ) = ( - sample.observations, - sample.actions, - sample.rewards, - sample.behavior_log_probs, - sample.behavior_values, - sample.dones, - ) - - # vmap - # TODO: REMOVE THIS INEA IFNEF IAEIF AE - def gae_advantages(): - pass - - batch_gae_advantages = jax.vmap(gae_advantages, in_axes=0) - advantages, target_values = batch_gae_advantages( - rewards=rewards, values=behavior_values, dones=dones - ) - - # Exclude the last step - it was only used for bootstrapping. - # The shape is [num_envs, num_steps, ..] - ( - observations, - actions, - behavior_log_probs, - behavior_values, - ) = jax.tree_map( - lambda x: x[:, :-1], - (observations, actions, behavior_log_probs, behavior_values), - ) - - trajectories = Batch( - observations=observations, - actions=actions, - advantages=advantages, - behavior_log_probs=behavior_log_probs, - target_values=target_values, - behavior_values=behavior_values, - ) - - # Concatenate all trajectories. Reshape from [num_envs, num_steps, ..] - # to [num_envs * num_steps,..] - assert len(target_values.shape) > 1 - num_envs = target_values.shape[0] - num_steps = target_values.shape[1] - batch_size = num_envs * num_steps - assert batch_size % num_minibatches == 0, ( - "Num minibatches must divide batch size. Got batch_size={}" - " num_minibatches={}." - ).format(batch_size, num_minibatches) - - batch = jax.tree_map( - lambda x: x.reshape((batch_size,) + x.shape[2:]), trajectories - ) - - # Compute gradients. - grad_fn = jax.grad(loss, has_aux=True) - - def model_update_minibatch( - carry: Tuple[hk.Params, optax.OptState, int], - minibatch: Batch, - ) -> Tuple[ - Tuple[hk.Params, optax.OptState, int], Dict[str, jnp.ndarray] - ]: - """Performs model update for a single minibatch.""" - params, opt_state, timesteps = carry - # Normalize advantages at the minibatch level before using them. - advantages = ( - minibatch.advantages - - jnp.mean(minibatch.advantages, axis=0) - ) / (jnp.std(minibatch.advantages, axis=0) + 1e-8) - gradients, metrics = grad_fn( - params, - timesteps, - minibatch.observations, - minibatch.actions, - minibatch.behavior_log_probs, - minibatch.target_values, - advantages, - minibatch.behavior_values, - ) - - # Apply updates - updates, opt_state = optimizer.update(gradients, opt_state) - params = optax.apply_updates(params, updates) - - metrics["norm_grad"] = optax.global_norm(gradients) - metrics["norm_updates"] = optax.global_norm(updates) - return (params, opt_state, timesteps), metrics - - def model_update_epoch( - carry: Tuple[ - jnp.ndarray, hk.Params, optax.OptState, int, Batch - ], - unused_t: Tuple[()], - ) -> Tuple[ - Tuple[jnp.ndarray, hk.Params, optax.OptState, Batch], - Dict[str, jnp.ndarray], - ]: - """Performs model updates based on one epoch of data.""" - key, params, opt_state, timesteps, batch = carry - key, subkey = jax.random.split(key) - permutation = jax.random.permutation(subkey, batch_size) - shuffled_batch = jax.tree_map( - lambda x: jnp.take(x, permutation, axis=0), batch - ) - minibatches = jax.tree_map( - lambda x: jnp.reshape( - x, [num_minibatches, -1] + list(x.shape[1:]) - ), - shuffled_batch, - ) - - (params, opt_state, timesteps), metrics = jax.lax.scan( - model_update_minibatch, - (params, opt_state, timesteps), - minibatches, - length=num_minibatches, - ) - return (key, params, opt_state, timesteps, batch), metrics - - params = state.params - opt_state = state.opt_state - timesteps = state.timesteps - - # Repeat training for the given number of epoch, taking a random - # permutation for every epoch. - # signature is scan(function, carry, tuple to iterate over, length) - (key, params, opt_state, timesteps, _), metrics = jax.lax.scan( - model_update_epoch, - (state.random_key, params, opt_state, timesteps, batch), - (), - length=num_epochs, - ) - - metrics = jax.tree_map(jnp.mean, metrics) - metrics["rewards_mean"] = jnp.mean( - jnp.abs(jnp.mean(rewards, axis=(0, 1))) - ) - metrics["rewards_std"] = jnp.std(rewards, axis=(0, 1)) - - new_state = TrainingState( - params=params, - opt_state=opt_state, - random_key=key, - timesteps=timesteps, - extras={"log_probs": None, "values": None}, - ) - - return new_state, metrics + # want to minimize -objective + return loss_total, { + "loss_total": -dice_objective + loss_value, + "loss_policy": -dice_objective, + "loss_value": loss_value, + } def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" @@ -398,7 +168,8 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: self._trajectory_buffer = TrajectoryBuffer( num_envs, num_steps, obs_spec ) - self._sgd_step = sgd_step + + self.grad_fn = jax.jit(jax.grad(loss, has_aux=True)) # Set up counters and logger self._logger = Logger() @@ -414,7 +185,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: # Initialize functions self._policy = policy - self._rollouts = rollouts + self.network = network # initialize some variables self._optimizer = optimizer @@ -426,6 +197,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: self._batch_size = int(num_envs * num_steps) # number in one batch self._num_minibatches = num_minibatches # number of minibatches self._num_epochs = num_epochs # number of epochs to use sample + self._obs_spec = obs_spec def select_action(self, t: TimeStep): """Selects action and updates info with PPO specific information""" @@ -434,7 +206,7 @@ def select_action(self, t: TimeStep): ) return utils.to_numpy(actions) - def lookahead(self, env, other_agents): + def in_lookahead(self, env, other_agents): """ Performs a rollout using the current parameters of both agents and simulates a naive learning update step for the other agent @@ -444,53 +216,17 @@ def lookahead(self, env, other_agents): other_agents: list, a list of objects of the other agents """ - def loss(log_probs, other_log_probs, values, rewards): - # Stacks so that the dimension is now (num_envs, num_steps) - self_logprobs = jnp.stack(log_probs, axis=1) - other_logprobs = jnp.stack(other_log_probs, axis=1) - values = jnp.stack(values, axis=1) - rewards = jnp.stack(rewards, axis=1) - - # apply discount: - cum_discount = ( - jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=1) - / self.gamma - ) - discounted_rewards = rewards * cum_discount - discounted_values = values * cum_discount - - # stochastics nodes involved in rewards dependencies: - dependencies = jnp.cumsum(self_logprobs + other_logprobs, axis=1) - - # logprob of each stochastic nodes: - stochastic_nodes = self_logprobs + other_logprobs - - # dice objective: - dice_objective = jnp.mean( - jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) - ) - - if hp.use_baseline: - # variance_reduction: - baseline_term = jnp.mean( - jnp.sum( - (1 - magic_box(stochastic_nodes)) * discounted_values, - axis=1, - ) - ) - dice_objective = dice_objective + baseline_term - - return -dice_objective # want to minimize -objective - - # Reset environment and initialize buffer - t = env.reset() # initialize buffer - other_memory = Memory() + other_memory = TrajectoryBuffer( + self._num_envs, self._num_steps, self._obs_spec + ) + # get the other agent other_agent = other_agents[0] # get a copy of the agent's state my_state = self._state # create a temporary training state for the other agent for the simulated rollout + # TODO: make a separate optimizer for the other agent other_state = TrainingState( params=other_agent._state.params, opt_state=other_agent._state.opt_state, @@ -499,7 +235,9 @@ def loss(log_probs, other_log_probs, values, rewards): extras=other_agent._state.extras, ) + # Perform a rollout and store S,A,R,S tuples # TODO: Replace with jax.lax.scan + t = env.reset() for _ in range(self._num_steps): # I take an action using my parameters a1, my_state = self._policy( @@ -512,39 +250,42 @@ def loss(log_probs, other_log_probs, values, rewards): # The environment steps as usual t_prime = env.step((a1, a2)) # We add to the buffer as if we are the other player - _, r_2 = t_prime[0].reward, t_prime[1].reward - other_memory.add( - other_state.extras["log_probs"], - my_state.extras["log_probs"], - other_state.extras["values"], - r_2, - ) + # TODO: IS THIS THE RIGHT ORDER??????? + other_memory.add(t[1], t[0], a2, a1, t_prime[1], t_prime[0]) + t = t_prime # unpack the values from the buffer - my_logprobs = other_memory.self_logprobs - other_logprobs = other_memory.other_logprobs - values = other_memory.other_logprobs - rewards = other_memory.rewards - - # initialize the grad function - grad_fn = jax.grad(loss) + sample = other_memory.sample() # calculate the gradients - gradients = grad_fn(my_logprobs, other_logprobs, values, rewards) + gradients, _ = self.grad_fn( + other_state.params, my_state.params, sample + ) - # TODO: BREAKS HERE BECAUSE IT EXPECTS A LIST? - # update the optimizer + # Update the optimizer updates, opt_state = self._optimizer.update( gradients, other_state.opt_state ) + # apply the optimizer updates + # params = params = optax.apply_updates(other_state.params, updates) # replace the other player's current parameters with a simulated update - other_state._replace(params=params) # might be redundant - other_state._replace(opt_state=opt_state) # might be redundant - self._other_state = other_state + # print("Params before replacement", params) + self._other_state = TrainingState( + params=params, + opt_state=opt_state, + random_key=other_state.random_key, + timesteps=other_state.timesteps, + extras=other_state.extras, + ) + # other_state._replace(params=params) + # print("Params after replacement", other_state.params) + # other_state._replace(opt_state=opt_state) + # self._other_state = other_state + # print("Params after self.", self._other_state.params) - def outer_rollout(self, env): + def out_lookahead(self, env, other_agents): """ Performs a real rollout using the current parameters of both agents and a naive learning update step for the other agent @@ -554,59 +295,16 @@ def outer_rollout(self, env): other_agents: list, a list of objects of the other agents """ - def loss(log_probs, other_log_probs, values, rewards): - # Stacks so that the dimension is now (num_envs, num_steps) - self_logprobs = jnp.stack(log_probs, axis=1) - other_logprobs = jnp.stack(other_log_probs, axis=1) - values = jnp.stack(values, axis=1) - rewards = jnp.stack(rewards, axis=1) - - # apply discount: - cum_discount = ( - jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=1) - / self.gamma - ) - discounted_rewards = rewards * cum_discount - discounted_values = values * cum_discount - - # stochastics nodes involved in rewards dependencies: - dependencies = jnp.cumsum(self_logprobs + other_logprobs, axis=1) - - # logprob of each stochastic nodes: - stochastic_nodes = self_logprobs + other_logprobs - - # dice objective: - dice_objective = jnp.mean( - jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) - ) - - if hp.use_baseline: - # variance_reduction: - baseline_term = jnp.mean( - jnp.sum( - (1 - magic_box(stochastic_nodes)) * discounted_values, - axis=1, - ) - ) - dice_objective = dice_objective + baseline_term - - return -dice_objective # want to minimize -objective - - # Reset environment and initialize buffer - t = env.reset() # initialize buffer - memory = Memory() - my_state = self._state - # create a temporary training state for the other agent for the simulated rollout - other_state = TrainingState( - params=self._other_state.params, - opt_state=self._other_state.opt_state, - random_key=self._other_state.random_key, - timesteps=self._other_state.timesteps, - extras=self._other_state.extras, + memory = TrajectoryBuffer( + self._num_envs, self._num_steps, self._obs_spec ) - reward_list = [] + # get a copy of the agent's state + my_state = self._state + # Perform a rollout and store S,A,R,S tuples + t = env.reset() + other_state = self._other_state # TODO: Replace with jax.lax.scan for _ in range(self._num_steps): # I take an action using my parameters @@ -619,40 +317,47 @@ def loss(log_probs, other_log_probs, values, rewards): ) # The environment steps as usual t_prime = env.step((a1, a2)) - # We add to the buffer as if we are the other player - r_1, _ = t_prime[0].reward, t_prime[1].reward - reward_list.append(r_1) - memory.add( - my_state.extras["log_probs"], - other_state.extras["log_probs"], - my_state.extras["values"], - r_1, - ) + # We add to the buffer + # TODO: IS THIS THE RIGHT ORDER??????? + memory.add(t[0], t[1], a1, a2, t_prime[0], t_prime[1]) + t = t_prime - # unpack the values from the buffer - my_logprobs = memory.self_logprobs - other_logprobs = memory.other_logprobs - values = memory.other_logprobs - rewards = memory.rewards + # Update internal agent's timesteps + self._total_steps += self._num_envs + self._logger.metrics["total_steps"] += self._num_envs + self._state._replace(timesteps=self._total_steps) - # initialize the grad function - grad_fn = jax.grad(loss) + # unpack the values from the buffer + sample = memory.sample() # calculate the gradients - gradients = grad_fn(my_logprobs, other_logprobs, values, rewards) - # TODO: Need to include the value function somehow? - # TODO: BREAKS HERE BECAUSE IT EXPECTS A LIST? - # update the optimizer + gradients, results = self.grad_fn( + my_state.params, other_state.params, sample + ) + + # Update the optimizer updates, opt_state = self._optimizer.update( - gradients, other_state.opt_state + gradients, my_state.opt_state ) + # apply the optimizer updates - params = optax.apply_updates(other_state.params, updates) - self._state._replace(params=params) - self._state._replace(opt_state=opt_state) + params = optax.apply_updates(my_state.params, updates) + + self._logger.metrics["sgd_steps"] += ( + self._num_minibatches * self._num_epochs + ) + self._logger.metrics["loss_total"] = results["loss_total"] + self._logger.metrics["loss_policy"] = results["loss_policy"] + self._logger.metrics["loss_value"] = results["loss_value"] - rewards_array = jnp.array(reward_list) - self.rewards = rewards_array + # replace the other player's current parameters with a simulated update + self._state = TrainingState( + params=params, + opt_state=opt_state, + random_key=self._state.random_key, + timesteps=self._state.timesteps, + extras={"log_probs": None, "values": None}, + ) def update( self, @@ -745,8 +450,8 @@ def make_lola(args, obs_spec, action_spec, seed: int, player_id: int): # ) optimizer = optax.chain( - optax.scale_by_adam(eps=args.ppo.adam_epsilon), - optax.scale(-args.ppo.learning_rate), + optax.scale_by_adam(eps=args.lola.adam_epsilon), + optax.scale(-args.lola.learning_rate), ) # Random key @@ -762,7 +467,8 @@ def make_lola(args, obs_spec, action_spec, seed: int, player_id: int): num_steps=args.num_steps, num_minibatches=args.ppo.num_minibatches, num_epochs=args.ppo.num_epochs, - gamma=args.ppo.gamma, + use_baseline=args.lola.use_baseline, + gamma=args.lola.gamma, ) diff --git a/pax/runner.py b/pax/runner.py index 97c3ecaf..703e3352 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -23,58 +23,60 @@ def train_loop(self, env, agents, num_episodes, watchers): print("Training ") print("-----------------------") for _ in range(0, max(int(num_episodes / env.num_envs), 1)): + # TODO: Inner rollout # 1. Get other agents' parameters # 2. Do a rollout # 3. Simulate gradient update - agents.lookahead(env) # NOTE: Outer for loop begins - # rewards_0, rewards_1 = [], [] - # t = env.reset() - # while not (t[0].last()): - # actions = agents.select_action(t) - # t_prime = env.step(actions) - # r_0, r_1 = t_prime[0].reward, t_prime[1].reward - - # # append step rewards to episode rewards - # rewards_0.append(r_0) - # rewards_1.append(r_1) - - # agents.update(t, actions, t_prime) - # self.train_steps += 1 - - # # book keeping - # t = t_prime - - # # logging - # # if watchers: - # # agents.log(watchers, None) - # # wandb.log( - # # { - # # "train/training_steps": self.train_steps, - # # "train/step_reward/player_1": float( - # # jnp.array(r_0).mean() - # # ), - # # "train/step_reward/player_2": float( - # # jnp.array(r_1).mean() - # # ), - # # "time_elapsed_minutes": ( - # # time.time() - self.start_time - # # ) - # # / 60, - # # "time_elapsed_seconds": time.time() - # # - self.start_time, - # # } - # # ) - agents.out_lookahead(env) + rewards_0, rewards_1 = [], [] + t = env.reset() + while not (t[0].last()): + actions = agents.select_action(t) + t_prime = env.step(actions) + r_0, r_1 = t_prime[0].reward, t_prime[1].reward + + # append step rewards to episode rewards + rewards_0.append(r_0) + rewards_1.append(r_1) + + # agents.update(t, actions, t_prime) + self.train_steps += 1 + + # book keeping + t = t_prime + # logging + if watchers: + agents.log(watchers) + wandb.log( + { + "train/training_steps": self.train_steps, + "train/step_reward/player_1": float( + jnp.array(r_0).mean() + ), + "train/step_reward/player_2": float( + jnp.array(r_1).mean() + ), + "time_elapsed_minutes": ( + time.time() - self.start_time + ) + / 60, + "time_elapsed_seconds": time.time() + - self.start_time, + } + ) + agents.in_lookahead(env) + agents.out_lookahead(env) + # print("Agent 1 params", agents.agents[0]._state.params["categorical_value_head/~/linear"]["w"]) + # print("Agent 2 params", agents.agents[1]._state.params["categorical_value_head/~/linear"]["w"]) # end of episode stats self.train_episodes += env.num_envs - # rewards_0 = jnp.array(rewards_0) - # rewards_1 = jnp.array(rewards_1) - rewards_0 = jnp.array(agents.agents[0].rewards) - rewards_1 = jnp.array(agents.agents[1].rewards) + rewards_0 = jnp.array(rewards_0) + rewards_1 = jnp.array(rewards_1) + # rewards_0 = jnp.array(agents.agents[0].rewards) + # rewards_1 = jnp.array(agents.agents[1].rewards) print( f"Total Episode Reward: {float(rewards_0.mean()), float(rewards_1.mean())}" diff --git a/pax/watchers.py b/pax/watchers.py index 674de7f5..3891c028 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -259,10 +259,7 @@ def policy_logger_lola(agent) -> None: pi = nn.softmax(weights) sgd_steps = agent._total_steps / agent._num_steps probs = { - f"policy/{str(s)}/{agent.player_id}/player_{agent.player_id}.cooperate": p[ - 0 - ] - for (s, p) in zip(State, pi) + f"policy/{agent.player_id}/{str(s)}": p[0] for (s, p) in zip(State, pi) } probs.update({"policy/total_steps": sgd_steps}) return probs From 1a00280b2fb3f3bd53c49667b0957bc7d67d9e9d Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Thu, 28 Jul 2022 13:31:25 +0100 Subject: [PATCH 16/29] compiling lola... --- pax/conf/experiment/lola.yaml | 9 ++++----- pax/lola/lola.py | 24 ++++++++++++++++-------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index ed9201f6..ac9aa0e2 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -14,10 +14,10 @@ payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] # Training hyperparameters -num_envs: 128 +num_envs: 64 num_steps: 150 # number of steps per episode total_timesteps: 10_000_000 -eval_every: 50_000 # timesteps +eval_every: 100_000 # timesteps # Useful information # num_episodes = total_timesteps / num_steps @@ -28,9 +28,8 @@ eval_every: 50_000 # timesteps lola: use_baseline: True adam_epsilon: 1e-5 - learning_rate: 0.2 - lr_in: 0.3 - lr_out: 0.2 + lr_in: 1 + lr_out: 0.3 gamma: 0.96 # Logging setup diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 4b967e58..6df44088 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -52,7 +52,8 @@ class LOLA: def __init__( self, network: NamedTuple, - optimizer: optax.GradientTransformation, + inner_optimizer: optax.GradientTransformation, + outer_optimizer: optax.GradientTransformation, random_key: jnp.ndarray, player_id: int, obs_spec: Tuple, @@ -149,7 +150,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: dummy_obs = jnp.zeros(shape=obs_spec) dummy_obs = utils.add_batch_dim(dummy_obs) initial_params = network.init(subkey, dummy_obs) - initial_opt_state = optimizer.init(initial_params) + initial_opt_state = outer_optimizer.init(initial_params) return TrainingState( params=initial_params, opt_state=initial_opt_state, @@ -169,6 +170,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: num_envs, num_steps, obs_spec ) + # self.grad_fn = jax.grad(loss, has_aux=True) self.grad_fn = jax.jit(jax.grad(loss, has_aux=True)) # Set up counters and logger @@ -188,7 +190,8 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: self.network = network # initialize some variables - self._optimizer = optimizer + self._inner_optimizer = inner_optimizer + self._outer_optimizer = outer_optimizer self.gamma = gamma # Other useful hyperparameters @@ -263,7 +266,7 @@ def in_lookahead(self, env, other_agents): ) # Update the optimizer - updates, opt_state = self._optimizer.update( + updates, opt_state = self._inner_optimizer.update( gradients, other_state.opt_state ) @@ -336,7 +339,7 @@ def out_lookahead(self, env, other_agents): ) # Update the optimizer - updates, opt_state = self._optimizer.update( + updates, opt_state = self._outer_optimizer.update( gradients, my_state.opt_state ) @@ -449,9 +452,13 @@ def make_lola(args, obs_spec, action_spec, seed: int, player_id: int): # * args.ppo.num_minibatches # ) - optimizer = optax.chain( + inner_optimizer = optax.chain( optax.scale_by_adam(eps=args.lola.adam_epsilon), - optax.scale(-args.lola.learning_rate), + optax.scale(-args.lola.lr_in), + ) + outer_optimizer = optax.chain( + optax.scale_by_adam(eps=args.lola.adam_epsilon), + optax.scale(-args.lola.lr_out), ) # Random key @@ -459,7 +466,8 @@ def make_lola(args, obs_spec, action_spec, seed: int, player_id: int): return LOLA( network=network, - optimizer=optimizer, + inner_optimizer=inner_optimizer, + outer_optimizer=outer_optimizer, random_key=random_key, obs_spec=obs_spec, player_id=player_id, From be33bc88d83fa338a0c0ba1fa5f5ecd709483299 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Thu, 28 Jul 2022 14:14:59 +0100 Subject: [PATCH 17/29] working lola --- pax/conf/experiment/lola.yaml | 8 ++++---- pax/lola/lola.py | 23 ++++++++++++----------- pax/lola/network.py | 6 ++++-- pax/runner.py | 13 ++++++------- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index ac9aa0e2..943448fc 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -14,10 +14,10 @@ payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] # Training hyperparameters -num_envs: 64 +num_envs: 128 num_steps: 150 # number of steps per episode total_timesteps: 10_000_000 -eval_every: 100_000 # timesteps +eval_every: 10_000_000 # timesteps # Useful information # num_episodes = total_timesteps / num_steps @@ -28,8 +28,8 @@ eval_every: 100_000 # timesteps lola: use_baseline: True adam_epsilon: 1e-5 - lr_in: 1 - lr_out: 0.3 + lr_in: 0.3 + lr_out: 0.2 gamma: 0.96 # Logging setup diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 6df44088..35e060cc 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -135,12 +135,18 @@ def loss(params, other_params, samples): dice_objective = dice_objective + baseline_term loss_value = jnp.mean((rewards - values) ** 2) - loss_total = -dice_objective + loss_value + # loss_total = -dice_objective + loss_value + loss_total = dice_objective + loss_value # want to minimize -objective + # return loss_total, { + # "loss_total": -dice_objective + loss_value, + # "loss_policy": -dice_objective, + # "loss_value": loss_value, + # } return loss_total, { - "loss_total": -dice_objective + loss_value, - "loss_policy": -dice_objective, + "loss_total": dice_objective + loss_value, + "loss_policy": dice_objective, "loss_value": loss_value, } @@ -271,10 +277,9 @@ def in_lookahead(self, env, other_agents): ) # apply the optimizer updates - # params = params = optax.apply_updates(other_state.params, updates) + # replace the other player's current parameters with a simulated update - # print("Params before replacement", params) self._other_state = TrainingState( params=params, opt_state=opt_state, @@ -282,11 +287,6 @@ def in_lookahead(self, env, other_agents): timesteps=other_state.timesteps, extras=other_state.extras, ) - # other_state._replace(params=params) - # print("Params after replacement", other_state.params) - # other_state._replace(opt_state=opt_state) - # self._other_state = other_state - # print("Params after self.", self._other_state.params) def out_lookahead(self, env, other_agents): """ @@ -305,9 +305,10 @@ def out_lookahead(self, env, other_agents): # get a copy of the agent's state my_state = self._state + # get a copy of the other opponent's state + other_state = self._other_state # Perform a rollout and store S,A,R,S tuples t = env.reset() - other_state = self._other_state # TODO: Replace with jax.lax.scan for _ in range(self._num_steps): # I take an action using my parameters diff --git a/pax/lola/network.py b/pax/lola/network.py index 04abacd1..5ecdc3be 100644 --- a/pax/lola/network.py +++ b/pax/lola/network.py @@ -18,12 +18,14 @@ def __init__( super().__init__(name=name) self._logit_layer = hk.Linear( num_values, - w_init=hk.initializers.Constant(0.5), + # w_init=hk.initializers.Constant(0.5), + w_init=hk.initializers.RandomNormal(), with_bias=False, ) self._value_layer = hk.Linear( 1, - w_init=hk.initializers.Constant(0.5), + # w_init=hk.initializers.Constant(0.5), + w_init=hk.initializers.RandomNormal(), with_bias=False, ) diff --git a/pax/runner.py b/pax/runner.py index 703e3352..229111be 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -77,20 +77,19 @@ def train_loop(self, env, agents, num_episodes, watchers): rewards_1 = jnp.array(rewards_1) # rewards_0 = jnp.array(agents.agents[0].rewards) # rewards_1 = jnp.array(agents.agents[1].rewards) + mean_r_0 = float(rewards_1.mean()) + mean_r_1 = float(rewards_1.mean()) print( - f"Total Episode Reward: {float(rewards_0.mean()), float(rewards_1.mean())}" + f"Total Episode Reward: {mean_r_0, mean_r_1} | Joint reward: {(mean_r_0 + mean_r_1)*0.5}" ) if watchers: wandb.log( { "episodes": self.train_episodes, - "train/episode_reward/player_1": float( - rewards_0.mean() - ), - "train/episode_reward/player_2": float( - rewards_1.mean() - ), + "train/episode_reward/player_1": mean_r_0, + "train/episode_reward/player_2": mean_r_1, + "train/joint_reward": (mean_r_0 + mean_r_1) * 0.5, } ) print() From 3cedf4e061bf4d3ae792800049a3014b5a0c5f8a Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Thu, 28 Jul 2022 14:50:21 +0100 Subject: [PATCH 18/29] update configs --- pax/conf/experiment/lola.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index 943448fc..7c7360cc 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -16,8 +16,8 @@ payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] # Training hyperparameters num_envs: 128 num_steps: 150 # number of steps per episode -total_timesteps: 10_000_000 -eval_every: 10_000_000 # timesteps +total_timesteps: 1_000_000 +eval_every: 1_000_000 # timesteps # Useful information # num_episodes = total_timesteps / num_steps From b37aa91054d1b527356230956c6d7c25da398bb2 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Thu, 28 Jul 2022 15:24:40 +0100 Subject: [PATCH 19/29] tidy up --- pax/conf/experiment/naive.yaml | 2 +- pax/conf/experiment/ppo_memory.yaml | 6 +++--- pax/dqn/agent.py | 1 - pax/runner.py | 4 ---- 4 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pax/conf/experiment/naive.yaml b/pax/conf/experiment/naive.yaml index 2d94d79d..3345f514 100644 --- a/pax/conf/experiment/naive.yaml +++ b/pax/conf/experiment/naive.yaml @@ -2,7 +2,7 @@ # Agents agent1: 'Naive' -agent2: 'TitForTat' +agent2: 'Naive' # Environment env_id: ipd diff --git a/pax/conf/experiment/ppo_memory.yaml b/pax/conf/experiment/ppo_memory.yaml index 9686e5e0..18fe7750 100644 --- a/pax/conf/experiment/ppo_memory.yaml +++ b/pax/conf/experiment/ppo_memory.yaml @@ -14,7 +14,7 @@ payoff: # Training hyperparameters num_envs: 100 num_steps: 25 # number of steps per episode -total_timesteps: 2_000_000 +total_timesteps: 1_000_000 eval_every: 50_000 # timesteps # Useful information @@ -34,7 +34,7 @@ ppo: max_gradient_norm: 0.5 anneal_entropy: True entropy_coeff_start: 0.2 - entropy_coeff_horizon: 1_000_000 + entropy_coeff_horizon: 500_000 entropy_coeff_end: 0.001 lr_scheduling: True learning_rate: 2.5e-2 @@ -45,6 +45,6 @@ ppo: wandb: entity: "ucl-dark" project: ipd - group: 'PPO_memory-vs-${agent2}-${game}-${entropy_coeff_horizon}' + group: 'PPO_memory-vs-${agent2}-${game}' name: run-seed-${seed} log: True diff --git a/pax/dqn/agent.py b/pax/dqn/agent.py index 6fef9d63..a11e27e7 100644 --- a/pax/dqn/agent.py +++ b/pax/dqn/agent.py @@ -183,7 +183,6 @@ def update( timestep: dm_env.TimeStep, action: jnp.array, new_timestep: dm_env.TimeStep, - other_agents=None, ): self._replay.add_batch( diff --git a/pax/runner.py b/pax/runner.py index 229111be..eee25dae 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -69,14 +69,10 @@ def train_loop(self, env, agents, num_episodes, watchers): ) agents.in_lookahead(env) agents.out_lookahead(env) - # print("Agent 1 params", agents.agents[0]._state.params["categorical_value_head/~/linear"]["w"]) - # print("Agent 2 params", agents.agents[1]._state.params["categorical_value_head/~/linear"]["w"]) # end of episode stats self.train_episodes += env.num_envs rewards_0 = jnp.array(rewards_0) rewards_1 = jnp.array(rewards_1) - # rewards_0 = jnp.array(agents.agents[0].rewards) - # rewards_1 = jnp.array(agents.agents[1].rewards) mean_r_0 = float(rewards_1.mean()) mean_r_1 = float(rewards_1.mean()) From 99c7906d924d13e0f6ea1de41ea66767a2445eb3 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Fri, 29 Jul 2022 20:06:29 +0100 Subject: [PATCH 20/29] add working lola with new runner using lax.scan --- pax/centralized_learners.py | 1 + pax/conf/config.yaml | 47 ++--- pax/conf/experiment/debug.yaml | 2 +- pax/conf/experiment/lola.yaml | 2 +- pax/experiment.py | 4 +- pax/lola/lola.py | 339 +++++++++++++++++++-------------- pax/runner.py | 45 ++++- pax/watchers.py | 21 +- 8 files changed, 266 insertions(+), 195 deletions(-) diff --git a/pax/centralized_learners.py b/pax/centralized_learners.py index 658375d5..b8341688 100644 --- a/pax/centralized_learners.py +++ b/pax/centralized_learners.py @@ -38,6 +38,7 @@ def out_lookahead(self, env): agent.out_lookahead(env, other_agents) counter += 1 + # TODO: Obselete at the moment. This can be put into the LOLA. def update( self, old_timesteps: List[TimeStep], diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index 52978665..d11fed27 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -14,13 +14,8 @@ save_dir: "./exp/${wandb.group}/${wandb.name}" debug: False # Agents -<<<<<<< HEAD agent1: 'LOLA' agent2: 'LOLA' -======= -agent1: 'Hyper' -agent2: 'NaiveLearnerEx' ->>>>>>> main # Environment env_id: ipd @@ -31,10 +26,10 @@ payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] centralized: True # Training hyperparameters -num_envs: 10 -num_steps: 100 # number of steps per episode -total_timesteps: 100_000_000 -eval_every: 50_000 # eval every n episodes, not timesteps +num_envs: 128 +num_steps: 150 # number of steps per episode +total_timesteps: 1_000_000 +eval_every: 100_000 # eval every n episodes, not timesteps # Useful information @@ -42,6 +37,16 @@ eval_every: 50_000 # eval every n episodes, not timesteps # num_updates = num_episodes / eval_every # batch_size = num_envs * num_steps +# DQN agent parameters +dqn: + batch_size: 256 + discount: 0.99 + learning_rate: 1e-2 + epsilon: 0.5 + replay_capacity: 100000 + min_replay_size: 1000 + sgd_period: 1 + target_update_period: 4 # PPO agent parameters ppo: @@ -70,10 +75,12 @@ naive: # LOLA agent parameters lola: - replay_capacity: 1000 #args.lola.replay_capacity, - min_replay_size: 50 #args.lola.min_replay_size, - sgd_period: 1 #args.dqn.sgd_period, - batch_size: 2 #args.dqn.batch_size + use_baseline: True + adam_epsilon: 1e-5 + lr_in: 0.3 + lr_out: 0.2 + gamma: 0.96 + num_lookaheads: 1 # Logging setup wandb: @@ -81,20 +88,6 @@ wandb: project: ipd group: 'LOLA-vs-${agent2}-${game}' name: run-seed-${seed} -<<<<<<< HEAD log: False -======= - log: True -# DQN agent parameters -dqn: - batch_size: 256 - discount: 0.99 - learning_rate: 1e-2 - epsilon: 0.5 - replay_capacity: 100000 - min_replay_size: 1000 - sgd_period: 1 - target_update_period: 4 ->>>>>>> main diff --git a/pax/conf/experiment/debug.yaml b/pax/conf/experiment/debug.yaml index 245d34d1..d9c8171c 100644 --- a/pax/conf/experiment/debug.yaml +++ b/pax/conf/experiment/debug.yaml @@ -4,5 +4,5 @@ debug: true wandb: group: debug - log: true + log: False diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index 7c7360cc..85c46995 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -16,7 +16,7 @@ payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] # Training hyperparameters num_envs: 128 num_steps: 150 # number of steps per episode -total_timesteps: 1_000_000 +total_timesteps: 10_000_000 eval_every: 1_000_000 # timesteps # Useful information diff --git a/pax/experiment.py b/pax/experiment.py index 1228c40a..5189498a 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -5,17 +5,17 @@ import omegaconf import wandb +from pax.centralized_learners import CentralizedLearners from pax.dqn.agent import default_agent from pax.env import SequentialMatrixGame from pax.hyper.ppo import make_hyper from pax.independent_learners import IndependentLearners -from pax.centralized_learners import CentralizedLearners +from pax.lola.lola import make_lola from pax.meta_env import InfiniteMatrixGame from pax.naive_exact import NaiveLearnerEx from pax.naive.naive import make_naive_pg from pax.ppo.ppo import make_agent from pax.ppo.ppo_gru import make_gru_agent -from pax.lola.lola import make_lola from pax.runner import Runner from pax.strategies import ( Altruistic, diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 35e060cc..25eb7fa7 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -1,6 +1,8 @@ from binascii import a2b_base64 from typing import Any, Mapping, NamedTuple, Tuple, Dict +from regex import R + from pax import utils from pax.lola.buffer import TrajectoryBuffer from pax.lola.network import make_network @@ -13,6 +15,18 @@ import optax +class Sample(NamedTuple): + """Object containing a batch of data""" + + obs_self: jnp.ndarray + obs_other: jnp.ndarray + actions_self: jnp.ndarray + actions_other: jnp.ndarray + dones: jnp.ndarray + rewards_self: jnp.ndarray + rewards_other: jnp.ndarray + + class Batch(NamedTuple): """A batch of data; all shapes are expected to be [B, ...].""" @@ -36,6 +50,7 @@ class TrainingState(NamedTuple): random_key: jnp.ndarray timesteps: int extras: Mapping[str, jnp.ndarray] + hidden: None def magic_box(x): @@ -80,9 +95,45 @@ def policy( random_key=key, timesteps=state.timesteps, extras=state.extras, + hidden=None, ) return actions, state + @jax.jit + def prepare_batch( + traj_batch: NamedTuple, t_prime: TimeStep, action_extras: dict + ): + # Rollouts complete -> Training begins + # Add an additional rollout step for advantage calculation + + _value = jax.lax.select( + t_prime.last(), + action_extras["values"], + jnp.zeros_like(action_extras["values"]), + ) + + _value = jax.lax.expand_dims(_value, [0]) + _reward = jax.lax.expand_dims(t_prime.reward, [0]) + _done = jax.lax.select( + t_prime.last(), + 2 * jnp.ones_like(_value), + jnp.zeros_like(_value), + ) + + # need to add final value here + traj_batch = traj_batch._replace( + behavior_values=jnp.concatenate( + [traj_batch.behavior_values, _value], axis=0 + ) + ) + traj_batch = traj_batch._replace( + rewards=jnp.concatenate([traj_batch.rewards, _reward], axis=0) + ) + traj_batch = traj_batch._replace( + dones=jnp.concatenate([traj_batch.dones, _done], axis=0) + ) + return traj_batch + def loss(params, other_params, samples): # Stacks so that the dimension is now (num_envs, num_steps) @@ -150,6 +201,20 @@ def loss(params, other_params, samples): "loss_value": loss_value, } + def sgd_step( + state: TrainingState, sample: NamedTuple + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + state = state + # placeholders + results = { + "loss_total": 0, + "loss_policy": 0, + "loss_value": 0, + "loss_entropy": 0, + "entropy_cost": 0, + } + return state, results + def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" key, subkey = jax.random.split(key) @@ -162,7 +227,11 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: opt_state=initial_opt_state, random_key=key, timesteps=0, - extras={"values": None, "log_probs": None}, + extras={ + "values": jnp.zeros(num_envs), + "log_probs": jnp.zeros(num_envs), + }, + hidden=None, ) # Initialise training state (parameters, optimiser state, extras). @@ -194,6 +263,8 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: # Initialize functions self._policy = policy self.network = network + self._prepare_batch = jax.jit(prepare_batch) + self._sgd_step = sgd_step # initialize some variables self._inner_optimizer = inner_optimizer @@ -215,7 +286,7 @@ def select_action(self, t: TimeStep): ) return utils.to_numpy(actions) - def in_lookahead(self, env, other_agents): + def in_lookahead(self, env, other_agents, env_rollout): """ Performs a rollout using the current parameters of both agents and simulates a naive learning update step for the other agent @@ -225,48 +296,65 @@ def in_lookahead(self, env, other_agents): other_agents: list, a list of objects of the other agents """ - # initialize buffer - other_memory = TrajectoryBuffer( - self._num_envs, self._num_steps, self._obs_spec - ) - - # get the other agent + # get other agent other_agent = other_agents[0] - # get a copy of the agent's state - my_state = self._state - # create a temporary training state for the other agent for the simulated rollout - # TODO: make a separate optimizer for the other agent - other_state = TrainingState( - params=other_agent._state.params, - opt_state=other_agent._state.opt_state, - random_key=other_agent._state.random_key, - timesteps=other_agent._state.timesteps, - extras=other_agent._state.extras, - ) - # Perform a rollout and store S,A,R,S tuples - # TODO: Replace with jax.lax.scan - t = env.reset() - for _ in range(self._num_steps): - # I take an action using my parameters - a1, my_state = self._policy( - my_state.params, t[0].observation, my_state - ) - # Opponent takes an action using their parameters - a2, other_state = self._policy( - other_state.params, t[1].observation, other_state - ) - # The environment steps as usual - t_prime = env.step((a1, a2)) - # We add to the buffer as if we are the other player - # TODO: IS THIS THE RIGHT ORDER??????? - other_memory.add(t[1], t[0], a2, a1, t_prime[1], t_prime[0]) - t = t_prime + # my state + my_state = TrainingState( + params=self._state.params, + opt_state=self._state.opt_state, + random_key=self._state.random_key, + timesteps=self._state.timesteps, + extras={ + "values": jnp.zeros(self._num_envs), + "log_probs": jnp.zeros(self._num_envs), + }, + hidden=None, + ) + # other player's state + other_state = other_agent.reset_memory() + + # do a full rollout + t_init = env.reset() + vals, trajectories = jax.lax.scan( + env_rollout, + (t_init[0], t_init[1], my_state, other_state), + None, + length=env.episode_length, + ) + # update agent / add final trajectory + # do an agent update based on trajectories + # only need the other agent's state # unpack the values from the buffer - sample = other_memory.sample() - # calculate the gradients + # update agent / add final trajectory + # TODO: Unclear if this is still needed when not calculating advantage? + # final_t1 = vals[0]._replace(step_type=2) + # a1_state = vals[2] + # _, state = self._policy(my_state.params, final_t1.observation, a1_state) + # traj_batch_0 = self._prepare_batch(trajectories[0], final_t1, state.extras) + + # final_t2 = vals[1]._replace(step_type=2) + # a2_state = vals[3] + # _, state = self._policy(other_state.params, final_t2.observation, a2_state) + # traj_batch_1 = self._prepare_batch(trajectories[1], final_t2, other_state.extras) + + traj_batch_0 = trajectories[0] + traj_batch_1 = trajectories[1] + # flip the order of the trajectories + # assuming we're the other player + sample = Sample( + obs_self=traj_batch_1.observations, + obs_other=traj_batch_0.observations, + actions_self=traj_batch_1.actions, + actions_other=traj_batch_0.actions, + dones=traj_batch_0.dones, + rewards_self=traj_batch_1.rewards, + rewards_other=traj_batch_0.rewards, + ) + + # get gradients of opponent gradients, _ = self.grad_fn( other_state.params, my_state.params, sample ) @@ -286,9 +374,10 @@ def in_lookahead(self, env, other_agents): random_key=other_state.random_key, timesteps=other_state.timesteps, extras=other_state.extras, + hidden=None, ) - def out_lookahead(self, env, other_agents): + def out_lookahead(self, env, env_rollout): """ Performs a real rollout using the current parameters of both agents and a naive learning update step for the other agent @@ -297,42 +386,44 @@ def out_lookahead(self, env, other_agents): env: SequentialMatrixGame, an environment object of the game being played other_agents: list, a list of objects of the other agents """ - - # initialize buffer - memory = TrajectoryBuffer( - self._num_envs, self._num_steps, self._obs_spec - ) - # get a copy of the agent's state - my_state = self._state + my_state = TrainingState( + params=self._state.params, + opt_state=self._state.opt_state, + random_key=self._state.random_key, + timesteps=self._state.timesteps, + extras={ + "values": jnp.zeros(self._num_envs), + "log_probs": jnp.zeros(self._num_envs), + }, + hidden=None, + ) # get a copy of the other opponent's state + # TODO: Do I need to reset this? Maybe... other_state = self._other_state - # Perform a rollout and store S,A,R,S tuples - t = env.reset() - # TODO: Replace with jax.lax.scan - for _ in range(self._num_steps): - # I take an action using my parameters - a1, my_state = self._policy( - my_state.params, t[0].observation, my_state - ) - # Opponent takes an action using their parameters - a2, other_state = self._policy( - other_state.params, t[1].observation, other_state - ) - # The environment steps as usual - t_prime = env.step((a1, a2)) - # We add to the buffer - # TODO: IS THIS THE RIGHT ORDER??????? - memory.add(t[0], t[1], a1, a2, t_prime[0], t_prime[1]) - t = t_prime - # Update internal agent's timesteps - self._total_steps += self._num_envs - self._logger.metrics["total_steps"] += self._num_envs - self._state._replace(timesteps=self._total_steps) + # do a full rollout + t_init = env.reset() + vals, trajectories = jax.lax.scan( + env_rollout, + (t_init[0], t_init[1], my_state, other_state), + None, + length=env.episode_length, + ) - # unpack the values from the buffer - sample = memory.sample() + traj_batch_0 = trajectories[0] + traj_batch_1 = trajectories[1] + + # Now keep the same order. + sample = Sample( + obs_self=traj_batch_0.observations, + obs_other=traj_batch_1.observations, + actions_self=traj_batch_0.actions, + actions_other=traj_batch_1.actions, + dones=traj_batch_0.dones, + rewards_self=traj_batch_0.rewards, + rewards_other=traj_batch_1.rewards, + ) # calculate the gradients gradients, results = self.grad_fn( @@ -347,6 +438,11 @@ def out_lookahead(self, env, other_agents): # apply the optimizer updates params = optax.apply_updates(my_state.params, updates) + # Update internal agent's timesteps + self._total_steps += self._num_envs + self._logger.metrics["total_steps"] += self._num_envs + self._state._replace(timesteps=self._total_steps) + self._logger.metrics["sgd_steps"] += ( self._num_minibatches * self._num_epochs ) @@ -361,98 +457,47 @@ def out_lookahead(self, env, other_agents): random_key=self._state.random_key, timesteps=self._state.timesteps, extras={"log_probs": None, "values": None}, + hidden=None, ) + def reset_memory(self) -> TrainingState: + self._state = self._state._replace( + extras={ + "values": jnp.zeros(self._num_envs), + "log_probs": jnp.zeros(self._num_envs), + } + ) + return self._state + def update( self, - t: TimeStep, - actions: np.array, + traj_batch, t_prime: TimeStep, - other_agents: list = None, + state, ): - # Adds agent and environment info to buffer - self._rollouts( - buffer=self._trajectory_buffer, - t=t, - actions=actions, - t_prime=t_prime, - state=self._state, - ) + """Update the agent -> only called at the end of a trajectory""" + _, state = self._policy(state.params, t_prime.observation, state) - # Log metrics - self._total_steps += self._num_envs - self._logger.metrics["total_steps"] += self._num_envs - - # Update internal state with total_steps - self._state = TrainingState( - params=self._state.params, - opt_state=self._state.opt_state, - random_key=self._state.random_key, - timesteps=self._total_steps, - extras=self._state.extras, - ) - - # Update counter until doing SGD - self._until_sgd += 1 - - # Add params to buffer - # doesn't change throughout the rollout - # but it can't be before the return - self._trajectory_buffer.params = self._state.params - - # Rollouts onging - if self._until_sgd % (self._num_steps) != 0: - return - - # Rollouts complete -> Training begins - # Add an additional rollout step for advantage calculation - _, self._state = self._policy( - self._state.params, t_prime.observation, self._state - ) - # print("Other agents params", other_agents[0]._trajectory_buffer.params) - - self._trajectory_buffer.add( - timestep=t_prime, - action=0, - log_prob=0, - value=self._state.extras["values"] - if not t_prime.last() - else jnp.zeros_like(self._state.extras["values"]), - new_timestep=t_prime, + traj_batch = self._prepare_batch(traj_batch, t_prime, state.extras) + state, results = self._sgd_step(state, traj_batch) + self._logger.metrics["sgd_steps"] += ( + self._num_minibatches * self._num_epochs ) + self._logger.metrics["loss_total"] = results["loss_total"] + self._logger.metrics["loss_policy"] = results["loss_policy"] + self._logger.metrics["loss_value"] = results["loss_value"] + self._logger.metrics["loss_entropy"] = results["loss_entropy"] + self._logger.metrics["entropy_cost"] = results["entropy_cost"] - # other_agent = other_agents[0] - # other_agent_params = other_agents[0]._trajectory_buffer.params - # It needs to be able to take in the opponents parameters - # and then do a rollout under those parameters - # could do sgd here for other agent? - sample = self._trajectory_buffer.sample() - self._state, results = self._sgd_step( - self._state, self.other_params, sample - ) - # self._logger.metrics["sgd_steps"] += ( - # self._num_minibatches * self._num_epochs - # ) - # self._logger.metrics["loss_total"] = results["loss_total"] - # self._logger.metrics["loss_policy"] = results["loss_policy"] - # self._logger.metrics["loss_value"] = results["loss_value"] + self._state = state + return state def make_lola(args, obs_spec, action_spec, seed: int, player_id: int): """Make Naive Learner Policy Gradient agent""" - # print(f"Making network for {args.env_id}") network = make_network(action_spec) - # Optimizer - # batch_size = int(args.num_envs * args.num_steps) - # transition_steps = ( - # args.total_timesteps - # / batch_size - # * args.ppo.num_epochs - # * args.ppo.num_minibatches - # ) - inner_optimizer = optax.chain( optax.scale_by_adam(eps=args.lola.adam_epsilon), optax.scale(-args.lola.lr_in), diff --git a/pax/runner.py b/pax/runner.py index 25d1ac22..f9169dc9 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -69,14 +69,40 @@ def _env_rollout(carry, unused): t2.observation, a2, tprime_2.reward, - None, - None, + a2_state.extras["log_probs"], + a2_state.extras["values"], tprime_2.last() * jnp.zeros(env.num_envs), - None, + a2_state.hidden, ) return (tprime_1, tprime_2, a1_state, a2_state), (traj1, traj2) for _ in range(0, max(int(num_episodes / env.num_envs), 1)): + # start of unique lola code + # LOLA updates occur here + if self.args.agent1 == "LOLA" and self.args.agent2 == "LOLA": + # inner rollout + for _ in range(self.args.lola.num_lookaheads): + agent1.in_lookahead(env, [agent2], _env_rollout) + agent2.in_lookahead(env, [agent1], _env_rollout) + # outer rollout + agent1.out_lookahead(env, _env_rollout) + agent2.out_lookahead(env, _env_rollout) + + elif self.args.agent1 == "LOLA" and self.args.agent2 != "LOLA": + # inner rollout + for _ in range(self.args.lola.num_lookaheads): + agent1.in_lookahead(env, [agent2], _env_rollout) + # outer rollout + agent1.out_lookahead(env, _env_rollout) + + elif self.args.agent1 != "LOLA" and self.args.agent2 == "LOLA": + # inner rollout + for _ in range(self.args.lola.num_lookaheads): + agent2.in_lookahead(env, [agent1], _env_rollout) + # outer rollout + agent2.out_lookahead(env, _env_rollout) + # end of unique lola code + t_init = env.reset() a1_state = agent1.reset_memory() a2_state = agent2.reset_memory() @@ -97,10 +123,6 @@ def _env_rollout(carry, unused): rewards_0 = trajectories[0].rewards.mean() rewards_1 = trajectories[1].rewards.mean() - # print( - # f"Total Episode Reward: {mean_r_0, mean_r_1} | Joint reward: {(mean_r_0 + mean_r_1)*0.5}" - # ) - # update agent / add final trajectory final_t1 = vals[0]._replace(step_type=2) a1_state = vals[2] @@ -110,6 +132,14 @@ def _env_rollout(carry, unused): a2_state = vals[3] a2_state = agent2.update(trajectories[1], final_t2, a2_state) + print( + f"Total Episode Reward: {float(rewards_0.mean()), float(rewards_1.mean())}" + ) + + # print( + # f"Total Episode Reward: {mean_r_0, mean_r_1} | Joint reward: {(mean_r_0 + mean_r_1)*0.5}" + # ) + # logging if watchers: agents.log(watchers) @@ -126,6 +156,7 @@ def _env_rollout(carry, unused): ) print() + # TODO: Why do we need this if we already update the state in agent.update? # update agents agents.agents[0]._state = a1_state agents.agents[1]._state = a2_state diff --git a/pax/watchers.py b/pax/watchers.py index 2e663020..b12e1a4f 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -4,6 +4,7 @@ from pax.lola.lola import LOLA import pax.hyper.ppo as HyperPPO import pax.ppo.ppo as PPO +from pax.naive.naive import NaiveLearner from pax.naive_exact import NaiveLearnerEx from .env import State @@ -99,16 +100,16 @@ def value_logger_ppo(agent: PPO) -> dict: return probs -# def value_logger_naive(agent: NaiveLearner) -> dict: -# weights = agent._state.params["categorical_value_head/~/linear_1"][ -# "w" -# ] # 5 x 1 matrix -# sgd_steps = agent._total_steps / agent._num_steps -# probs = { -# f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) -# } -# probs.update({"value/total_steps": sgd_steps}) -# return probs +def value_logger_naive(agent: NaiveLearner) -> dict: + weights = agent._state.params["categorical_value_head/~/linear_1"][ + "w" + ] # 5 x 1 matrix + sgd_steps = agent._total_steps / agent._num_steps + probs = { + f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) + } + probs.update({"value/total_steps": sgd_steps}) + return probs def value_logger_lola(agent: LOLA) -> dict: From dbb3937c1bfe9b4401fc29ecec3f4fe522e7cc7e Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Sat, 30 Jul 2022 02:07:19 +0100 Subject: [PATCH 21/29] tidy up watchers, fix naive learner, LOLA getting exploited hard .... --- pax/conf/config.yaml | 18 ++- pax/conf/experiment/lola.yaml | 2 +- pax/experiment.py | 90 ++++++------- pax/lola/lola.py | 99 +++------------ pax/naive/naive.py | 145 ++++++++++----------- pax/ppo/ppo.py | 12 +- pax/runner.py | 1 - pax/utils.py | 15 ++- pax/watchers.py | 231 ++++++++++------------------------ 9 files changed, 232 insertions(+), 381 deletions(-) diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index d11fed27..2196f0c7 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -71,7 +71,23 @@ ppo: # Naive Learner parameters naive: - lr: 1.0 + num_minibatches: 1 + num_epochs: 1 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.2 + entropy_coeff_horizon: 500_000 + # ^this should 1/2 of the total timesteps + entropy_coeff_end: 0.01 + lr_scheduling: True + learning_rate: 1 + adam_epsilon: 1e-5 + with_memory: False # LOLA agent parameters lola: diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index 85c46995..f4256db9 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -28,7 +28,7 @@ eval_every: 1_000_000 # timesteps lola: use_baseline: True adam_epsilon: 1e-5 - lr_in: 0.3 + lr_in: 1 lr_out: 0.2 gamma: 0.96 diff --git a/pax/experiment.py b/pax/experiment.py index 5189498a..a29295f8 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -32,19 +32,14 @@ from pax.watchers import ( logger_hyper, logger_naive, + losses_logger, losses_naive, losses_ppo, + policy_logger, policy_logger_dqn, - policy_logger_ppo, policy_logger_ppo_with_memory, - policy_logger_naive, - policy_logger_lola, - losses_lola, - value_logger_lola, - value_logger_naive, + value_logger, value_logger_dqn, - value_logger_ppo, - naive_pg_losses, ) @@ -268,12 +263,12 @@ def get_naive_learner(seed, player_id): "Grim": GrimTrigger, "DQN": get_DQN_agent, "PPO": get_PPO_agent, - "Naive": get_naive_pg_agent, + "NaivePG": get_naive_pg_agent, "LOLA": get_LOLA_agent, "PPO_memory": get_PPO_memory_agent, # HyperNetworks "Hyper": get_hyper_agent, - "NaiveLearnerEx": get_naive_learner, + "NaiveEx": get_naive_learner, "HyperAltruistic": HyperAltruistic, "HyperDefect": HyperDefect, "HyperTFT": HyperTFT, @@ -315,62 +310,59 @@ def dqn_log(agent): wandb.log(policy_dict) return - def ppo_log(agent): + def dumb_log(agent, *args): + return + + def hyper_log(agent): losses = losses_ppo(agent) - if args.ppo.with_memory: - policy = policy_logger_ppo_with_memory(agent) - else: - policy = policy_logger_ppo(agent) - value = value_logger_ppo(agent) - losses.update(value) + policy = logger_hyper(agent) losses.update(policy) if args.wandb.log: wandb.log(losses) return - def naive_log(agent): - losses = naive_pg_losses(agent) - policy = policy_logger_naive(agent) - value = value_logger_naive(agent) + def lola_log(agent): + losses = losses_logger(agent) + policy = policy_logger(agent) + value = value_logger(agent) losses.update(value) losses.update(policy) if args.wandb.log: wandb.log(losses) return - def lola_log(agent): - losses = losses_lola(agent) - policy = policy_logger_lola(agent) - value = value_logger_lola(agent) - losses.update(value) + def naive_ex_log(agent): + losses = losses_naive(agent) + policy = logger_naive(agent) losses.update(policy) if args.wandb.log: wandb.log(losses) return - def dumb_log(agent, *args): - return - - def hyper_log(agent): - losses = losses_ppo(agent) - policy = logger_hyper(agent) + def naive_pg_log(agent): + losses = losses_logger(agent) + policy = policy_logger(agent) + value = value_logger(agent) + losses.update(value) losses.update(policy) if args.wandb.log: wandb.log(losses) return - def naive_logger(agent): - losses = losses_naive(agent) - policy = logger_naive(agent) + def ppo_log(agent): + losses = losses_ppo(agent) + policy = policy_logger(agent) + value = value_logger(agent) + losses.update(value) losses.update(policy) if args.wandb.log: wandb.log(losses) return - def naive_pg_log(agent): - losses = naive_pg_losses(agent) - policy = policy_logger_ppo(agent) - value = value_logger_ppo(agent) + def ppo_memory_log(agent): + losses = losses_ppo(agent) + policy = policy_logger_ppo_with_memory(agent) + value = value_logger(agent) losses.update(value) losses.update(policy) if args.wandb.log: @@ -378,22 +370,22 @@ def naive_pg_log(agent): return strategies = { - "TitForTat": dumb_log, - "Defect": dumb_log, "Altruistic": dumb_log, - "Human": dumb_log, - "Random": dumb_log, - "Grim": dumb_log, + "Defect": dumb_log, "DQN": dqn_log, - "PPO": ppo_log, - "LOLA": lola_log, - "PPO_memory": ppo_log, - "Naive": naive_pg_log, + "Grim": dumb_log, + "Human": dumb_log, "Hyper": hyper_log, - "NaiveLearnerEx": naive_logger, "HyperAltruistic": dumb_log, "HyperDefect": dumb_log, "HyperTFT": dumb_log, + "LOLA": lola_log, + "NaivePG": naive_pg_log, + "NaiveEx": naive_ex_log, + "PPO": ppo_log, + "PPO_memory": ppo_memory_log, + "Random": dumb_log, + "TitForTat": dumb_log, } assert args.agent1 in strategies diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 25eb7fa7..6d20798e 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -1,11 +1,9 @@ -from binascii import a2b_base64 from typing import Any, Mapping, NamedTuple, Tuple, Dict -from regex import R - from pax import utils from pax.lola.buffer import TrajectoryBuffer from pax.lola.network import make_network +from pax.utils import TrainingState from dm_env import TimeStep import haiku as hk @@ -42,17 +40,6 @@ class Batch(NamedTuple): behavior_log_probs: jnp.ndarray -class TrainingState(NamedTuple): - """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" - - params: hk.Params - opt_state: optax.GradientTransformation - random_key: jnp.ndarray - timesteps: int - extras: Mapping[str, jnp.ndarray] - hidden: None - - def magic_box(x): return jnp.exp(x - jax.lax.stop_gradient(x)) @@ -99,41 +86,6 @@ def policy( ) return actions, state - @jax.jit - def prepare_batch( - traj_batch: NamedTuple, t_prime: TimeStep, action_extras: dict - ): - # Rollouts complete -> Training begins - # Add an additional rollout step for advantage calculation - - _value = jax.lax.select( - t_prime.last(), - action_extras["values"], - jnp.zeros_like(action_extras["values"]), - ) - - _value = jax.lax.expand_dims(_value, [0]) - _reward = jax.lax.expand_dims(t_prime.reward, [0]) - _done = jax.lax.select( - t_prime.last(), - 2 * jnp.ones_like(_value), - jnp.zeros_like(_value), - ) - - # need to add final value here - traj_batch = traj_batch._replace( - behavior_values=jnp.concatenate( - [traj_batch.behavior_values, _value], axis=0 - ) - ) - traj_batch = traj_batch._replace( - rewards=jnp.concatenate([traj_batch.rewards, _reward], axis=0) - ) - traj_batch = traj_batch._replace( - dones=jnp.concatenate([traj_batch.dones, _done], axis=0) - ) - return traj_batch - def loss(params, other_params, samples): # Stacks so that the dimension is now (num_envs, num_steps) @@ -263,7 +215,6 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: # Initialize functions self._policy = policy self.network = network - self._prepare_batch = jax.jit(prepare_batch) self._sgd_step = sgd_step # initialize some variables @@ -312,7 +263,21 @@ def in_lookahead(self, env, other_agents, env_rollout): hidden=None, ) # other player's state - other_state = other_agent.reset_memory() + init_params = other_agent._state.params + other_opt_state = self._inner_optimizer.init(other_agent._state.params) + # other_state = other_agent.state.copy() + other_state = TrainingState( + params=init_params, + opt_state=other_opt_state, + random_key=other_agent._state.random_key, + timesteps=other_agent._state.timesteps, + extras={ + "values": jnp.zeros(self._num_envs), + "log_probs": jnp.zeros(self._num_envs), + }, + hidden=None, + ) + # other_state = other_agent.reset_memory() # do a full rollout t_init = env.reset() @@ -323,23 +288,6 @@ def in_lookahead(self, env, other_agents, env_rollout): length=env.episode_length, ) - # update agent / add final trajectory - # do an agent update based on trajectories - # only need the other agent's state - # unpack the values from the buffer - - # update agent / add final trajectory - # TODO: Unclear if this is still needed when not calculating advantage? - # final_t1 = vals[0]._replace(step_type=2) - # a1_state = vals[2] - # _, state = self._policy(my_state.params, final_t1.observation, a1_state) - # traj_batch_0 = self._prepare_batch(trajectories[0], final_t1, state.extras) - - # final_t2 = vals[1]._replace(step_type=2) - # a2_state = vals[3] - # _, state = self._policy(other_state.params, final_t2.observation, a2_state) - # traj_batch_1 = self._prepare_batch(trajectories[1], final_t2, other_state.extras) - traj_batch_0 = trajectories[0] traj_batch_1 = trajectories[1] # flip the order of the trajectories @@ -360,6 +308,7 @@ def in_lookahead(self, env, other_agents, env_rollout): ) # Update the optimizer + updates, opt_state = self._inner_optimizer.update( gradients, other_state.opt_state ) @@ -476,20 +425,6 @@ def update( state, ): """Update the agent -> only called at the end of a trajectory""" - _, state = self._policy(state.params, t_prime.observation, state) - - traj_batch = self._prepare_batch(traj_batch, t_prime, state.extras) - state, results = self._sgd_step(state, traj_batch) - self._logger.metrics["sgd_steps"] += ( - self._num_minibatches * self._num_epochs - ) - self._logger.metrics["loss_total"] = results["loss_total"] - self._logger.metrics["loss_policy"] = results["loss_policy"] - self._logger.metrics["loss_value"] = results["loss_value"] - self._logger.metrics["loss_entropy"] = results["loss_entropy"] - self._logger.metrics["entropy_cost"] = results["entropy_cost"] - - self._state = state return state diff --git a/pax/naive/naive.py b/pax/naive/naive.py index 1ba1f997..6a7306d4 100644 --- a/pax/naive/naive.py +++ b/pax/naive/naive.py @@ -5,6 +5,7 @@ from pax import utils from pax.naive.buffer import TrajectoryBuffer from pax.naive.network import make_network +from pax.utils import TrainingState from dm_env import TimeStep import haiku as hk @@ -29,14 +30,14 @@ class Batch(NamedTuple): behavior_log_probs: jnp.ndarray -class TrainingState(NamedTuple): - """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" +# class TrainingState(NamedTuple): +# """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" - params: hk.Params - opt_state: optax.GradientTransformation - random_key: jnp.ndarray - timesteps: int - extras: Mapping[str, jnp.ndarray] +# params: hk.Params +# opt_state: optax.GradientTransformation +# random_key: jnp.ndarray +# timesteps: int +# extras: Mapping[str, jnp.ndarray] class Logger: @@ -79,9 +80,45 @@ def policy( random_key=key, timesteps=state.timesteps, extras=state.extras, + hidden=None, ) return actions, state + @jax.jit + def prepare_batch( + traj_batch: NamedTuple, t_prime: TimeStep, action_extras: dict + ): + # Rollouts complete -> Training begins + # Add an additional rollout step for advantage calculation + + _value = jax.lax.select( + t_prime.last(), + action_extras["values"], + jnp.zeros_like(action_extras["values"]), + ) + + _value = jax.lax.expand_dims(_value, [0]) + _reward = jax.lax.expand_dims(t_prime.reward, [0]) + _done = jax.lax.select( + t_prime.last(), + 2 * jnp.ones_like(_value), + jnp.zeros_like(_value), + ) + + # need to add final value here + traj_batch = traj_batch._replace( + behavior_values=jnp.concatenate( + [traj_batch.behavior_values, _value], axis=0 + ) + ) + traj_batch = traj_batch._replace( + rewards=jnp.concatenate([traj_batch.rewards, _reward], axis=0) + ) + traj_batch = traj_batch._replace( + dones=jnp.concatenate([traj_batch.dones, _done], axis=0) + ) + return traj_batch + def rollouts( buffer: TrajectoryBuffer, t: TimeStep, @@ -174,23 +211,13 @@ def sgd_step( ) # vmap - batch_gae_advantages = jax.vmap(gae_advantages, in_axes=0) - advantages, target_values = batch_gae_advantages( + advantages, target_values = gae_advantages( rewards=rewards, values=behavior_values, dones=dones ) # Exclude the last step - it was only used for bootstrapping. # The shape is [num_envs, num_steps, ..] - ( - observations, - actions, - behavior_log_probs, - behavior_values, - ) = jax.tree_map( - lambda x: x[:, :-1], - (observations, actions, behavior_log_probs, behavior_values), - ) - + behavior_values = behavior_values[:-1, :] trajectories = Batch( observations=observations, actions=actions, @@ -306,7 +333,11 @@ def model_update_epoch( opt_state=opt_state, random_key=key, timesteps=timesteps, - extras={"log_probs": None, "values": None}, + extras={ + "log_probs": jnp.zeros(num_envs), + "values": jnp.zeros(num_envs), + }, + hidden=None, ) return new_state, metrics @@ -324,6 +355,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: random_key=key, timesteps=0, extras={"values": None, "log_probs": None}, + hidden=None, ) # Initialise training state (parameters, optimiser state, extras). @@ -333,7 +365,6 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: self._trajectory_buffer = TrajectoryBuffer( num_envs, num_steps, obs_spec ) - self._sgd_step = sgd_step # Set up counters and logger self._logger = Logger() @@ -350,6 +381,8 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: # Initialize functions self._policy = policy self._rollouts = rollouts + self._prepare_batch = jax.jit(prepare_batch) + self._sgd_step = sgd_step # Other useful hyperparameters self._num_envs = num_envs # number of environments @@ -365,66 +398,34 @@ def select_action(self, t: TimeStep): ) return utils.to_numpy(actions) + def reset_memory(self) -> TrainingState: + self._state = self._state._replace( + extras={ + "values": jnp.zeros(self._num_envs), + "log_probs": jnp.zeros(self._num_envs), + } + ) + return self._state + def update( self, - t: TimeStep, - actions: np.array, + traj_batch, t_prime: TimeStep, - other_agents=None, + state, ): - # Adds agent and environment info to buffer - self._rollouts( - buffer=self._trajectory_buffer, - t=t, - actions=actions, - t_prime=t_prime, - state=self._state, - ) - - # Log metrics - self._total_steps += self._num_envs - self._logger.metrics["total_steps"] += self._num_envs - - # Update internal state with total_steps - self._state = TrainingState( - params=self._state.params, - opt_state=self._state.opt_state, - random_key=self._state.random_key, - timesteps=self._total_steps, - extras=self._state.extras, - ) - - # Update counter until doing SGD - self._until_sgd += 1 - - # Rollouts onging - if self._until_sgd % (self._num_steps) != 0: - return - - # Rollouts complete -> Training begins - # Add an additional rollout step for advantage calculation - _, self._state = self._policy( - self._state.params, t_prime.observation, self._state - ) - - self._trajectory_buffer.add( - timestep=t_prime, - action=0, - log_prob=0, - value=self._state.extras["values"] - if not t_prime.last() - else jnp.zeros_like(self._state.extras["values"]), - new_timestep=t_prime, - ) + """Update the agent -> only called at the end of a trajectory""" + _, state = self._policy(state.params, t_prime.observation, state) - sample = self._trajectory_buffer.sample() - self._state, results = self._sgd_step(self._state, sample) + traj_batch = self._prepare_batch(traj_batch, t_prime, state.extras) + state, results = self._sgd_step(state, traj_batch) self._logger.metrics["sgd_steps"] += ( self._num_minibatches * self._num_epochs ) self._logger.metrics["loss_total"] = results["loss_total"] self._logger.metrics["loss_policy"] = results["loss_policy"] self._logger.metrics["loss_value"] = results["loss_value"] + self._state = state + return state def make_naive_pg(args, obs_spec, action_spec, seed: int, player_id: int): @@ -465,7 +466,7 @@ def make_naive_pg(args, obs_spec, action_spec, seed: int, player_id: int): # Random key random_key = jax.random.PRNGKey(seed=seed) - return NaiveLearner( + agent = NaiveLearner( network=network, optimizer=optimizer, random_key=random_key, @@ -477,6 +478,8 @@ def make_naive_pg(args, obs_spec, action_spec, seed: int, player_id: int): gamma=args.naive.gamma, gae_lambda=args.naive.gae_lambda, ) + agent.player_id = player_id + return agent if __name__ == "__main__": diff --git a/pax/ppo/ppo.py b/pax/ppo/ppo.py index cdccfeb6..68d14b56 100644 --- a/pax/ppo/ppo.py +++ b/pax/ppo/ppo.py @@ -10,6 +10,7 @@ from pax import utils from pax.ppo.networks import make_cartpole_network, make_network +from pax.utils import TrainingState class Batch(NamedTuple): @@ -27,17 +28,6 @@ class Batch(NamedTuple): behavior_log_probs: jnp.ndarray -class TrainingState(NamedTuple): - """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" - - params: hk.Params - opt_state: optax.GradientTransformation - random_key: jnp.ndarray - timesteps: int - extras: Mapping[str, jnp.ndarray] - hidden: None - - class Logger: metrics: dict diff --git a/pax/runner.py b/pax/runner.py index f9169dc9..7ffaf579 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -3,7 +3,6 @@ import jax import jax.numpy as jnp -from dm_env import transition import wandb from pax.env import IteratedPrisonersDilemma diff --git a/pax/utils.py b/pax/utils.py index 4f3430ac..67de3097 100644 --- a/pax/utils.py +++ b/pax/utils.py @@ -1,10 +1,23 @@ from time import time as tic - +from typing import Mapping, NamedTuple +import optax +import haiku as hk import jax import jax.numpy as jnp import numpy as np +class TrainingState(NamedTuple): + """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" + + params: hk.Params + opt_state: optax.GradientTransformation + random_key: jnp.ndarray + timesteps: int + extras: Mapping[str, jnp.ndarray] + hidden: None + + class Section(object): """ Examples diff --git a/pax/watchers.py b/pax/watchers.py index b12e1a4f..2a0ddbc5 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -19,28 +19,82 @@ ALL_STATES = [START, CC, DC, CD, DD] +# General policy logger def policy_logger(agent) -> dict: - weights = agent.actor_optimizer.target["Dense_0"][ - "kernel" - ] # [layer_name]['w'] - log_pi = nn.softmax(weights) + weights = agent._state.params["categorical_value_head/~/linear"]["w"] + pi = nn.softmax(weights) + # sgd_steps = agent._total_steps / agent._num_steps + sgd_steps = agent._logger.metrics["sgd_steps"] probs = { - "policy/" + str(s): p[0] for (s, p) in zip(State, log_pi) - } # probability of cooperating is p[0] + f"policy/{agent.player_id}/{str(s)}": p[0] for (s, p) in zip(State, pi) + } + probs.update({"policy/total_steps": sgd_steps}) return probs +# General value logger def value_logger(agent) -> dict: - weights = agent.critic_optimizer.target["Dense_0"]["kernel"] - values = { - f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) + weights = agent._state.params["categorical_value_head/~/linear_1"]["w"] + # sgd_steps = agent._total_steps / agent._num_steps + sgd_steps = agent._logger.metrics["sgd_steps"] + probs = { + f"value/{agent.player_id}/{str(s)}": p[0] + for (s, p) in zip(State, weights) } - values.update( - {f"value/{str(s)}.defect": p[1] for (s, p) in zip(State, weights)} - ) - return values + probs.update({"value/total_steps": sgd_steps}) + return probs + + +# General loss logger +def losses_logger(agent) -> None: + sgd_steps = agent._logger.metrics["sgd_steps"] + loss_total = agent._logger.metrics["loss_total"] + loss_policy = agent._logger.metrics["loss_policy"] + loss_value = agent._logger.metrics["loss_value"] + losses = { + "sgd_steps": sgd_steps, + f"loss/{agent.player_id}/total": loss_total, + f"loss/{agent.player_id}/policy": loss_policy, + f"loss/{agent.player_id}/value": loss_value, + } + return losses + + +# Loss logger for PPO (has additional terms) +def losses_ppo(agent: PPO) -> dict: + sgd_steps = agent._logger.metrics["sgd_steps"] + loss_total = agent._logger.metrics["loss_total"] + loss_policy = agent._logger.metrics["loss_policy"] + loss_value = agent._logger.metrics["loss_value"] + loss_entropy = agent._logger.metrics["loss_entropy"] + entropy_cost = agent._logger.metrics["entropy_cost"] + losses = { + "sgd_steps": sgd_steps, + f"loss/{agent.player_id}/total": loss_total, + f"loss/{agent.player_id}/policy": loss_policy, + f"loss/{agent.player_id}/value": loss_value, + f"loss/{agent.player_id}/entropy": loss_entropy, + f"loss/{agent.player_id}/entropy_coefficient": entropy_cost, + } + return losses + + +# General policy logger for networks with more than on layer +def policy_logger_ppo_with_memory(agent) -> dict: + """Calculate probability of cooperation""" + params = agent._state.params + hidden = agent._state.hidden + episode = int(agent._logger.metrics["total_steps"] / agent._num_steps) + cooperation_probs = {"episode": episode} + for state, state_name in zip(ALL_STATES, STATE_NAMES): + (dist, _), hidden = agent.forward(params, state, hidden) + cooperation_probs[f"policy/{agent.player_id}/{state_name}"] = float( + dist.probs[0][0] + ) + return cooperation_probs +# TODO: Clean these up (possibly redundant) def policy_logger_dqn(agent) -> None: # this assumes using a linear layer, so this logging won't work using MLP weights = agent._state.target_params["linear"]["w"] # 5 x 2 matrix @@ -79,85 +133,6 @@ def value_logger_dqn(agent) -> dict: return values -def policy_logger_ppo(agent: PPO) -> dict: - weights = agent._state.params["categorical_value_head/~/linear"]["w"] - pi = nn.softmax(weights) - sgd_steps = agent._total_steps / agent._num_steps - probs = {f"policy/{str(s)}.cooperate": p[0] for (s, p) in zip(State, pi)} - probs.update({"policy/total_steps": sgd_steps}) - return probs - - -def value_logger_ppo(agent: PPO) -> dict: - weights = agent._state.params["categorical_value_head/~/linear_1"][ - "w" - ] # 5 x 1 matrix - sgd_steps = agent._total_steps / agent._num_steps - probs = { - f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) - } - probs.update({"value/total_steps": sgd_steps}) - return probs - - -def value_logger_naive(agent: NaiveLearner) -> dict: - weights = agent._state.params["categorical_value_head/~/linear_1"][ - "w" - ] # 5 x 1 matrix - sgd_steps = agent._total_steps / agent._num_steps - probs = { - f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) - } - probs.update({"value/total_steps": sgd_steps}) - return probs - - -def value_logger_lola(agent: LOLA) -> dict: - weights = agent._state.params["categorical_value_head/~/linear_1"][ - "w" - ] # 5 x 1 matrix - sgd_steps = agent._total_steps / agent._num_steps - probs = { - f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights) - } - probs.update({"value/total_steps": sgd_steps}) - return probs - - -def policy_logger_ppo_with_memory(agent) -> dict: - """Calculate probability of coopreation""" - # n = 5 - params = agent._state.params - hidden = agent._state.hidden - # episode = int( - # agent._logger.metrics["total_steps"] - # / (agent._num_steps * agent._num_envs) - # ) - episode = int(agent._logger.metrics["total_steps"] / agent._num_steps) - cooperation_probs = {"episode": episode} - - # TODO: Figure out how to JIT the forward function - # Works when the forward function is not jitted. - for state, state_name in zip(ALL_STATES, STATE_NAMES): - (dist, _), hidden = agent.forward(params, state, hidden) - cooperation_probs[f"policy/{state_name}"] = float(dist.probs[0][0]) - return cooperation_probs - - -def naive_pg_losses(agent) -> None: - sgd_steps = agent._logger.metrics["sgd_steps"] - loss_total = agent._logger.metrics["loss_total"] - loss_policy = agent._logger.metrics["loss_policy"] - loss_value = agent._logger.metrics["loss_value"] - losses = { - "sgd_steps": sgd_steps, - "train/total": loss_total, - "train/policy": loss_policy, - "train/value": loss_value, - } - return losses - - def logger_hyper(agent: HyperPPO) -> dict: episode = int( agent._logger.metrics["total_steps"] @@ -167,78 +142,6 @@ def logger_hyper(agent: HyperPPO) -> dict: return cooperation_probs -def naive_losses(agent) -> None: - sgd_steps = agent._logger.metrics["sgd_steps"] - loss_total = agent._logger.metrics["loss_total"] - loss_policy = agent._logger.metrics["loss_policy"] - loss_value = agent._logger.metrics["loss_value"] - losses = { - "sgd_steps": sgd_steps, - "train/total": loss_total, - "train/policy": loss_policy, - "train/value": loss_value, - } - return losses - - -def losses_lola(agent) -> None: - sgd_steps = agent._logger.metrics["sgd_steps"] - loss_total = agent._logger.metrics["loss_total"] - loss_policy = agent._logger.metrics["loss_policy"] - loss_value = agent._logger.metrics["loss_value"] - losses = { - "sgd_steps": sgd_steps, - "train/total": loss_total, - "train/policy": loss_policy, - "train/value": loss_value, - } - return losses - - -def losses_ppo(agent: PPO) -> dict: - pid = agent.player_id - sgd_steps = agent._logger.metrics["sgd_steps"] - loss_total = agent._logger.metrics["loss_total"] - loss_policy = agent._logger.metrics["loss_policy"] - loss_value = agent._logger.metrics["loss_value"] - loss_entropy = agent._logger.metrics["loss_entropy"] - entropy_cost = agent._logger.metrics["entropy_cost"] - losses = { - f"train/ppo_{pid}/sgd_steps": sgd_steps, - f"train/ppo_{pid}/total": loss_total, - f"train/ppo_{pid}/policy": loss_policy, - f"train/ppo_{pid}/value": loss_value, - f"train/ppo_{pid}/entropy": loss_entropy, - f"train/ppo_{pid}/entropy_coefficient": entropy_cost, - } - return losses - - -def policy_logger_naive(agent) -> None: - weights = agent._state.params["categorical_value_head/~/linear"]["w"] - pi = nn.softmax(weights) - sgd_steps = agent._total_steps / agent._num_steps - probs = { - f"policy/{str(s)}/{agent.player_id}/player_{agent.player_id}.cooperate": p[ - 0 - ] - for (s, p) in zip(State, pi) - } - probs.update({"policy/total_steps": sgd_steps}) - return probs - - -def policy_logger_lola(agent) -> None: - weights = agent._state.params["categorical_value_head/~/linear"]["w"] - pi = nn.softmax(weights) - sgd_steps = agent._total_steps / agent._num_steps - probs = { - f"policy/{agent.player_id}/{str(s)}": p[0] for (s, p) in zip(State, pi) - } - probs.update({"policy/total_steps": sgd_steps}) - return probs - - def losses_naive(agent: NaiveLearnerEx) -> dict: pid = agent.player_id sgd_steps = agent._logger.metrics["sgd_steps"] From 11b98c6fd98a80eb462d01ab33cdd3c363fa0d3c Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Sat, 30 Jul 2022 02:29:59 +0100 Subject: [PATCH 22/29] tidy up watchers, fix naive learner, LOLA getting exploited hard .... --- pax/conf/config.yaml | 12 ++---------- pax/conf/experiment/lola.yaml | 2 +- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index 2196f0c7..e542f947 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -75,19 +75,11 @@ naive: num_epochs: 1 gamma: 0.96 gae_lambda: 0.95 - ppo_clipping_epsilon: 0.2 - value_coeff: 0.5 clip_value: True - max_gradient_norm: 0.5 - anneal_entropy: True - entropy_coeff_start: 0.2 - entropy_coeff_horizon: 500_000 - # ^this should 1/2 of the total timesteps - entropy_coeff_end: 0.01 - lr_scheduling: True + max_gradient_norm: 1 + lr_scheduling: False learning_rate: 1 adam_epsilon: 1e-5 - with_memory: False # LOLA agent parameters lola: diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index f4256db9..a0fc9647 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -2,7 +2,7 @@ # Agents agent1: 'LOLA' -agent2: 'LOLA' +agent2: 'NaivePG' centralized: True # Environment From aeb6426047d09a6598303efb990765ff1d591ccc Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Mon, 1 Aug 2022 01:23:11 +0100 Subject: [PATCH 23/29] lola compiles, move TrainingState to utils --- pax/conf/config.yaml | 7 +- pax/conf/experiment/lola.yaml | 13 ++- pax/experiment.py | 2 +- pax/lola/lola.py | 210 +++++++++++++++++++--------------- pax/lola/network.py | 8 +- pax/naive/naive.py | 1 + pax/runner.py | 39 ++++++- pax/utils.py | 17 +++ 8 files changed, 186 insertions(+), 111 deletions(-) diff --git a/pax/conf/config.yaml b/pax/conf/config.yaml index e542f947..ebf542d8 100644 --- a/pax/conf/config.yaml +++ b/pax/conf/config.yaml @@ -9,7 +9,7 @@ hydra: level: INFO # Global variables -seed: 0 +seed: 25 save_dir: "./exp/${wandb.group}/${wandb.name}" debug: False @@ -52,7 +52,7 @@ dqn: ppo: num_minibatches: 10 num_epochs: 4 - gamma: 0.75 + gamma: 0.96 gae_lambda: 0.95 ppo_clipping_epsilon: 0.2 value_coeff: 0.5 @@ -60,7 +60,7 @@ ppo: max_gradient_norm: 0.5 anneal_entropy: True entropy_coeff_start: 0.2 - entropy_coeff_horizon: 500_000 + entropy_coeff_horizon: 5_000_000 # for halfway, the horizon should (1/2) * (total_timesteps / num_envs) entropy_coeff_end: 0.01 lr_scheduling: True @@ -75,7 +75,6 @@ naive: num_epochs: 1 gamma: 0.96 gae_lambda: 0.95 - clip_value: True max_gradient_norm: 1 lr_scheduling: False learning_rate: 1 diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index a0fc9647..6df9a6f7 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -2,7 +2,7 @@ # Agents agent1: 'LOLA' -agent2: 'NaivePG' +agent2: 'LOLA' centralized: True # Environment @@ -16,8 +16,8 @@ payoff: [[-1,-1], [-3,0], [0,-3], [-2,-2]] # Training hyperparameters num_envs: 128 num_steps: 150 # number of steps per episode -total_timesteps: 10_000_000 -eval_every: 1_000_000 # timesteps +total_timesteps: 4_000_000 +eval_every: 4_000_000 # timesteps # Useful information # num_episodes = total_timesteps / num_steps @@ -28,14 +28,15 @@ eval_every: 1_000_000 # timesteps lola: use_baseline: True adam_epsilon: 1e-5 - lr_in: 1 + lr_in: 0.3 lr_out: 0.2 gamma: 0.96 + num_lookaheads: 0 # Logging setup wandb: entity: "ucl-dark" project: ipd group: 'LOLA-vs-${agent2}-${game}' - name: run-seed-${seed} - log: True + name: run-seed-${seed}-${lola.num_lookaheads}-lookaheads + log: False diff --git a/pax/experiment.py b/pax/experiment.py index a29295f8..d8e4f8d2 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -288,7 +288,7 @@ def get_naive_learner(seed, player_id): agent_0 = strategies[args.agent1](seeds[0], pids[0]) # player 1 agent_1 = strategies[args.agent2](seeds[1], pids[1]) # player 2 - logger.info(f"PPO with memory: {args.ppo.with_memory}") + # logger.info(f"PPO with memory: {args.ppo.with_memory}") logger.info(f"Agent Pair: {args.agent1} | {args.agent2}") logger.info(f"Agent seeds: {seeds[0]} | {seeds[1]}") diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 6d20798e..045919af 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -61,8 +61,6 @@ def __init__( obs_spec: Tuple, num_envs: int = 4, num_steps: int = 150, - num_minibatches: int = 1, - num_epochs: int = 1, use_baseline: bool = True, gamma: float = 0.96, ): @@ -86,26 +84,79 @@ def policy( ) return actions, state - def loss(params, other_params, samples): - # Stacks so that the dimension is now (num_envs, num_steps) - + def outer_loss(params, other_params, samples): + """Used for the outer rollout""" + # Unpack the samples obs_1 = samples.obs_self obs_2 = samples.obs_other - rewards = samples.rewards_self - # r_1 = samples.rewards_self - # r_2 = samples.rewards_other - actions_1 = samples.actions_self actions_2 = samples.actions_other - # distribution, values_self = self.network.apply(params, obs_1) + # Get distribution and value using my network distribution, values = self.network.apply(params, obs_1) self_log_prob = distribution.log_prob(actions_1) - distribution, values_others = self.network.apply( - other_params, obs_2 + # Get distribution and value using other player's network + distribution, _ = self.other_network.apply(other_params, obs_2) + other_log_prob = distribution.log_prob(actions_2) + + # apply discount: + cum_discount = ( + jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=1) + / self.gamma ) + + discounted_rewards = rewards * cum_discount + discounted_values = values * cum_discount + + # stochastics nodes involved in rewards dependencies: + dependencies = jnp.cumsum(self_log_prob + other_log_prob, axis=1) + + # logprob of each stochastic nodes: + stochastic_nodes = self_log_prob + other_log_prob + + # dice objective: + dice_objective = jnp.mean( + jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) + ) + + if use_baseline: + # variance_reduction: + baseline_term = jnp.mean( + jnp.sum( + (1 - magic_box(stochastic_nodes)) * discounted_values, + axis=1, + ) + ) + dice_objective = dice_objective + baseline_term + + # want to minimize this value + value_objective = jnp.mean((rewards - values) ** 2) + + # want to maximize this objective + loss_total = -dice_objective + value_objective + + return loss_total, { + "loss_total": -dice_objective + value_objective, + "loss_policy": -dice_objective, + "loss_value": value_objective, + } + + def inner_loss(params, other_params, samples): + """Used for the inner rollout""" + obs_1 = samples.obs_self + obs_2 = samples.obs_other + rewards = samples.rewards_self + actions_1 = samples.actions_self + actions_2 = samples.actions_other + + # Get distribution and value using other player's network + distribution, values = self.other_network.apply(params, obs_1) + self_log_prob = distribution.log_prob(actions_1) + + # Get distribution and value using my network + distribution, _ = self.network.apply(other_params, obs_2) other_log_prob = distribution.log_prob(actions_2) # apply discount: @@ -137,20 +188,16 @@ def loss(params, other_params, samples): ) dice_objective = dice_objective + baseline_term - loss_value = jnp.mean((rewards - values) ** 2) - # loss_total = -dice_objective + loss_value - loss_total = dice_objective + loss_value + # want to minimize this value + value_objective = jnp.mean((rewards - values) ** 2) + + # want to maximize this objective + loss_total = -dice_objective + value_objective - # want to minimize -objective - # return loss_total, { - # "loss_total": -dice_objective + loss_value, - # "loss_policy": -dice_objective, - # "loss_value": loss_value, - # } return loss_total, { - "loss_total": dice_objective + loss_value, - "loss_policy": dice_objective, - "loss_value": loss_value, + "loss_total": -dice_objective + value_objective, + "loss_policy": -dice_objective, + "loss_value": value_objective, } def sgd_step( @@ -198,7 +245,9 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: ) # self.grad_fn = jax.grad(loss, has_aux=True) - self.grad_fn = jax.jit(jax.grad(loss, has_aux=True)) + self.grad_fn_inner = jax.jit(jax.grad(inner_loss, has_aux=True)) + self.grad_fn_outer = jax.jit(jax.grad(outer_loss, has_aux=True)) + # self.grad_fn_outer = jax.grad(outer_loss, has_aux=True) # Set up counters and logger self._logger = Logger() @@ -226,8 +275,6 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: self._num_envs = num_envs # number of environments self._num_steps = num_steps # number of steps per environment self._batch_size = int(num_envs * num_steps) # number in one batch - self._num_minibatches = num_minibatches # number of minibatches - self._num_epochs = num_epochs # number of epochs to use sample self._obs_spec = obs_spec def select_action(self, t: TimeStep): @@ -237,19 +284,15 @@ def select_action(self, t: TimeStep): ) return utils.to_numpy(actions) - def in_lookahead(self, env, other_agents, env_rollout): + def in_lookahead(self, env, env_rollout): """ Performs a rollout using the current parameters of both agents and simulates a naive learning update step for the other agent INPUT: env: SequentialMatrixGame, an environment object of the game being played - other_agents: list, a list of objects of the other agents """ - # get other agent - other_agent = other_agents[0] - # my state my_state = TrainingState( params=self._state.params, @@ -262,34 +305,19 @@ def in_lookahead(self, env, other_agents, env_rollout): }, hidden=None, ) - # other player's state - init_params = other_agent._state.params - other_opt_state = self._inner_optimizer.init(other_agent._state.params) - # other_state = other_agent.state.copy() - other_state = TrainingState( - params=init_params, - opt_state=other_opt_state, - random_key=other_agent._state.random_key, - timesteps=other_agent._state.timesteps, - extras={ - "values": jnp.zeros(self._num_envs), - "log_probs": jnp.zeros(self._num_envs), - }, - hidden=None, - ) - # other_state = other_agent.reset_memory() # do a full rollout t_init = env.reset() - vals, trajectories = jax.lax.scan( + _, trajectories = jax.lax.scan( env_rollout, - (t_init[0], t_init[1], my_state, other_state), + (t_init[0], t_init[1], my_state, self.other_state), None, length=env.episode_length, ) traj_batch_0 = trajectories[0] traj_batch_1 = trajectories[1] + # flip the order of the trajectories # assuming we're the other player sample = Sample( @@ -303,26 +331,28 @@ def in_lookahead(self, env, other_agents, env_rollout): ) # get gradients of opponent - gradients, _ = self.grad_fn( - other_state.params, my_state.params, sample + gradients, _ = self.grad_fn_inner( + self.other_state.params, my_state.params, sample ) # Update the optimizer updates, opt_state = self._inner_optimizer.update( - gradients, other_state.opt_state + gradients, self.other_state.opt_state ) # apply the optimizer updates - params = optax.apply_updates(other_state.params, updates) + params = optax.apply_updates(self.other_state.params, updates) + + # self._other_state = other_state # replace the other player's current parameters with a simulated update - self._other_state = TrainingState( + self.other_state = TrainingState( params=params, opt_state=opt_state, - random_key=other_state.random_key, - timesteps=other_state.timesteps, - extras=other_state.extras, + random_key=self.other_state.random_key, + timesteps=self.other_state.timesteps, + extras=self.other_state.extras, hidden=None, ) @@ -335,7 +365,7 @@ def out_lookahead(self, env, env_rollout): env: SequentialMatrixGame, an environment object of the game being played other_agents: list, a list of objects of the other agents """ - # get a copy of the agent's state + # make my own state my_state = TrainingState( params=self._state.params, opt_state=self._state.opt_state, @@ -347,35 +377,42 @@ def out_lookahead(self, env, env_rollout): }, hidden=None, ) - # get a copy of the other opponent's state - # TODO: Do I need to reset this? Maybe... - other_state = self._other_state + # a reference to the other agent's state (from runner) + other_state = TrainingState( + params=self.other_state.params, + opt_state=self.other_state.opt_state, + random_key=self.other_state.random_key, + timesteps=self.other_state.timesteps, + extras={ + "values": jnp.zeros(self._num_envs), + "log_probs": jnp.zeros(self._num_envs), + }, + hidden=None, + ) + # self.other_state # do a full rollout t_init = env.reset() - vals, trajectories = jax.lax.scan( + _, trajectories = jax.lax.scan( env_rollout, (t_init[0], t_init[1], my_state, other_state), None, length=env.episode_length, ) - traj_batch_0 = trajectories[0] - traj_batch_1 = trajectories[1] - # Now keep the same order. sample = Sample( - obs_self=traj_batch_0.observations, - obs_other=traj_batch_1.observations, - actions_self=traj_batch_0.actions, - actions_other=traj_batch_1.actions, - dones=traj_batch_0.dones, - rewards_self=traj_batch_0.rewards, - rewards_other=traj_batch_1.rewards, + obs_self=trajectories[0].observations, + obs_other=trajectories[1].observations, + actions_self=trajectories[0].actions, + actions_other=trajectories[1].actions, + dones=trajectories[0].dones, + rewards_self=trajectories[0].rewards, + rewards_other=trajectories[1].rewards, ) # calculate the gradients - gradients, results = self.grad_fn( + gradients, results = self.grad_fn_outer( my_state.params, other_state.params, sample ) @@ -392,14 +429,13 @@ def out_lookahead(self, env, env_rollout): self._logger.metrics["total_steps"] += self._num_envs self._state._replace(timesteps=self._total_steps) - self._logger.metrics["sgd_steps"] += ( - self._num_minibatches * self._num_epochs - ) + # Logging + self._logger.metrics["sgd_steps"] += 1 self._logger.metrics["loss_total"] = results["loss_total"] self._logger.metrics["loss_policy"] = results["loss_policy"] self._logger.metrics["loss_value"] = results["loss_value"] - # replace the other player's current parameters with a simulated update + # replace the player's current parameters with a real update self._state = TrainingState( params=params, opt_state=opt_state, @@ -430,18 +466,12 @@ def update( def make_lola(args, obs_spec, action_spec, seed: int, player_id: int): """Make Naive Learner Policy Gradient agent""" - + # Create Haiku network network = make_network(action_spec) - - inner_optimizer = optax.chain( - optax.scale_by_adam(eps=args.lola.adam_epsilon), - optax.scale(-args.lola.lr_in), - ) - outer_optimizer = optax.chain( - optax.scale_by_adam(eps=args.lola.adam_epsilon), - optax.scale(-args.lola.lr_out), - ) - + # Inner optimizer uses SGD + inner_optimizer = optax.sgd(args.lola.lr_in) + # Outer optimizer uses Adam + outer_optimizer = optax.adam(args.lola.lr_out) # Random key random_key = jax.random.PRNGKey(seed=seed) @@ -454,8 +484,6 @@ def make_lola(args, obs_spec, action_spec, seed: int, player_id: int): player_id=player_id, num_envs=args.num_envs, num_steps=args.num_steps, - num_minibatches=args.ppo.num_minibatches, - num_epochs=args.ppo.num_epochs, use_baseline=args.lola.use_baseline, gamma=args.lola.gamma, ) diff --git a/pax/lola/network.py b/pax/lola/network.py index 5ecdc3be..8c87aecd 100644 --- a/pax/lola/network.py +++ b/pax/lola/network.py @@ -18,14 +18,14 @@ def __init__( super().__init__(name=name) self._logit_layer = hk.Linear( num_values, - # w_init=hk.initializers.Constant(0.5), - w_init=hk.initializers.RandomNormal(), + w_init=hk.initializers.Constant(0), + # w_init=hk.initializers.RandomNormal(), with_bias=False, ) self._value_layer = hk.Linear( 1, - # w_init=hk.initializers.Constant(0.5), - w_init=hk.initializers.RandomNormal(), + w_init=hk.initializers.Constant(0), + # w_init=hk.initializers.RandomNormal(), with_bias=False, ) diff --git a/pax/naive/naive.py b/pax/naive/naive.py index 6a7306d4..79d7773a 100644 --- a/pax/naive/naive.py +++ b/pax/naive/naive.py @@ -383,6 +383,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: self._rollouts = rollouts self._prepare_batch = jax.jit(prepare_batch) self._sgd_step = sgd_step + self.network = network # Other useful hyperparameters self._num_envs = num_envs # number of environments diff --git a/pax/runner.py b/pax/runner.py index 7ffaf579..807922cd 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -1,13 +1,16 @@ import time from typing import NamedTuple +import copy import jax import jax.numpy as jnp + import wandb from pax.env import IteratedPrisonersDilemma from pax.independent_learners import IndependentLearners from pax.strategies import Defect, TitForTat +from pax.utils import TrainingState, copy_state_and_network # TODO: make these a copy of acme @@ -76,31 +79,52 @@ def _env_rollout(carry, unused): return (tprime_1, tprime_2, a1_state, a2_state), (traj1, traj2) for _ in range(0, max(int(num_episodes / env.num_envs), 1)): + ######################### LOLA ######################### # start of unique lola code - # LOLA updates occur here if self.args.agent1 == "LOLA" and self.args.agent2 == "LOLA": + # copy state and haiku network + ( + agent1.other_state, + agent1.other_network, + ) = copy_state_and_network(agent2) + ( + agent2.other_state, + agent2.other_network, + ) = copy_state_and_network(agent1) # inner rollout for _ in range(self.args.lola.num_lookaheads): - agent1.in_lookahead(env, [agent2], _env_rollout) - agent2.in_lookahead(env, [agent1], _env_rollout) + agent1.in_lookahead(env, _env_rollout) + agent2.in_lookahead(env, _env_rollout) + # outer rollout agent1.out_lookahead(env, _env_rollout) agent2.out_lookahead(env, _env_rollout) elif self.args.agent1 == "LOLA" and self.args.agent2 != "LOLA": + # copy state and haiku network of agent 2 + ( + agent1.other_state, + agent1.other_network, + ) = copy_state_and_network(agent2) # inner rollout for _ in range(self.args.lola.num_lookaheads): - agent1.in_lookahead(env, [agent2], _env_rollout) + agent1.in_lookahead(env, _env_rollout) # outer rollout agent1.out_lookahead(env, _env_rollout) elif self.args.agent1 != "LOLA" and self.args.agent2 == "LOLA": + # copy state and haiku network of agent 1 + ( + agent2.other_state, + agent2.other_network, + ) = copy_state_and_network(agent1) # inner rollout for _ in range(self.args.lola.num_lookaheads): - agent2.in_lookahead(env, [agent1], _env_rollout) + agent2.in_lookahead(env, _env_rollout) # outer rollout agent2.out_lookahead(env, _env_rollout) # end of unique lola code + ######################### LOLA ######################### t_init = env.reset() a1_state = agent1.reset_memory() @@ -133,6 +157,7 @@ def _env_rollout(carry, unused): print( f"Total Episode Reward: {float(rewards_0.mean()), float(rewards_1.mean())}" + f"| Joint reward: {(rewards_0.mean() + rewards_1.mean())*0.5}" ) # print( @@ -151,6 +176,10 @@ def _env_rollout(carry, unused): "train/episode_reward/player_2": float( rewards_1.mean() ), + "train/episode_reward/joint": ( + rewards_0.mean() + rewards_1.mean() + ) + * 0.5, }, ) print() diff --git a/pax/utils.py b/pax/utils.py index 67de3097..56929e55 100644 --- a/pax/utils.py +++ b/pax/utils.py @@ -99,3 +99,20 @@ def add_batch_dim(values): def to_numpy(values): return jax.tree_map(np.asarray, values) + + +def copy_state_and_network(agent): + """Copies an agent state and returns the state""" + state = TrainingState( + params=agent._state.params, + opt_state=agent._state.opt_state, + random_key=agent._state.random_key, + timesteps=agent._state.timesteps, + extras={ + "values": jnp.zeros(agent._num_envs), + "log_probs": jnp.zeros(agent._num_envs), + }, + hidden=None, + ) + network = agent.network + return state, network From c9cd40c97ffc2ca866d76039a86ecee21f0b66ca Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Mon, 1 Aug 2022 11:53:25 +0100 Subject: [PATCH 24/29] lastest lola --- pax/conf/experiment/lola.yaml | 4 ++-- pax/lola/lola.py | 38 +++++++++++++++++-------------- pax/runner.py | 43 +++++++++++++++++++++++++++++++++-- 3 files changed, 64 insertions(+), 21 deletions(-) diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index 6df9a6f7..73e73969 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -26,7 +26,7 @@ eval_every: 4_000_000 # timesteps # LOLA agent parameters lola: - use_baseline: True + use_baseline: False adam_epsilon: 1e-5 lr_in: 0.3 lr_out: 0.2 @@ -39,4 +39,4 @@ wandb: project: ipd group: 'LOLA-vs-${agent2}-${game}' name: run-seed-${seed}-${lola.num_lookaheads}-lookaheads - log: False + log: True diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 045919af..72a531e9 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -244,10 +244,8 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: num_envs, num_steps, obs_spec ) - # self.grad_fn = jax.grad(loss, has_aux=True) self.grad_fn_inner = jax.jit(jax.grad(inner_loss, has_aux=True)) self.grad_fn_outer = jax.jit(jax.grad(outer_loss, has_aux=True)) - # self.grad_fn_outer = jax.grad(outer_loss, has_aux=True) # Set up counters and logger self._logger = Logger() @@ -306,28 +304,37 @@ def in_lookahead(self, env, env_rollout): hidden=None, ) + other_state = TrainingState( + params=self.other_state.params, + opt_state=self.other_state.opt_state, + random_key=self.other_state.random_key, + timesteps=self.other_state.timesteps, + extras={ + "values": jnp.zeros(self._num_envs), + "log_probs": jnp.zeros(self._num_envs), + }, + hidden=None, + ) + # do a full rollout t_init = env.reset() _, trajectories = jax.lax.scan( env_rollout, - (t_init[0], t_init[1], my_state, self.other_state), + (t_init[0], t_init[1], my_state, other_state), None, length=env.episode_length, ) - traj_batch_0 = trajectories[0] - traj_batch_1 = trajectories[1] - # flip the order of the trajectories # assuming we're the other player sample = Sample( - obs_self=traj_batch_1.observations, - obs_other=traj_batch_0.observations, - actions_self=traj_batch_1.actions, - actions_other=traj_batch_0.actions, - dones=traj_batch_0.dones, - rewards_self=traj_batch_1.rewards, - rewards_other=traj_batch_0.rewards, + obs_self=trajectories[1].observations, + obs_other=trajectories[0].observations, + actions_self=trajectories[1].actions, + actions_other=trajectories[0].actions, + dones=trajectories[0].dones, + rewards_self=trajectories[1].rewards, + rewards_other=trajectories[0].rewards, ) # get gradients of opponent @@ -344,8 +351,6 @@ def in_lookahead(self, env, env_rollout): # apply the optimizer updates params = optax.apply_updates(self.other_state.params, updates) - # self._other_state = other_state - # replace the other player's current parameters with a simulated update self.other_state = TrainingState( params=params, @@ -377,7 +382,7 @@ def out_lookahead(self, env, env_rollout): }, hidden=None, ) - # a reference to the other agent's state (from runner) + # copy the other person's state other_state = TrainingState( params=self.other_state.params, opt_state=self.other_state.opt_state, @@ -389,7 +394,6 @@ def out_lookahead(self, env, env_rollout): }, hidden=None, ) - # self.other_state # do a full rollout t_init = env.reset() diff --git a/pax/runner.py b/pax/runner.py index 807922cd..680f90d2 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -76,6 +76,45 @@ def _env_rollout(carry, unused): tprime_2.last() * jnp.zeros(env.num_envs), a2_state.hidden, ) + return ( + tprime_1, + tprime_2, + a1_state, + a2_state, + ), (traj1, traj2) + + def _env_rollout2(carry, unused): + t1, t2, a1_state, a2_state = carry + a1, a1_state = agent2._policy( + a1_state.params, t1.observation, a1_state + ) + a2, a2_state = agent1._policy( + a2_state.params, t2.observation, a2_state + ) + tprime_1, tprime_2 = env.runner_step( + [ + a1, + a2, + ] + ) + traj1 = Sample( + t1.observation, + a1, + tprime_1.reward, + a1_state.extras["log_probs"], + a1_state.extras["values"], + tprime_1.last() * jnp.zeros(env.num_envs), + a1_state.hidden, + ) + traj2 = Sample( + t2.observation, + a2, + tprime_2.reward, + a2_state.extras["log_probs"], + a2_state.extras["values"], + tprime_2.last() * jnp.zeros(env.num_envs), + a2_state.hidden, + ) return (tprime_1, tprime_2, a1_state, a2_state), (traj1, traj2) for _ in range(0, max(int(num_episodes / env.num_envs), 1)): @@ -94,11 +133,11 @@ def _env_rollout(carry, unused): # inner rollout for _ in range(self.args.lola.num_lookaheads): agent1.in_lookahead(env, _env_rollout) - agent2.in_lookahead(env, _env_rollout) + agent2.in_lookahead(env, _env_rollout2) # outer rollout agent1.out_lookahead(env, _env_rollout) - agent2.out_lookahead(env, _env_rollout) + agent2.out_lookahead(env, _env_rollout2) elif self.args.agent1 == "LOLA" and self.args.agent2 != "LOLA": # copy state and haiku network of agent 2 From dd29be02ddba354e7b0ffdee90bdaee682d09f74 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Mon, 1 Aug 2022 13:15:48 +0100 Subject: [PATCH 25/29] fix axis --- pax/conf/experiment/lola.yaml | 2 +- pax/lola/lola.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index 73e73969..43b74553 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -39,4 +39,4 @@ wandb: project: ipd group: 'LOLA-vs-${agent2}-${game}' name: run-seed-${seed}-${lola.num_lookaheads}-lookaheads - log: True + log: True diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 72a531e9..6af6f317 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -103,22 +103,23 @@ def outer_loss(params, other_params, samples): # apply discount: cum_discount = ( - jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=1) + jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=0) / self.gamma ) + # print(cum_discount) discounted_rewards = rewards * cum_discount discounted_values = values * cum_discount # stochastics nodes involved in rewards dependencies: - dependencies = jnp.cumsum(self_log_prob + other_log_prob, axis=1) + dependencies = jnp.cumsum(self_log_prob + other_log_prob, axis=0) # logprob of each stochastic nodes: stochastic_nodes = self_log_prob + other_log_prob # dice objective: dice_objective = jnp.mean( - jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) + jnp.sum(magic_box(dependencies) * discounted_rewards, axis=0) ) if use_baseline: @@ -126,7 +127,7 @@ def outer_loss(params, other_params, samples): baseline_term = jnp.mean( jnp.sum( (1 - magic_box(stochastic_nodes)) * discounted_values, - axis=1, + axis=0, ) ) dice_objective = dice_objective + baseline_term @@ -246,6 +247,7 @@ def make_initial_state(key: Any, obs_spec: Tuple) -> TrainingState: self.grad_fn_inner = jax.jit(jax.grad(inner_loss, has_aux=True)) self.grad_fn_outer = jax.jit(jax.grad(outer_loss, has_aux=True)) + # self.grad_fn_outer = jax.grad(outer_loss, has_aux=True) # Set up counters and logger self._logger = Logger() @@ -414,6 +416,7 @@ def out_lookahead(self, env, env_rollout): rewards_self=trajectories[0].rewards, rewards_other=trajectories[1].rewards, ) + # print("sample.obs_self.shape", sample.obs_self.shape) # calculate the gradients gradients, results = self.grad_fn_outer( From 78c8196d39b59082feb86470d30a06c9ee915ec4 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Mon, 1 Aug 2022 13:16:30 +0100 Subject: [PATCH 26/29] fix axis --- pax/lola/lola.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 6af6f317..2108fcf7 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -162,21 +162,21 @@ def inner_loss(params, other_params, samples): # apply discount: cum_discount = ( - jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=1) + jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=0) / self.gamma ) discounted_rewards = rewards * cum_discount discounted_values = values * cum_discount # stochastics nodes involved in rewards dependencies: - dependencies = jnp.cumsum(self_log_prob + other_log_prob, axis=1) + dependencies = jnp.cumsum(self_log_prob + other_log_prob, axis=0) # logprob of each stochastic nodes: stochastic_nodes = self_log_prob + other_log_prob # dice objective: dice_objective = jnp.mean( - jnp.sum(magic_box(dependencies) * discounted_rewards, axis=1) + jnp.sum(magic_box(dependencies) * discounted_rewards, axis=0) ) if use_baseline: @@ -184,7 +184,7 @@ def inner_loss(params, other_params, samples): baseline_term = jnp.mean( jnp.sum( (1 - magic_box(stochastic_nodes)) * discounted_values, - axis=1, + axis=0, ) ) dice_objective = dice_objective + baseline_term From c4b2e720eb89c48a6429818118b8784e8ef88201 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Mon, 1 Aug 2022 13:53:19 +0100 Subject: [PATCH 27/29] similar lola --- pax/lola/lola.py | 1 - pax/lola/network.py | 4 +++- pax/runner.py | 4 ++-- pax/utils.py | 4 +++- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 2108fcf7..34e5a7dd 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -106,7 +106,6 @@ def outer_loss(params, other_params, samples): jnp.cumprod(self.gamma * jnp.ones(rewards.shape), axis=0) / self.gamma ) - # print(cum_discount) discounted_rewards = rewards * cum_discount discounted_values = values * cum_discount diff --git a/pax/lola/network.py b/pax/lola/network.py index 8c87aecd..f82b0766 100644 --- a/pax/lola/network.py +++ b/pax/lola/network.py @@ -5,6 +5,7 @@ import distrax import haiku as hk import jax.numpy as jnp +import jax class CategoricalValueHead(hk.Module): @@ -30,7 +31,8 @@ def __init__( ) def __call__(self, inputs: jnp.ndarray): - logits = self._logit_layer(inputs) + # logits = self._logit_layer(inputs) + logits = jax.nn.sigmoid(self._logit_layer(inputs)) value = jnp.squeeze(self._value_layer(inputs), axis=-1) return (distrax.Categorical(logits=logits), value) diff --git a/pax/runner.py b/pax/runner.py index 680f90d2..2a620e45 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -133,11 +133,11 @@ def _env_rollout2(carry, unused): # inner rollout for _ in range(self.args.lola.num_lookaheads): agent1.in_lookahead(env, _env_rollout) - agent2.in_lookahead(env, _env_rollout2) + agent2.in_lookahead(env, _env_rollout) # outer rollout agent1.out_lookahead(env, _env_rollout) - agent2.out_lookahead(env, _env_rollout2) + agent2.out_lookahead(env, _env_rollout) elif self.args.agent1 == "LOLA" and self.args.agent2 != "LOLA": # copy state and haiku network of agent 2 diff --git a/pax/utils.py b/pax/utils.py index 56929e55..9f8f8758 100644 --- a/pax/utils.py +++ b/pax/utils.py @@ -102,9 +102,11 @@ def to_numpy(values): def copy_state_and_network(agent): + import copy + """Copies an agent state and returns the state""" state = TrainingState( - params=agent._state.params, + params=copy.deepcopy(agent._state.params), opt_state=agent._state.opt_state, random_key=agent._state.random_key, timesteps=agent._state.timesteps, From 2f302843b9c5894a04b0403af8d5165801a8bf19 Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Mon, 1 Aug 2022 17:21:36 +0100 Subject: [PATCH 28/29] half working lola --- pax/env.py | 2 ++ pax/lola/lola.py | 17 +++++++++++++++-- pax/lola/network.py | 33 +++++++++++++++++++++++++++++++-- pax/watchers.py | 8 ++++++++ 4 files changed, 56 insertions(+), 4 deletions(-) diff --git a/pax/env.py b/pax/env.py index 390c1e2e..407f5903 100644 --- a/pax/env.py +++ b/pax/env.py @@ -53,6 +53,8 @@ def step( return self.reset() action_1, action_2 = actions self._num_steps += 1 + # print("action_1.shape", action_1.shape) + # print("action_1", action_1) assert action_1.shape == action_2.shape assert action_1.shape == (self.num_envs,) diff --git a/pax/lola/lola.py b/pax/lola/lola.py index 34e5a7dd..6de4b50f 100644 --- a/pax/lola/lola.py +++ b/pax/lola/lola.py @@ -415,17 +415,25 @@ def out_lookahead(self, env, env_rollout): rewards_self=trajectories[0].rewards, rewards_other=trajectories[1].rewards, ) - # print("sample.obs_self.shape", sample.obs_self.shape) - + # print("Before updating") + # print("---------------------") + # print("params", self._state.params) + # print("opt_state", self._state.opt_state) + # print() # calculate the gradients gradients, results = self.grad_fn_outer( my_state.params, other_state.params, sample ) + # print("Gradients", gradients) + # print() # Update the optimizer updates, opt_state = self._outer_optimizer.update( gradients, my_state.opt_state ) + # print("Updates", updates) + # print("Updated optimizer", opt_state) + # print() # apply the optimizer updates params = optax.apply_updates(my_state.params, updates) @@ -450,6 +458,11 @@ def out_lookahead(self, env, env_rollout): extras={"log_probs": None, "values": None}, hidden=None, ) + # print("After updating") + # print("---------------------") + # print("params", self._state.params) + # print("opt_state", self._state.opt_state) + # print() def reset_memory(self) -> TrainingState: self._state = self._state._replace( diff --git a/pax/lola/network.py b/pax/lola/network.py index f82b0766..e195d715 100644 --- a/pax/lola/network.py +++ b/pax/lola/network.py @@ -31,12 +31,40 @@ def __init__( ) def __call__(self, inputs: jnp.ndarray): - # logits = self._logit_layer(inputs) - logits = jax.nn.sigmoid(self._logit_layer(inputs)) + logits = self._logit_layer(inputs) + # logits = jax.nn.sigmoid(self._logit_layer(inputs)) value = jnp.squeeze(self._value_layer(inputs), axis=-1) return (distrax.Categorical(logits=logits), value) +class BernoulliValueHead(hk.Module): + """Network head that produces a categorical distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._logit_layer = hk.Linear( + num_values, + w_init=hk.initializers.Constant(0), + with_bias=False, + ) + self._value_layer = hk.Linear( + 1, + w_init=hk.initializers.Constant(0), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + # matching the way that they do it. + logits = jnp.squeeze(self._logit_layer(inputs), axis=-1) + probs = jax.nn.sigmoid(logits) + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + return (distrax.Bernoulli(probs=1 - probs), value) + + def make_network(num_actions: int): """Creates a hk network using the baseline hyperparameters from OpenAI""" @@ -45,6 +73,7 @@ def forward_fn(inputs): layers.extend( [ CategoricalValueHead(num_values=num_actions), + # BernoulliValueHead(num_values=1), ] ) policy_value_network = hk.Sequential(layers) diff --git a/pax/watchers.py b/pax/watchers.py index 2a0ddbc5..258c429e 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -21,8 +21,13 @@ # General policy logger def policy_logger(agent) -> dict: + # Categorical plotting weights = agent._state.params["categorical_value_head/~/linear"]["w"] pi = nn.softmax(weights) + # ####### # # + # # Bernoulli plotting + # weights = agent._state.params["bernoulli_value_head/~/linear"]["w"] + # pi = nn.sigmoid(weights) # sgd_steps = agent._total_steps / agent._num_steps sgd_steps = agent._logger.metrics["sgd_steps"] probs = { @@ -34,7 +39,10 @@ def policy_logger(agent) -> dict: # General value logger def value_logger(agent) -> dict: + # Categorical plotting weights = agent._state.params["categorical_value_head/~/linear_1"]["w"] + # Bernoulli plotting + # weights = agent._state.params["bernoulli_value_head/~/linear_1"]["w"] # sgd_steps = agent._total_steps / agent._num_steps sgd_steps = agent._logger.metrics["sgd_steps"] probs = { From 4ad1b76ead5405a86f0349d742ddafde28a06fca Mon Sep 17 00:00:00 2001 From: Newton Kwan Date: Tue, 2 Aug 2022 13:43:45 +0100 Subject: [PATCH 29/29] temporary lola --- pax/conf/experiment/lola.yaml | 1 + pax/experiment.py | 6 ++- pax/lola/network.py | 76 +++++++++++++++++++++++++++++++++++ pax/runner.py | 56 +++++++++++++++++++++----- pax/utils.py | 39 +++++++++++++++++- pax/watchers.py | 12 ++++-- 6 files changed, 173 insertions(+), 17 deletions(-) diff --git a/pax/conf/experiment/lola.yaml b/pax/conf/experiment/lola.yaml index 43b74553..82646fcb 100644 --- a/pax/conf/experiment/lola.yaml +++ b/pax/conf/experiment/lola.yaml @@ -30,6 +30,7 @@ lola: adam_epsilon: 1e-5 lr_in: 0.3 lr_out: 0.2 + lr_value: 0.1 gamma: 0.96 num_lookaheads: 0 diff --git a/pax/experiment.py b/pax/experiment.py index d8e4f8d2..ef666915 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -10,7 +10,11 @@ from pax.env import SequentialMatrixGame from pax.hyper.ppo import make_hyper from pax.independent_learners import IndependentLearners -from pax.lola.lola import make_lola + +from pax.lola.lola_two_nets import make_lola + +# from pax.lola.lola import make_lola + from pax.meta_env import InfiniteMatrixGame from pax.naive_exact import NaiveLearnerEx from pax.naive.naive import make_naive_pg diff --git a/pax/lola/network.py b/pax/lola/network.py index e195d715..fdfd8698 100644 --- a/pax/lola/network.py +++ b/pax/lola/network.py @@ -65,6 +65,48 @@ def __call__(self, inputs: jnp.ndarray): return (distrax.Bernoulli(probs=1 - probs), value) +class PolicyHead(hk.Module): + """Network head that produces a categorical distribution.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._logit_layer = hk.Linear( + num_values, + w_init=hk.initializers.Constant(0), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + logits = jnp.squeeze(self._logit_layer(inputs), axis=-1) + probs = jax.nn.sigmoid(logits) + return distrax.Bernoulli(probs=1 - probs) + + +class ValueHead(hk.Module): + """Network head that produces a value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._value_layer = hk.Linear( + 1, + w_init=hk.initializers.Constant(0), + with_bias=False, + ) + + def __call__(self, inputs: jnp.ndarray): + + value = jnp.squeeze(self._value_layer(inputs), axis=-1) + return value + + def make_network(num_actions: int): """Creates a hk network using the baseline hyperparameters from OpenAI""" @@ -81,3 +123,37 @@ def forward_fn(inputs): network = hk.without_apply_rng(hk.transform(forward_fn)) return network + + +def make_policy_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + PolicyHead(num_values=1), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network + + +def make_value_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + ValueHead(num_values=1), + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network diff --git a/pax/runner.py b/pax/runner.py index 2a620e45..7807f813 100644 --- a/pax/runner.py +++ b/pax/runner.py @@ -10,7 +10,11 @@ from pax.env import IteratedPrisonersDilemma from pax.independent_learners import IndependentLearners from pax.strategies import Defect, TitForTat -from pax.utils import TrainingState, copy_state_and_network +from pax.utils import ( + TrainingState, + copy_state_and_network, + copy_extended_state_and_network, +) # TODO: make these a copy of acme @@ -83,13 +87,19 @@ def _env_rollout(carry, unused): a2_state, ), (traj1, traj2) - def _env_rollout2(carry, unused): + def _env_rollout_extended(carry, unused): t1, t2, a1_state, a2_state = carry a1, a1_state = agent2._policy( - a1_state.params, t1.observation, a1_state + a1_state.policy_params, + a1_state.value_params, + t1.observation, + a1_state, ) a2, a2_state = agent1._policy( - a2_state.params, t2.observation, a2_state + a2_state.policy_params, + a2_state.value_params, + t2.observation, + a2_state, ) tprime_1, tprime_2 = env.runner_step( [ @@ -122,22 +132,39 @@ def _env_rollout2(carry, unused): # start of unique lola code if self.args.agent1 == "LOLA" and self.args.agent2 == "LOLA": # copy state and haiku network + # ( + # agent1.other_state, + # agent1.other_network, + # ) = copy_state_and_network(agent2) + + # ( + # agent2.other_state, + # agent2.other_network, + # ) = copy_state_and_network(agent1) + ( agent1.other_state, - agent1.other_network, - ) = copy_state_and_network(agent2) + agent1.other_policy_network, + agent1.other_value_network, + ) = copy_extended_state_and_network(agent2) + ( agent2.other_state, - agent2.other_network, - ) = copy_state_and_network(agent1) + agent2.other_policy_network, + agent2.other_value_network, + ) = copy_extended_state_and_network(agent1) + # inner rollout for _ in range(self.args.lola.num_lookaheads): agent1.in_lookahead(env, _env_rollout) agent2.in_lookahead(env, _env_rollout) # outer rollout - agent1.out_lookahead(env, _env_rollout) - agent2.out_lookahead(env, _env_rollout) + # agent1.out_lookahead(env, _env_rollout) + # agent2.out_lookahead(env, _env_rollout) + + agent1.out_lookahead(env, _env_rollout_extended) + agent2.out_lookahead(env, _env_rollout_extended) elif self.args.agent1 == "LOLA" and self.args.agent2 != "LOLA": # copy state and haiku network of agent 2 @@ -173,9 +200,16 @@ def _env_rollout2(carry, unused): # unique naive-learner code a2_state = agent2.make_initial_state(t_init[1]) + # Original # rollout episode + # vals, trajectories = jax.lax.scan( + # _env_rollout, + # (*t_init, a1_state, a2_state), + # None, + # length=env.episode_length, + # ) vals, trajectories = jax.lax.scan( - _env_rollout, + _env_rollout_extended, (*t_init, a1_state, a2_state), None, length=env.episode_length, diff --git a/pax/utils.py b/pax/utils.py index 9f8f8758..543d3ac9 100644 --- a/pax/utils.py +++ b/pax/utils.py @@ -7,11 +7,24 @@ import numpy as np +# class TrainingState(NamedTuple): +# """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" + +# params: hk.Params +# opt_state: optax.GradientTransformation +# random_key: jnp.ndarray +# timesteps: int +# extras: Mapping[str, jnp.ndarray] +# hidden: None + + class TrainingState(NamedTuple): """Training state consists of network parameters, optimiser state, random key, timesteps, and extras.""" - params: hk.Params - opt_state: optax.GradientTransformation + policy_params: hk.Params + value_params: hk.Params + policy_opt_state: optax.GradientTransformation + value_opt_state: optax.GradientTransformation random_key: jnp.ndarray timesteps: int extras: Mapping[str, jnp.ndarray] @@ -118,3 +131,25 @@ def copy_state_and_network(agent): ) network = agent.network return state, network + + +def copy_extended_state_and_network(agent): + import copy + + """Copies an agent state and returns the state""" + state = TrainingState( + policy_params=copy.deepcopy(agent._state.policy_params), + value_params=copy.deepcopy(agent._state.value_params), + policy_opt_state=agent._state.policy_opt_state, + value_opt_state=agent._state.value_opt_state, + random_key=agent._state.random_key, + timesteps=agent._state.timesteps, + extras={ + "values": jnp.zeros(agent._num_envs), + "log_probs": jnp.zeros(agent._num_envs), + }, + hidden=None, + ) + policy_network = agent.policy_network + value_network = agent.value_network + return state, policy_network, value_network diff --git a/pax/watchers.py b/pax/watchers.py index 258c429e..294ebc4a 100644 --- a/pax/watchers.py +++ b/pax/watchers.py @@ -22,12 +22,15 @@ # General policy logger def policy_logger(agent) -> dict: # Categorical plotting - weights = agent._state.params["categorical_value_head/~/linear"]["w"] - pi = nn.softmax(weights) + # weights = agent._state.params["categorical_value_head/~/linear"]["w"] + # pi = nn.softmax(weights) # ####### # # # # Bernoulli plotting # weights = agent._state.params["bernoulli_value_head/~/linear"]["w"] # pi = nn.sigmoid(weights) + # # Bernoulli plotting + weights = agent._state.policy_params["policy_head/~/linear"]["w"] + pi = nn.sigmoid(weights) # sgd_steps = agent._total_steps / agent._num_steps sgd_steps = agent._logger.metrics["sgd_steps"] probs = { @@ -40,10 +43,13 @@ def policy_logger(agent) -> dict: # General value logger def value_logger(agent) -> dict: # Categorical plotting - weights = agent._state.params["categorical_value_head/~/linear_1"]["w"] + # weights = agent._state.params["categorical_value_head/~/linear_1"]["w"] # Bernoulli plotting # weights = agent._state.params["bernoulli_value_head/~/linear_1"]["w"] # sgd_steps = agent._total_steps / agent._num_steps + # Bernoulli plotting + weights = agent._state.value_params["value_head/~/linear"]["w"] + sgd_steps = agent._total_steps / agent._num_steps sgd_steps = agent._logger.metrics["sgd_steps"] probs = { f"value/{agent.player_id}/{str(s)}": p[0]