Skip to content

martin-marek/picodo

Repository files navigation

Picodo: fast Transformer decoder training in JAX/NNX

  • Picodo has <200 SLOC
  • can run on GPUs, TPUs, Google Colab, or even locally on a Mac
  • achieves 64% MFU on TPU v2-8 when training GPT2-small (124M)
  • supports Fully Sharded Data Parallel (FSDP) parallelism
  • uses the new Flax NNX Api
  • uses Hydra for experiment management
  • uses Weights & Biases for experiment tracking

Training

Open In Colab

Picodo requires a pretokenized dataset for training following the same format as nanoGPT. This speeds up training and simplifies the codebase. To get started, I prepared a pretokenized sample of 2.5B tokens from fineweb-edu here: train.bin, valid.bin.

The simplest way to use this codebase is by using the provided Colab notebook, which automatically installs requirements, downloads the dataset, and starts training a model.

To train a model using bash, simply set the config name and any overrides:

python main.py -cn colab opt.peak_lr=0.004

You can also run main.py directly, which uses the local.yaml config by default (designed for local development).

Inspiration

This repository was originally a fork of deepmind/NanoDO but it no longer shares any lines of code. Some notable changes:

  • NanoDO has ~1800 SLOC while Picodo only has ~200 SLOC
  • Picodo doens't rely on grain for data loading so it can run locally on a Mac
  • Picodo uses the new Flax NNX Api
  • Picodo uses Hydra and Weights & Biases instead of Google's ConfigDict / Tensorboard

About

FSDP Transformer in JAX/NNX

Resources

License

Stars

Watchers

Forks