GithubHelp home page GithubHelp logo

junjieyang97 / stocbio Goto Github PK

View Code? Open in Web Editor NEW
43.0 1.0 14.0 1.05 MB

Example code for paper "Bilevel Optimization: Nonasymptotic Analysis and Faster Algorithms"

License: MIT License

Python 100.00%
bilevel-optimization hyperparameter algorithm

stocbio's Introduction

Efficient bilevel Optimizers stocBiO, ITD-BiO and FO-ITD-BiO.

Codes for ICML 2021 paper Bilevel Optimization: Nonasymptotic Analysis and Faster Algorithms by Kaiyi Ji, Junjie Yang, and Yingbin Liang from The Ohio State University.

stocBiO for hyperparameter optimization

Our hyperparameter optimization implementation is bulit on HyperTorch, where we propose stoc-BiO algorithm with better performance than other bilevel algorithms. Our code is tested on python3 and PyTorch1.8.

Note that hypergrad package is built on HyperTorch.

The experiments based on 20 Newsgroup and MNIST datasets are in l2reg_on_twentynews.py and mnist_exp.py, respectively.

How to run our code

We introduce some basic args meanning as follows.

Args meaning

  • --alg: Different algorithms we support.
  • --hessian_q: The number of Hessian vectors used to estimate.
  • --training_size: The number of samples used in training.
  • --validation_size: The number of samples used for validation.
  • --batch_size: Batch size for traning data.
  • --epochs: Outer epoch number for training.
  • --iterations or --T: Inner iteration number for training.
  • --eta: Hyperparameter $\eta$ for Hessian inverse approximation.
  • --noise_rate: The corruption rate for MNIST data.

To replicate empirical results under different datasets in our paper, please run the following commands:

stocBiO in MNIST with p=0.1

python3 mnist_exp.py --alg stocBiO --batch_size 50 --noise_rate 0.1

stocBiO in MNIST with p=0.4

python3 mnist_exp.py --alg stocBiO --batch_size 50 --noise_rate 0.4

stocBiO in 20 Newsgroup

python3 l2reg_on_twentynews.py --alg stocBiO

AID-FP in MNIST with p=0.4

python3 mnist_exp.py --alg AID-FP --batch_size 50 --noise_rate 0.4

AID-FP in 20 Newsgroup

python3 l2reg_on_twentynews.py --alg AID-FP

ITD-BiO and FO-ITD-BiO for meta-learning

Our meta-learning part is built on learn2learn, where we implement the bilevel optimizer ITD-BiO and show that it converges faster than MAML and ANIL. Note that we also implement first-order ITD-BiO (FO-ITD-BiO) without computing the derivative of the inner-loop output with respect to feature parameters, i.e., removing all Jacobian and Hessian-vector calculations. It turns out that FO-ITD-BiO is even faster without sacrificing overall prediction accuracy.

Environments for meta-learning experiments

For Windows OS,

  • PyTorch=1.7.1
  • l2l=0.1.5
  • python=3.8
  • cuda=11.3

For Linux OS,

  • PyTorch=1.7.0
  • l2l=0.1.5
  • python=3.6.9
  • cuda=10.2

For both OS, we highly suggest an old version of l2l. For latest versions of l2l, some adaptations of codes are needed.

Some experiment examples

In the following, we provide some experiments to demonstrate the better performance of the proposed stoc-BiO algorithm.

We compare our algorithm to various hyperparameter baseline algorithms on 20 Newsgroup dataset:

We evaluate the performance of our algorithm with respect to different batch sizes:

The comparison results on MNIST dataset:

This repo is still under construction and any comment is welcome!

Citation

If this repo is useful for your research, please cite our paper:

@inproceedings{ji2021bilevel,
	author = {Ji, Kaiyi and Yang, Junjie and Liang, Yingbin},
	title = {Bilevel Optimization: Nonasymptotic Analysis and Faster Algorithms},
	booktitle={International Conference on Machine Learning (ICML)},
	year = {2021}}

stocbio's People

Contributors

jikaiyi avatar junjieyang97 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.