This repository is an easy-to-read jax
implementation of GPT. It is designed
to be easily understandable and hackable. This can be used as a starting point
to implement your own transformer models.
The code also implements an alternative version of the standard attention
mechanism that I called selective attention. The idea is that every key and
values are generated per-query. This means that while queries have their usual
shapes of [q_seq, d_model]
, the keys and values have shapes [q_seq, kv_seq, d_model]
. The initial motivation for this was that the usual keys and values
are very general and the same for every query. This means that information is
shared non-efficiently. The selective attention mechanism allows for more
fine-grained control over the information being shared.
But it turns out that the selective attention mechanism does not bring any
improvements over the standard attention mechanism (though being much slower).
This is a bit disappointing but I think this is still an interesting idea to
explore, and a good jax
tutorial (at least for me). It turns out those kind
of implementations make heavy use of jax
's vmap
and jit
functionalities.
Let's recall that the usual attention can be written as follows:
Where
Where
Now, one important detail is the way the keys and values are generated. I decided to go the simple way of concatenating the queries with the keys and values and passing them through a linear layer. That way I keep the parameter count close to the standard attention mechanism.
To implement this, I notably have to build every (query, key/value)
concated
pairs. This is done using jax.vmap
as follows (./src/model/mha.py
):
def cross_product_matching(
query: Float[Array, "q_seq q_size"], other: Float[Array, "o_seq o_size"]
) -> Float[Array, "q_seq o_seq q_size+o_size"]:
"""Concatenate every query element to every other element."""
cat = lambda v1, v2: jnp.concatenate((v1, v2), axis=0)
cat = jax.vmap(cat, in_axes=(None, 0))
cat = jax.vmap(cat, in_axes=(0, None))
return cat(query, other)
I'm putting it here because I think it's a good example of how to use vmap
to
build complex oprations. The output of this function can be used with a linear
layer to produce the keys and values for each query.
The learning dataset is the simple Shakespeare plays. No complex tokenization scheme is used. Every characters are mapped to its dedicated integer. This is not efficient but it remains simple which is the goal of this repository.
RoPE is used as positional encoding.
The shapes are checked using beartype
and jaxtyping
.
I have trained two models with the same hyperparameters (which you can find in
./configs/default.yaml
). The main difference between the two models is the
training time, selective attention being much slower than the standard
attention mechanism. The results are as follows:
standard | selective | equinox | |
---|---|---|---|
loss | 1.475 | 1.448 | 1.632 |
top-1 accuracy | 0.543 | 0.554 | 0.504 |
training time | 3h45 | 14h50 | 3h30 |
parameters | 191,000 | 207,000 | 224,000 |
Selective attention is slightly better but it may be explained solely by the
number of parameters rather than the attention mechanism itself. Also, note
that my equinox
version does not use any positional encoding. That's why it
performs so badly.
To validate my own standard attention, I've compared a run with my
implementation without RoPE with a run with the equinox
implementation
(without RoPE). The training runs are similar:
standard | equinox | |
---|---|---|
loss | 1.645 | 1.632 |
top-1 accuracy | 0.499 | 0.504 |
training time | 3h20 | 3h30 |
parameters | 191,000 | 224,000 |
All training curves:
With nix
(flakes):
nix develop
python3 -m venv .venv
source .venv/bin/activate
pdm sync
Otherwise:
python3 -m venv .venv
source .venv/bin/activate
pip install pdm
pdm sync
Note that this will install jax
for Nvidia GPUs with CUDA being installed
through pip. If you want to use another configuration of jax
, you can find
the installation instructions
here.
Every dependencies are listed in the pyproject.toml
, I use pdm
as a package
manager but you can use whichever you want. I used python 3.12 but this code
should run fine for any python 3.11 or above.
If you want to log the experiments, make sure to have a wandb account. The repo use hydra. The basic training (using the default hyperparameters) can be run with:
python3 main.py mode=offline
Most of the options are exposed in the ./configs/default.yaml
configuration
file which is read by hydra. You can modify this file or pass the options
directly to the command line.
Specify whether you want to log your experiments with mode=online
or
mode=offline
. You can specify the attention mechanism to use with
model.mha_type={standard,selective,equinox}
. The equinox
version is used as
a reference in this repository, but note that it does not use any positional
encoding scheme. For standard
and selective
, you can specify if you want
to use rope
or not.