GithubHelp home page GithubHelp logo

jestimator's Introduction

Amos and JEstimator

This is not an officially supported Google product.

This is the source code for the paper "Amos: An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale".

It implements Amos, an optimizer compatible with the optax library, and JEstimator, a light-weight library with a tf.Estimator-like interface to manage T5X-compatible checkpoints for machine learning programs in Jax, which we use to run experiments in the paper.

Installation and test

In order to run a test for Amos, we need to install Abseil, Jax, Flax and Chex:

pip install absl-py  # Install Abseil
pip install --upgrade pip
pip install --upgrade "jax[cpu]"  # Install Jax
pip install flax  # Install Flax
pip install chex  # Install Chex

Then, checkout the repository and run the test:

git clone --branch=main https://github.com/google-research/jestimator
PYTHONPATH=. python3 jestimator/amos_test.py

Run models with JEstimator

The data pipeline of JEstimator is built on Tensorflow, although in principle it can be replaced by PyTorch DataLoader as well. We also need the T5X and FlaxFormer library.

pip install tensorflow-cpu  # Install Tensorflow

git clone --branch=main https://github.com/google-research/t5x
cd t5x  # Install T5X with TPU support, so we can pre-train on Google Cloud:
python3 -m pip install -e '.[tpu]' -f \
  https://storage.googleapis.com/jax-releases/libtpu_releases.html
cd ..

git clone --branch=main https://github.com/google/flaxformer
cd flaxformer  # Install FlaxFormer:
pip3 install '.[testing]'
cd ..

Then, we can test a toy linear regression model with JEstimator:

JAX_PLATFORMS=cpu PYTHONPATH=. python3 jestimator/models/linear_regression/linear_regression_test.py

And we can train a single layer LSTM model on PTB:

JAX_PLATFORMS=cpu PYTHONPATH=. python3 jestimator/estimator.py \
  --module_imp="jestimator.models.lstm.lm" \
  --module_config="jestimator/models/lstm/lm.py" \
  --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \
  --train_pattern="jestimator/models/lstm/ptb/ptb.train.txt" \
  --model_dir="$HOME/models/ptb_lstm" \
  --train_batch_size=64 \
  --train_consecutive=113 \
  --train_shuffle_buf=4096 \
  --max_train_steps=200000 \
  --check_every_steps=1000 \
  --max_ckpt=20 \
  --module_config.opt_config.optimizer="amos" \
  --module_config.opt_config.learning_rate=0.01 \
  --module_config.opt_config.beta=0.98 \
  --module_config.opt_config.momentum=0.0 \
  --logtostderr

After the training completes, we can evaluate the model on validation set:

JAX_PLATFORMS=cpu PYTHONPATH=. python3 jestimator/estimator.py \
  --module_imp="jestimator.models.lstm.lm" \
  --module_config="jestimator/models/lstm/lm.py" \
  --module_config.vocab_path="jestimator/models/lstm/ptb/vocab.txt" \
  --eval_pattern="jestimator/models/lstm/ptb/ptb.valid.txt" \
  --model_dir="$HOME/models/ptb_lstm" \
  --eval_batch_size=1 \
  --logtostderr

jestimator's People

Contributors

tianran avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.