Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add readmes and fix modules #5

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added data/README.md
Empty file.
Empty file added embeddings/README.md
Empty file.
Empty file added heuristics/README.md
Empty file.
Empty file added inference/README.md
Empty file.
5 changes: 4 additions & 1 deletion inference/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ async def predict(request: Request):
pprint_nparray(
np.array(ex.moveset[..., MovesetFeature.MOVESET_FEATURE__SPECIES_ID])
)
pprint_nparray(np.array(response.pi))
print()
pprint_nparray(np.array(response.logit))
pprint_nparray(np.array(response.log_pi))
print()
pprint_nparray(np.array(response.pi))
print()
pprint_nparray(np.array(response.v))
pprint_nparray(np.array(response.action))

Expand Down
Empty file added ml/README.md
Empty file.
6 changes: 3 additions & 3 deletions ml/arch/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from ml_collections import ConfigDict

from ml.arch.modules import Logits, TransformerEncoder, softcap
from ml.arch.modules import Logits, TransformerEncoder
from ml.func import legal_log_policy, legal_policy


Expand Down Expand Up @@ -35,7 +35,7 @@ def setup(self):
self.encoder = TransformerEncoder(**self.cfg.transformer.to_dict())
self.queries = self.param(
"queries",
nn.initializers.truncated_normal(0.02),
nn.initializers.truncated_normal(),
(4, self.cfg.transformer.model_size),
)
self.logits = Logits(**self.cfg.logits.to_dict())
Expand All @@ -51,6 +51,6 @@ def __call__(self, embeddings: chex.Array, mask: chex.Array):
state_embedding = state_embedding.reshape(-1)

logits = self.logits(state_embedding)
logits = softcap(logits, max_value=3)
# logits = softcap(logits, max_value=3)

return logits.reshape(-1)
25 changes: 6 additions & 19 deletions ml/arch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,12 @@ def l2_normalize(x: chex.Array, epsilon: float = 1e-6):
return x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + epsilon)


def escort_transform(x: chex.Array, mask: chex.Array, p: int = 2, axis: int = -1):
def escort_transform(
x: chex.Array, mask: chex.Array, p: int = 2, axis: int = -1, eps: float = 1e-8
):
abs_x = jnp.power(jnp.abs(x), p)
denom = abs_x.sum(axis=axis, where=mask, keepdims=True)
denom = jnp.where(denom == 0, 1, denom)
return abs_x / denom
return abs_x / (denom + eps)


class MultiHeadAttention(nn.Module):
Expand Down Expand Up @@ -809,10 +810,7 @@ class SumEmbeddings(nn.Module):
def __call__(self, encodings: List[chex.Array]) -> jnp.ndarray:
# Sum the transformed embeddings using parameter weights

num_embeddings = len(encodings)

bias1 = self.param("bias1", nn.initializers.zeros_init(), (self.output_size,))
bias2 = self.param("bias2", nn.initializers.zeros_init(), (self.output_size,))
bias = self.param("bias", nn.initializers.zeros_init(), (self.output_size,))

def _transform_encoding(encoding: chex.Array, index: int):
return nn.Dense(
Expand All @@ -825,18 +823,7 @@ def _transform_encoding(encoding: chex.Array, index: int):
_transform_encoding(encoding, i) for i, encoding in enumerate(encodings)
]

output = sum(transformed_encodings) / (num_embeddings**0.5) + bias1
if self.use_layer_norm:
output = layer_norm(output)
output = nn.relu(output)

weights = nn.Dense(num_embeddings)(output)
scores = jax.nn.softmax(weights, axis=-1)

embeddings = [
score * embedding for score, embedding in zip(scores, transformed_encodings)
]
output = sum(embeddings) + bias2
output = sum(transformed_encodings) + bias

if self.use_layer_norm:
output = layer_norm(output)
Expand Down
8 changes: 6 additions & 2 deletions ml/learners/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ def collect_regularisation_telemetry_data(
legal_mask: chex.Array,
state_mask: chex.Array,
) -> dict[str, Any]:
raw_reg_rewards = regularisation_policy.mean(where=legal_mask)
norm_reg_rewards = (policy * regularisation_policy).mean(where=state_mask)
raw_reg_rewards = regularisation_policy.mean(where=legal_mask, axis=-1).mean(
where=state_mask
)
norm_reg_rewards = jnp.squeeze((policy * regularisation_policy).sum(axis=-1)).mean(
where=state_mask
)
return {
"raw_reg_rewards": raw_reg_rewards,
"norm_reg_rewards": norm_reg_rewards,
Expand Down
50 changes: 23 additions & 27 deletions ml/learners/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ class NerdConfig:
"""Nerd related params."""

beta: float = 3
clip: float = 6
clip: float = 100


@chex.dataclass(frozen=True)
class VtraceConfig(ActorCriticConfig):
entropy_loss_coef: float = 1e-2
target_network_avg: float = 1e-2
target_network_avg: float = 5e-3

nerd: NerdConfig = NerdConfig()
clip_gradient: float = 50
clip_gradient: float = 5


def get_config():
Expand All @@ -55,7 +55,7 @@ def get_config():

class TrainState(train_state.TrainState):

params_reg: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
# params_reg: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
params_target: core.FrozenDict[str, Any] = struct.field(pytree_node=True)

actor_steps: int = 0
Expand All @@ -67,7 +67,7 @@ def create_train_state(module: nn.Module, rng: PRNGKey, config: ActorCriticConfi

params = module.init(rng, ex, hx)
params_target = module.init(rng, ex, hx)
params_reg = module.init(rng, ex, hx)
# params_reg = module.init(rng, ex, hx)

tx = optax.chain(
optax.adam(
Expand All @@ -84,7 +84,7 @@ def create_train_state(module: nn.Module, rng: PRNGKey, config: ActorCriticConfi
apply_fn=module.apply,
params=params,
params_target=params_target,
params_reg=params_reg,
# params_reg=params_reg,
tx=tx,
)

Expand All @@ -95,7 +95,7 @@ def save(state: TrainState):
dict(
params=state.params,
params_target=state.params_target,
params_reg=state.params_reg,
# params_reg=state.params_reg,
opt_state=state.opt_state,
step=state.step,
),
Expand All @@ -118,9 +118,9 @@ def load(state: TrainState, path: str):
print(f"Learner steps: {step_no:08}")
print(f"Loading target and regularisation nets")
print(f"Loading optimizer state")
state.replace(
state = state.replace(
params_target=step["params_target"],
params_reg=step["params_reg"],
# params_reg=step["params_reg"],
opt_state=step["opt_state"],
)

Expand All @@ -133,22 +133,20 @@ def train_step(state: TrainState, batch: TimeStep, config: VtraceConfig):

def loss_fn(params: Params):
# Define a checkpointed function
def rollout_fn(model_params, env, history):
def rollout_fn(model_params):
return jax.vmap(jax.vmap(state.apply_fn, (None, 0, 0)), (None, 0, 0))(
model_params, env, history
model_params, batch.env, batch.history
)

pred: ModelOutput = rollout_fn(params, batch.env, batch.history)
pred_targ: ModelOutput = rollout_fn(
state.params_target, batch.env, batch.history
)
pred_reg: ModelOutput = rollout_fn(state.params_reg, batch.env, batch.history)
pred: ModelOutput = rollout_fn(params)
pred_targ: ModelOutput = rollout_fn(state.params_target)
# pred_reg: ModelOutput = rollout_fn(state.params_reg)

logs = {}

policy_pprocessed = config.finetune(pred.pi, batch.env.legal, state.step)

log_policy_reg = pred.log_pi - pred_reg.log_pi
log_policy_reg = pred.log_pi - pred_targ.log_pi
logs.update(
collect_regularisation_telemetry_data(
pred.pi, log_policy_reg, batch.env.legal, batch.env.valid
Expand All @@ -160,9 +158,7 @@ def rollout_fn(model_params, env, history):
v_target_list, has_played_list, v_trace_policy_target_list = [], [], []
action_oh = jax.nn.one_hot(batch.actor.action, batch.actor.policy.shape[-1])

rewards = (
batch.actor.rewards.win_rewards + batch.actor.rewards.fainted_rewards / 6
)
rewards = batch.actor.rewards.scaled_fainted_rewards

for player in range(config.num_players):
reward = rewards[:, :, player] # [T, B, Player]
Expand All @@ -180,7 +176,7 @@ def rollout_fn(model_params, env, history):
lambda_=1.0,
c=config.c_vtrace,
rho=jnp.inf,
eta=0.1,
eta=0.5,
gamma=config.gamma,
)
v_target_list.append(jax.lax.stop_gradient(v_target_))
Expand Down Expand Up @@ -252,14 +248,14 @@ def rollout_fn(model_params, env, history):
old_tensors=state.params_target,
step_size=ema_val,
)
params_reg = optax.incremental_update(
new_tensors=state.params_target,
old_tensors=state.params_reg,
step_size=ema_val,
)
# params_reg = optax.incremental_update(
# new_tensors=state.params_target,
# old_tensors=state.params_reg,
# step_size=ema_val,
# )
state = state.replace(
params_target=params_target,
params_reg=params_reg,
# params_reg=params_reg,
actor_steps=state.actor_steps + batch.env.valid.sum(),
)

Expand Down
Empty file added proto/README.md
Empty file.
Empty file added replays/README.md
Empty file.
Empty file added rlenv/README.md
Empty file.
14 changes: 7 additions & 7 deletions rlenv/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import websockets
from tqdm import tqdm

from ml.arch.model import get_model
from ml.arch.model import get_dummy_model
from ml.config import FineTuning
from ml.learners.func import collect_batch_telemetry_data
from ml.utils import Params
Expand Down Expand Up @@ -131,7 +131,7 @@ async def initialize_players(self):
self.current_player = self.players[0] # Start with the first player

def is_done(self):
return self.dones.sum() == 2
return self.dones.all()

async def _reset(self):
"""Reset both players, enqueue their initial states, and return the first state to act on."""
Expand Down Expand Up @@ -163,13 +163,13 @@ async def _step(self, action: int):
async def _perform_action(self, player: SinglePlayerEnvironment, action: int):
"""Helper method to send the action to the player and enqueue the resulting state."""
# Perform the step and add the resulting state along with the player back into the queue
state = await player._step(action)
if not self.is_done():
if not player.is_done():
state = await player._step(action)
self.dones[int(state.info.player_index)] = state.info.done
if not state.info.done:
await self.state_queue.put((player, state))

if self.is_done():
await self.state_queue.put((player, state))
await self.state_queue.put((self.current_player, self.current_state))


class BatchEnvironment(ABC):
Expand Down Expand Up @@ -387,7 +387,7 @@ def main():
evaluation_progress = tqdm(desc="evaluation: ")

num_envs = 8
network = get_model()
network = get_dummy_model()
training_env = SingleTrajectoryTrainingBatchCollector(network, num_envs)
evaluation_env = EvalBatchCollector(network, 4)

Expand Down
Empty file added scripts/README.md
Empty file.
Empty file added service/README.md
Empty file.
58 changes: 28 additions & 30 deletions service/src/server/game.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { Action, GameState } from "../../protos/service_pb";
import { MessagePort } from "worker_threads";
import { EVAL_GAME_ID_OFFSET } from "./data";
import { getEvalAction } from "./eval";
import { Rewards } from "../../protos/state_pb";

const formatId = "gen3randombattle";
const generator = TeamGenerators.getTeamGenerator(formatId);
Expand All @@ -25,9 +26,6 @@ export class Game {
playerIds: number[];
maxPlayers: number;

tracker: Tracker;
prevTurn: number;

constructor(gameId: number, workerIndex: number, port: MessagePort) {
this.gameId = gameId;
this.workerIndex = workerIndex;
Expand All @@ -38,9 +36,6 @@ export class Game {
this.maxPlayers = this.gameId < EVAL_GAME_ID_OFFSET ? 2 : 1;
this.resetCount = 0;
this.playerIds = [];

this.tracker = new Tracker();
this.prevTurn = 0;
}

addPlayerId(playerId: number) {
Expand All @@ -66,7 +61,6 @@ export class Game {
console.error("No players have been added");
}
this.resetCount = 0;
this.tracker.reset();
this.tasks.reset();
this._reset(options);
}
Expand Down Expand Up @@ -118,38 +112,42 @@ export class Game {
>player p2 ${JSON.stringify(p2spec)}`);

const battle = stream.battle!;
this.tracker.setBattle(battle);

const tracker = new Tracker();
tracker.setBattle(battle);
tracker.reset();

const sendFn: sendFnType = async (player) => {
const gameState = new GameState();

const state = player.createState();
const info = state.getInfo()!;
const currentTurn = info.getTurn()!;

if (currentTurn > this.prevTurn) {
const rewards = info.getRewards()!;
this.tracker.update();
const {
faintedReward,
hpReward,
scaledHpReward,
scaledFaintedReward,
} = this.tracker.getReward();

rewards.setHpReward(hpReward);
rewards.setFaintedReward(faintedReward);
rewards.setScaledHpReward(scaledHpReward);
rewards.setScaledFaintedReward(scaledFaintedReward);

this.prevTurn = currentTurn;
}
const isDone = info.getDone()!;
const playerIndex = +state.getInfo()!.getPlayerIndex();

const rewards = new Rewards();
tracker.update(playerIndex);
const {
faintedReward,
hpReward,
scaledHpReward,
scaledFaintedReward,
winReward,
} = tracker.getReward();

rewards.setHpReward(hpReward);
rewards.setFaintedReward(faintedReward);
rewards.setScaledHpReward(scaledHpReward);
rewards.setScaledFaintedReward(scaledFaintedReward);
rewards.setWinReward(winReward);

info.setRewards(rewards);
state.setInfo(info);

gameState.setState(state.serializeBinary());
const playerId =
this.playerIds[+state.getInfo()!.getPlayerIndex()] ?? 1;
const playerId = this.playerIds[playerIndex] ?? 1;
let rqid = -1;
if (!state.getInfo()!.getDone()) {
if (!isDone) {
rqid = this.tasks.createJob();
}
gameState.setRqid(rqid);
Expand Down
Loading