GithubHelp home page GithubHelp logo

lamb16's Introduction

LAMB16

LAMB16 is an optimizer proposed by Chunqing "Jyn" Shan, based on the LAMB[1] optimizer. It enables training with float8 optimizer state while maintaining float32 precision for weights and backpropagation gradients. This significantly reduces memory size and bandwidth requirements. Additionally, LAMB16 incorporates adaptive trust based on learning rate, which helps rescale the trust rate to a reasonable range when the learning rate is not typical (e.g., 0.01).

These modifications make LAMB16 more stable and capable of converging faster than both Adam and LAMB, ultimately reaching the smallest final loss. By requiring only 16 bits for each parameter (hence the name LAMB16), it uses 1/4 the memory size and bandwidth of Adam or LAMB. With optimized kernels (e.g., performing float decompression in the kernel and keeping FP32 in L1/shared memory), it can be even faster.

LAMB16 works transparently like an FP32 optimizer with a 1/4 memory footprint, without the need for AMP or changes to the ML workflow. It stores per-element adaptive learning rates, avoiding the side effects of other memory-aware optimizers (e.g., Adafactor). It also enables much larger batch size training, similar to the LAMB optimizer.

Without data augmentation, using the original 60,000 MINST images, LAMB16 trains 2-layer naive CNN at 1024 batch size to 99.2% test accuracy in 10 epochs; at 128 batch size, 99.3% test accuracy in 5 epochs.

Algorithm

LAMB16 calculates the per-layer norm of the first moment estimate (m) and the second moment estimate (v) in addition to the norm of the weights and adam_delta. The m_norm and v_norm are stored as float32 scalar values in the optimizer's state. The normalized m and v are then calculated by dividing m and v by their respective norms and stored in float8_e4m3fn and float8_e5m2 formats in the optimizer's state. This results in a total state size of 50% of the weight size, with 16 bits for each parameter (compared to 64 bits per parameter or 200% of the weight size for Adam or LAMB).

LAMB16 enables training with 1/4 the memory requirement(and 1/4 bandwidth overhead) for optimizer state compared to Adam or LAMB. It also allows for training with much larger batch sizes, a benefit inherited from the LAMB optimizer. With the same hyperparameters, LAMB16 converges faster than Adam and the original LAMB. Considering its significantly reduced memory bandwidth requirement, it should be much faster than both in practice.

Implementation

This is a proof-of-concept implementation based on cybertronai's pytorch-lamb. I wrote a new Optimizer for LAMB16 and reused the test_lamb.py(CLI) and their implementation of original LAMB, so we can compare the performance of Adam, LAMB, and LAMB16.

Results

The following results demonstrate the performance of Adam, LAMB, and LAMB16 optimizers when training the MNIST dataset with a batch size of 1024, a learning rate of 0.02, and a weight decay of 0.01.

Batch 1024 LR 0.02

The red line represents Adam, the green line represents LAMB, and the blue line represents LAMB16.

The following results demonstrate the performance of Adam, LAMB, and LAMB16 optimizers when training the MNIST dataset with a batch size of 128, a learning rate of 0.01, and a weight decay of 0.01.

Batch 128 LR 0.01

The red line represents Adam, the green line represents LAMB, and the blue line represents LAMB16.

Compare

I was stupid and not aware of the existence of 4-bit/8-bit AdamW when developing LAMB16. They did some very interesting numerical analysis. Still, LAMB16 outperforms 4-bit/8-bit AdamW on large batch sizes due to LAMB16's per-layer adaptive trust ratio and its better moment resolution.

Another advantage of LAMB16 over low-bit AdamW[3] is that 4-bit/8-bit AdamW uses a dynamic exponent mapping quantize strategy, which involves mapping and de-mapping values to INT4/INT8. It needs a lot more memory bandwidth compared to LAMB16, which uses float8<->float32 convert that can be done per-element without the requirement of mapping/dict-building.

Reference

  1. LAMB: Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
  2. https://github.com/cybertronai/pytorch-lamb, the original LAMB optimizer implementation.
  3. https://arxiv.org/pdf/2309.01507

lamb16's People

Contributors

vxst avatar 8enmann avatar ousou avatar yaroslavvb avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.