a jax-native training and inference framework for alphazero, built around mctx.
features:
- comprehensive SIMD and SPMD support; parallel data collection and training across XLA devices.
- fully JIT-composable.
- circular buffer to hold memory samples (replay buffer).
- a lineup of evaluation functions
- AZ Net v. Random
- AZ v. Random
- AZ v. MCTS
- AZ v. AZ (League based)
- stochastic alphazero with chance nodes (on branch
stochastic_alpha_zero
). - resume training from checkpoint.
- install poetry
- inside repo
poetry install
poetry shell
- install JAX 0.4.30 with
poetry run pip3 install ...
(system dependent)
training
python3 src/train.py name=test env_id=connect_four
inference
python3 src/inference.py ckpt="checkpoints/connect_four_test/000000.ckpt" eval_simulations=1600
stochastic alphazero
the recurrent_fn
expects the State to hold two extra attributes:
class State:
# everything else ...
- _chance_probs: the chance node probabilities across all actions
- is_chance: if the current node is a chance node
_chance_probs: jnp.ndarray
is_chance: bool
jax
jax is volatile, get's updated frequently. things will most likely crash if you don't use the intended version.
i use the now deprecated jax.pmap
for explicit device parallelism, i may swap to shard_map
in the future.
environment
this is built for PGX environments. that being said swapping out for your own completely custom environment should be straight forward as long as you implement the following:
class State(ABC):
"""
Base state class for Pgx game environments.
Key attributes:
- current_player: ID of agent to play
- observation: Current state observation
- rewards: Intermediate rewards for each agent
- terminated: Whether the state is terminal
- truncated: Whether the episode was truncated
- legal_action_mask: Boolean array of legal actions
"""
current_player: jnp.ndarray
observation: jnp.ndarray
rewards: jnp.ndarray
terminated: jnp.ndarray
truncated: jnp.ndarray
legal_action_mask: jnp.ndarray
class Env(ABC):
"""
Base environment class for Pgx games.
Key properties:
- id: Environment identifier
- num_actions: Size of action space
- num_players: Number of players
- observation_shape: Shape of observation
- version: Environment version
Key methods:
- init: Initialize the environment state
- step: Perform an action and get the next state
- observe: Get observation for a specific player
"""
@property
@abstractmethod
def id(self) -> str:
pass
@property
def num_actions(self) -> int:
pass
@property
@abstractmethod
def num_players(self) -> int:
pass
@property
def observation_shape(self) -> Tuple[int, ...]:
pass
@property
@abstractmethod
def version(self) -> str:
pass
@abstractmethod
def init(self, key: jnp.ndarray) -> 'State':
pass
@abstractmethod
def step(self, state: 'State', action: jnp.ndarray, key: jnp.ndarray = None) -> 'State':
pass
@abstractmethod
def observe(self, state: 'State', player_id: int) -> jnp.ndarray:
pass
See docs for more info, or the Tic-Tac-Toe implementation.