GithubHelp home page GithubHelp logo

anthrax3 / supervised-reptile Goto Github PK

View Code? Open in Web Editor NEW

This project forked from openai/supervised-reptile

0.0 2.0 0.0 1.2 MB

Reptile on supervised meta-learning datasets

License: MIT License

Shell 2.58% Python 44.94% JavaScript 48.01% HTML 0.39% CSS 4.09%

supervised-reptile's Introduction

supervised-reptile

Reptile training code for Omniglot and Mini-ImageNet.

Reptile is a meta-learning algorithm that finds a good initialization. It works by sampling a task, training on the sampled task, and then updating the initialization towards the new weights for the task.

Getting the data

The fetch_data.sh script creates a data/ directory and downloads Omniglot and Mini-ImageNet into it. The data is on the order of 5GB, so the download takes 10-20 minutes on a reasonably fast internet connection.

$ ./fetch_data.sh
Fetching omniglot/images_background ...
Extracting omniglot/images_background ...
Fetching omniglot/images_evaluation ...
Extracting omniglot/images_evaluation ...
Fetching Mini-ImageNet train set ...
Fetching wnid: n01532829
Fetching wnid: n01558993
Fetching wnid: n01704323
Fetching wnid: n01749939
...

If you want to download Omniglot but not Mini-ImageNet, you can simply kill the script after it starts downloading Mini-ImageNet. The script automatically deletes partially-downloaded data when it is killed early.

Reproducing training runs

You can train models with the run_omniglot.py and run_miniimagenet.py scripts. Hyper-parameters are specified as flags (see --help for a detailed list). Here are the commands used for the paper:

# transductive 1-shot 5-way Omniglot.
python -u run_omniglot.py --shots 1 --inner-batch 25 --inner-iters 3 --meta-step 1 --meta-batch 10 --meta-iters 100000 --eval-batch 25 --eval-iters 5 --learning-rate 0.001 --meta-step-final 0 --train-shots 15 --checkpoint ckpt_o15t --transductive

# transductive 1-shot 5-way Mini-ImageNet.
python -u run_miniimagenet.py --shots 1 --inner-batch 5 --inner-iters 15 --meta-step 1 --meta-batch 10 --meta-iters 100000 --eval-batch 5 --eval-iters 10 --learning-rate 0.001 --meta-step-final 0 --train-shots 15 --checkpoint ckpt_m15t --transductive

# 5-shot 5-way Mini-ImageNet.
python -u run_miniimagenet.py --inner-batch 10 --inner-iters 8 --meta-step 1 --meta-batch 5 --meta-iters 200000 --eval-batch 15 --eval-iters 88 --learning-rate 0.00022 --meta-step-final 0 --train-shots 15 --checkpoint ckpt_m55

# 1-shot 5-way Mini-ImageNet.
python -u run_miniimagenet.py --shots 1 --inner-batch 3 --inner-iters 19 --meta-step 0.235 --meta-batch 2 --meta-iters 200000 --eval-batch 3 --eval-iters 55 --learning-rate 0.0012 --meta-step-final 0 --train-shots 12 --checkpoint ckpt_m15

# 5-shot 5-way Omniglot.
python -u run_omniglot.py --train-shots 10 --inner-batch 10 --inner-iters 5 --learning-rate 0.0015 --meta-step 0.7 --meta-step-final 0 --meta-batch 5 --meta-iters 100000 --eval-batch 6 --eval-iters 100 --checkpoint ckpt_o55

# 1-shot 5-way Omniglot.
python -u run_omniglot.py --shots 1 --inner-batch 5 --inner-iters 12 --meta-step 1 --meta-batch 3 --meta-iters 200000 --eval-batch 5 --eval-iters 86 --learning-rate 0.00044 --meta-step-final 0 --train-shots 12 --checkpoint ckpt_o15

# 1-shot 20-way Omniglot.
python -u run_omniglot.py --shots 1 --classes 20 --inner-batch 15 --inner-iters 12 --meta-step 1 --meta-batch 3 --meta-iters 200000 --eval-batch 10 --eval-iters 97 --learning-rate 0.00046 --meta-step-final 0 --train-shots 9 --checkpoint ckpt_o120

# 5-shot 20-way Omniglot.
python -u run_omniglot.py --classes 20 --inner-batch 20 --inner-iters 12 --meta-step 1 --meta-batch 3 --meta-iters 200000 --eval-batch 10 --eval-iters 97 --learning-rate 0.00046 --meta-step-final 0 --train-shots 12 --checkpoint ckpt_o520

Training creates checkpoints. Currently, you cannot resume training from a checkpoint, but you can re-run evaluation from a checkpoint by passing --pretrained. You can use TensorBoard on the checkpoint directories to see approximate learning curves during training and testing.

To evaluate with transduction, pass the --transductive flag. In this implementation, transductive evaluation is faster than non-transductive evaluation since it makes better use of batches.

Comparing different inner-loop gradient combinations

Here are the commands for comparing different gradient combinations. The --foml flag indicates that only the final gradient should be used.

# Shared hyper-parameters for all experiments.
shared="--sgd --seed 0 --inner-batch 25 --learning-rate 0.003 --meta-step-final 0 --meta-iters 40000 --eval-batch 25 --eval-iters 5 --eval-interval 1"

python run_omniglot.py --inner-iters 1 --train-shots 5 --meta-step 0.25 --checkpoint g1_ckpt $shared | tee g1.txt

python run_omniglot.py --inner-iters 2 --train-shots 10 --meta-step 0.25 --checkpoint g1_g2_ckpt $shared | tee g1_g2.txt
python run_omniglot.py --inner-iters 2 --train-shots 10 --meta-step 0.125 --checkpoint half_g1_g2_ckpt $shared | tee half_g1_g2.txt
python run_omniglot.py --foml --inner-iters 2 --train-shots 10 --meta-step 0.25 --checkpoint g2_ckpt $shared | tee g2.txt

python run_omniglot.py --inner-iters 3 --train-shots 15 --meta-step 0.25 --checkpoint g1_g2_g3_ckpt $shared | tee g1_g2_g3.txt
python run_omniglot.py --inner-iters 3 --train-shots 15 --meta-step 0.08325 --checkpoint third_g1_g2_g3_ckpt $shared | tee third_g1_g2_g3.txt
python run_omniglot.py --foml --inner-iters 3 --train-shots 15 --meta-step 0.25 --checkpoint g3_ckpt $shared | tee g3.txt

python run_omniglot.py --foml --inner-iters 4 --train-shots 20 --meta-step 0.25 --checkpoint g4_ckpt $shared | tee g4.txt
python run_omniglot.py --inner-iters 4 --train-shots 20 --meta-step 0.25 --checkpoint g1_g2_g3_g4_ckpt $shared | tee g1_g2_g3_g4.txt
python run_omniglot.py --inner-iters 4 --train-shots 20 --meta-step 0.0625 --checkpoint fourth_g1_g2_g3_g4_ckpt $shared | tee fourth_g1_g2_g3_g4.txt

supervised-reptile's People

Contributors

unixpickle avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.