GithubHelp home page GithubHelp logo

agd's Introduction

Automatic Gradient Descent

Jeremy Bernstein*·Chris Mingard*·Kevin Huang·Navid Azizan·Yisong Yue

Getting started

Install PyTorch and a GPU, and run:

python main.py

Command line arguments are:

--arch           # options: fcn, vgg, resnet18, resnet50
--dataset        # options: cifar10, cifar100, mnist, imagenet
--train_bs       # training batch size
--test_bs        # testing batch size
--epochs         # number of training epochs
--depth          # number of layers for fcn
--width          # hidden layer width for fcn
--distribute     # train over multiple gpus (for imagenet)
--gain           # experimental acceleration of training

No training hyperparameters are neccessary. Optionally, you can try --gain 10.0 which we have found can accelerate training. Chris is maintaining a separate repository with some more experimental features.

Repository structure

.
├── architecture/           # network architectures
├── data/                   # datasets and preprocessing
├── latex/                  # source code for the paper
├── supercloud/             # mit supercloud run files
├── agd.py                  # automatic gradient descent
├── main.py                 # entrypoint to training

Description of the method

For the $k\text{th}$ weight matrix $W_k$ in $\mathbb{R}^{d_k \times d_{k-1}}$ and square or cross-entropy loss $\mathcal{L}$:

  • initial weights are drawn from the uniform measure over orthogonal matrices, and then scaled by $\sqrt{d_k / d_{k-1}}$.
  • weights are updated according to:
$$W_k \gets W_k - \frac{\eta}{L} \cdot \sqrt{\tfrac{d_k}{d_{k-1}}} \cdot \frac{ \nabla_{W_k} \mathcal{L}}{\Vert{ \nabla_{W_k}\mathcal{L}}\Vert _F}.$$

$L$ measures the depth of the network, and the learning rate $\eta$ is set automatically via:

  • $G \gets \frac{1}{L} \sum_{k\in{1...L}} \sqrt{\tfrac{d_k}{d_{k-1}}}\cdot \Vert\nabla_{W_k} \mathcal{L}\Vert_F$;
  • $\eta \gets \log\Big( \tfrac{1+\sqrt{1+4G}}{2}\Big)$.

This procedure is slightly modified for convolutional layers.

Citation

If you find AGD helpful and you'd like to cite the paper, we'd appreciate it:

@article{agd-2023,
  author  = {Jeremy Bernstein and Chris Mingard and Kevin Huang and Navid Azizan and Yisong Yue},
  title   = {{A}utomatic {G}radient {D}escent: {D}eep {L}earning without {H}yperparameters},
  journal = {arXiv:2304.05187},
  year    = 2023
}

References

Our paper titled Automatic Gradient Descent: Deep Learning without Hyperparameters is available at this link. The derivation of AGD is a refined version of the majorise-minimise analysis given in my PhD thesis Optimisation & Generalisation in Networks of Neurons, and was worked out in close collaboration with Chris and Kevin. In turn, this develops the perturbation analysis from our earlier paper On the Distance between two Neural Networks and the Stability of Learning with a couple insights from Greg Yang and Edward Hu's Feature Learning in Infinite-Width Neural Networks thrown in for good measure.

Acknowledgements

Some architecture definitions were adapted from kuangliu/pytorch-cifar.

License

We are making AGD available under the CC BY-NC-SA 4.0 license.

agd's People

Contributors

jxbz avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.