Skip to content


Folders and files

Last commit message
Last commit date

Latest commit



31 Commits

Repository files navigation

Selective Attention Transformer

This repository is an easy-to-read jax implementation of GPT. It is designed to be easily understandable and hackable. This can be used as a starting point to implement your own transformer models.

The code also implements an alternative version of the standard attention mechanism that I called selective attention. The idea is that every key and values are generated per-query. This means that while queries have their usual shapes of [q_seq, d_model], the keys and values have shapes [q_seq, kv_seq, d_model]. The initial motivation for this was that the usual keys and values are very general and the same for every query. This means that information is shared non-efficiently. The selective attention mechanism allows for more fine-grained control over the information being shared.

But it turns out that the selective attention mechanism does not bring any improvements over the standard attention mechanism (though being much slower). This is a bit disappointing but I think this is still an interesting idea to explore, and a good jax tutorial (at least for me). It turns out those kind of implementations make heavy use of jax's vmap and jit functionalities.

Selective attention

Let's recall that the usual attention can be written as follows:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

Where $Q$ are the queries, $K$ are the keys, and $V$ are the values. The selective attention is defined by:

$$ \text{Attention}(q_i, K_i, V_i) = \text{softmax}\left(\frac{K_iq_i}{\sqrt{d_k}}\right)V_i $$

Where $q_i$ is the $i$-th query, $K_i$ and $V_i$ are the keys and values for the $i$-th query. It effectively means that the keys and values are generated per-query. The final attention has to be computed for every queries $Q$.

Now, one important detail is the way the keys and values are generated. I decided to go the simple way of concatenating the queries with the keys and values and passing them through a linear layer. That way I keep the parameter count close to the standard attention mechanism.

To implement this, I notably have to build every (query, key/value) concated pairs. This is done using jax.vmap as follows (./src/model/

def cross_product_matching(
    query: Float[Array, "q_seq q_size"], other: Float[Array, "o_seq o_size"]
) -> Float[Array, "q_seq o_seq q_size+o_size"]:
    """Concatenate every query element to every other element."""
    cat = lambda v1, v2: jnp.concatenate((v1, v2), axis=0)
    cat = jax.vmap(cat, in_axes=(None, 0))
    cat = jax.vmap(cat, in_axes=(0, None))
    return cat(query, other)

I'm putting it here because I think it's a good example of how to use vmap to build complex oprations. The output of this function can be used with a linear layer to produce the keys and values for each query.

Implementation details

The learning dataset is the simple Shakespeare plays. No complex tokenization scheme is used. Every characters are mapped to its dedicated integer. This is not efficient but it remains simple which is the goal of this repository.

RoPE is used as positional encoding.

The shapes are checked using beartype and jaxtyping.


I have trained two models with the same hyperparameters (which you can find in ./configs/default.yaml). The main difference between the two models is the training time, selective attention being much slower than the standard attention mechanism. The results are as follows:

standard selective equinox
loss 1.475 1.448 1.632
top-1 accuracy 0.543 0.554 0.504
training time 3h45 14h50 3h30
parameters 191,000 207,000 224,000

Selective attention is slightly better but it may be explained solely by the number of parameters rather than the attention mechanism itself. Also, note that my equinox version does not use any positional encoding. That's why it performs so badly.

To validate my own standard attention, I've compared a run with my implementation without RoPE with a run with the equinox implementation (without RoPE). The training runs are similar:

standard equinox
loss 1.645 1.632
top-1 accuracy 0.499 0.504
training time 3h20 3h30
parameters 191,000 224,000

All training curves:


How to use

Install the dependencies

With nix (flakes):

nix develop
python3 -m venv .venv
source .venv/bin/activate
pdm sync


python3 -m venv .venv
source .venv/bin/activate
pip install pdm
pdm sync

Note that this will install jax for Nvidia GPUs with CUDA being installed through pip. If you want to use another configuration of jax, you can find the installation instructions here.

Every dependencies are listed in the pyproject.toml, I use pdm as a package manager but you can use whichever you want. I used python 3.12 but this code should run fine for any python 3.11 or above.

Run experiments

If you want to log the experiments, make sure to have a wandb account. The repo use hydra. The basic training (using the default hyperparameters) can be run with:

python3 mode=offline

Most of the options are exposed in the ./configs/default.yaml configuration file which is read by hydra. You can modify this file or pass the options directly to the command line.

Specify whether you want to log your experiments with mode=online or mode=offline. You can specify the attention mechanism to use with model.mha_type={standard,selective,equinox}. The equinox version is used as a reference in this repository, but note that it does not use any positional encoding scheme. For standard and selective, you can specify if you want to use rope or not.


A transformer variant.






No releases published


No packages published