LAMB16 is an optimizer proposed by Chunqing "Jyn" Shan, based on the LAMB[1] optimizer. It enables training with float8 optimizer state while maintaining float32 precision for weights and backpropagation gradients. This significantly reduces memory size and bandwidth requirements. Additionally, LAMB16 incorporates adaptive trust based on learning rate, which helps rescale the trust rate to a reasonable range when the learning rate is not typical (e.g., 0.01).
These modifications make LAMB16 more stable and capable of converging faster than both Adam and LAMB, ultimately reaching the smallest final loss. By requiring only 16 bits for each parameter (hence the name LAMB16), it uses 1/4 the memory size and bandwidth of Adam or LAMB. With optimized kernels (e.g., performing float decompression in the kernel and keeping FP32 in L1/shared memory), it can be even faster.
LAMB16 works transparently like an FP32 optimizer with a 1/4 memory footprint, without the need for AMP or changes to the ML workflow. It stores per-element adaptive learning rates, avoiding the side effects of other memory-aware optimizers (e.g., Adafactor). It also enables much larger batch size training, similar to the LAMB optimizer.
Without data augmentation, using the original 60,000 MINST images, LAMB16 trains 2-layer naive CNN at 1024 batch size to 99.2% test accuracy in 10 epochs; at 128 batch size, 99.3% test accuracy in 5 epochs.
LAMB16 calculates the per-layer norm of the first moment estimate (m) and the second moment estimate (v) in addition to the norm of the weights and adam_delta. The m_norm and v_norm are stored as float32 scalar values in the optimizer's state. The normalized m and v are then calculated by dividing m and v by their respective norms and stored in float8_e4m3fn and float8_e5m2 formats in the optimizer's state. This results in a total state size of 50% of the weight size, with 16 bits for each parameter (compared to 64 bits per parameter or 200% of the weight size for Adam or LAMB).
LAMB16 enables training with 1/4 the memory requirement(and 1/4 bandwidth overhead) for optimizer state compared to Adam or LAMB. It also allows for training with much larger batch sizes, a benefit inherited from the LAMB optimizer. With the same hyperparameters, LAMB16 converges faster than Adam and the original LAMB. Considering its significantly reduced memory bandwidth requirement, it should be much faster than both in practice.
This is a proof-of-concept implementation based on cybertronai's pytorch-lamb. I wrote a new Optimizer for LAMB16 and reused the test_lamb.py(CLI) and their implementation of original LAMB, so we can compare the performance of Adam, LAMB, and LAMB16.
The following results demonstrate the performance of Adam, LAMB, and LAMB16 optimizers when training the MNIST dataset with a batch size of 1024, a learning rate of 0.02, and a weight decay of 0.01.
The red line represents Adam, the green line represents LAMB, and the blue line represents LAMB16.
The following results demonstrate the performance of Adam, LAMB, and LAMB16 optimizers when training the MNIST dataset with a batch size of 128, a learning rate of 0.01, and a weight decay of 0.01.
The red line represents Adam, the green line represents LAMB, and the blue line represents LAMB16.
I was stupid and not aware of the existence of 4-bit/8-bit AdamW when developing LAMB16. They did some very interesting numerical analysis. Still, LAMB16 outperforms 4-bit/8-bit AdamW on large batch sizes due to LAMB16's per-layer adaptive trust ratio and its better moment resolution.
Another advantage of LAMB16 over low-bit AdamW[3] is that 4-bit/8-bit AdamW uses a dynamic exponent mapping quantize strategy, which involves mapping and de-mapping values to INT4/INT8. It needs a lot more memory bandwidth compared to LAMB16, which uses float8<->float32 convert that can be done per-element without the requirement of mapping/dict-building.
- LAMB: Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
- https://github.com/cybertronai/pytorch-lamb, the original LAMB optimizer implementation.
- https://arxiv.org/pdf/2309.01507