GithubHelp home page GithubHelp logo

leo's Introduction

Meta-Learning with Latent Embedding Optimization

Overview

This repository contains the implementation of the meta-learning model described in the paper "Meta-Learning with Latent Embedding Optimization" by Rusu et. al. It was posted on arXiv in July 2018 and will be presented at ICLR 2019.

The paper learns a data-dependent latent representation of model parameters and performs gradient-based meta-learning in this low-dimensional space.

The code here doesn't include the (standard) method for pre-training the data embeddings. Instead, the trained embeddings are provided.

Disclaimer: This is not an official Google product.

Running the code

Setup

To run the code, you first need to need to install:

Getting the data

You need to download the embeddings and extract them on disk:

$ wget http://storage.googleapis.com/leo-embeddings/embeddings.zip
$ unzip embeddings.zip
$ EMBEDDINGS=`pwd`/embeddings

Running the code

Then, clone this repository using:

$ git clone https://github.com/deepmind/leo

and run the code as:

$ python runner.py --data_path=$EMBEDDINGS

This will train the model for solving 5-way 1-shot miniImageNet classification.

Hyperparameters

To train the model on the tieredImageNet dataset or with a different number of training examples per class (K-shot), you can pass these parameters with command-line or in config.py, e.g.:

$ python runner.py --data_path=$EMBEDDINGS --dataset_name=tieredImageNet --num_tr_examples_per_class=5 --outer_lr=1e-4

See config.py for the list of options to set.

Comparison of paper and open-source implementations in terms of test set accuracy:

Implementation miniImageNet 1-shot miniImageNet 5-shot tieredImageNet 1-shot tieredImageNet 5-shot
LEO Paper 61.76 ± 0.08% 77.59 ± 0.12% 66.33 ± 0.05% 81.44 ± 0.09%
This code 61.89 ± 0.16% 77.65 ± 0.09% 66.25 ± 0.14% 81.77 ± 0.09%

The hyperparameters we found working best for different setups are as follows:

Hyperparameter miniImageNet 1-shot miniImageNet 5-shot tieredImageNet 1-shot tieredImageNet 5-shot
outer_lr 2.739071e-4 4.102361e-4 8.659053e-4 6.110314e-4
l2_penalty_weight 3.623413e-10 8.540338e-9 4.148858e-10 1.690399e-10
orthogonality_penalty_weight 0.188103 1.523998e-3 5.451078e-3 2.481216e-2
dropout_rate 0.307651 0.300299 0.475126 0.415158
kl_weight 0.756143 0.466387 2.034189e-3 1.622811
encoder_penalty_weight 5.756821e-6 2.661608e-7 8.302962e-5 2.672450e-5

leo's People

Contributors

sygi avatar

Stargazers

Aitor Chamizo avatar Ebrahim Pichka avatar wen cao avatar  avatar zhaomiao avatar zml avatar Qiu Chuanhang avatar  avatar jwhao avatar  avatar  avatar  avatar Jack Li avatar  avatar  avatar Douwe den Blanken avatar  avatar George Tsoumplekas avatar Chenhui Zhang avatar Jung Yeon Lee avatar Guangyuan Weng avatar  avatar Chaofei avatar  avatar  avatar  avatar Jihoon Park avatar zengrh avatar Shu-Yu Huang avatar Shang avatar Hyeok Yoon avatar  avatar Changbin Li avatar Pramod avatar Jiarong avatar  avatar  avatar  avatar Studying avatar kjogr avatar Rocío Mercado Oropeza avatar cin-hubert avatar STYLIANOS IORDANIS avatar oking9 avatar JiTao avatar Luming Tang avatar  avatar Jerry Zhi-Yang He avatar Aleksandar Stanic avatar  avatar  avatar yangxuesong avatar Luna Smith avatar befreor avatar  avatar  avatar Théo Morales avatar Kirsury avatar Jiarong Ye avatar Xiaoyi Yin avatar Lingdong Kong avatar  avatar  avatar Jubilee.Yang avatar leerw avatar guodong qi avatar Tong Chen avatar  avatar Giorgio Giannone avatar Ruolin Ye avatar  avatar  avatar Gautam Mittal avatar Ankit Shah avatar  avatar  avatar Matthew Guan avatar SeventhHeaven avatar Makdoud avatar Rui Li avatar mio avatar Xiaohan Chen avatar Tim Chen (Hung-Ting Chen) avatar Jiacheng Zhu avatar Khoi Duc Nguyen avatar  avatar  avatar YudongChen avatar Federer Fanatic avatar  avatar Li Jianchen avatar Matt Shaffer avatar Nate Wildermuth avatar Zichuan Lin avatar  avatar  avatar  avatar Sun Hao avatar SAD avatar HilbertXu avatar

Watchers

 avatar James Cloos avatar Andreas Fidjeland avatar Andrei A. Rusu avatar Grzegorz Warzecha avatar Tonmoy Saikia avatar Ran Tao avatar  avatar Xin Lai avatar  avatar  avatar Madox avatar  avatar paper2code - bot avatar

leo's Issues

Questions on implementation

First of all, thank you for the wonderful work! I really enjoyed reading it.

I am currently trying to reimplement your work and got some questions.

  1. Is there any reference to orthogonality regularization? Also, it seems like it is regularized to induce better expressibility for row vector then column vector, which is unorthodox. Is this because the goal of regularization is actually for the latent code than its output whose value eventually some type of "prototype vector" for each class?

https://github.com/deepmind/leo/blob/de9a0c2a77dd7a42c1986b1eef18d184a86e294a/model.py#L40-L41

  1. Is the KL divergence implementation correct? Since our prior and posterior is both Gaussian, there is an analytic formula to calculate it, but the current implementation is using a sample-based approach, and it doesn't represent the KL-divergence; it should be \sum q * (log q - log p), but it is currently \mean (log q - log p). Am I missing something?

https://github.com/deepmind/leo/blob/de9a0c2a77dd7a42c1986b1eef18d184a86e294a/model.py#L269-L274

  1. Currently, $z_n$ is sampled K times from q, and it resulted in K $w_n$ vectors where the mean is taken over logits. Is this to stabilize training caused by reparameterization trick, or is there more than that?

https://github.com/deepmind/leo/blob/de9a0c2a77dd7a42c1986b1eef18d184a86e294a/model.py#L250-L254

Thank you!

Test and accuracy problem

Thanks for your great work. When I run in test mode, it happens a AssertionError.
Traceback (most recent call last): File "runner.py", line 213, in <module> tf.app.run() File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/platform/app.py", line 125, in run _sys.exit(main(argv)) File "runner.py", line 208, in main run_training_loop(FLAGS.checkpoint_path) File "runner.py", line 193, in run_training_loop assert not FLAGS.checkpoint_steps AssertionError

If I annotate 'assert not FLAGS.checkpoint_steps', I can only get 75% acc in test dataset.

Incompatibility with Python 3

Hi,

I could not find any reference to necessary python version.
However it appears that python2.7 was counted on.
For Python 3 at least the following is required:

  • Add in pickle.load encoding='latin1' else ASCII is assumed.
  • Change iteritems to items.

meta dataset used in training

Thank you for this great work!
I noticed that a PyTorch implementation [1] states that the model is trained on both the training set and validation set in the LEO paper.
I wonder to know if this TensorFlow implementation is trained on both the training set and validation set?

Thank you.

[1] https://github.com/timchen0618/pytorch-leo

EMBEDDINGS Proper export

In the README.md for storing embeddings path you have:

$ EMBEDDINGS=`pwd`/embeddings

Change to:

$ export EMBEDDINGS=`pwd`/embeddings

Killed!

When I tried to run the file runner.py, it shows "Allocation of 1533440000 exceeds 10% of system memory." and then is killed. I tried to run it on colab as well as my laptop but they have the same result.

I have installed "tensorflow==1.13.1", "dm-sonnet==1.29" and "tensorflow-probability==0.5.0"

So many WARNINGS while running

I have installed required libs, and tensorflow version is 0.13 and tensorflow-probability version is 0.6. However, when I ran runner.py, there appeared so many WARNINGs showing that some usages of tensorflow are deprecated and will be removed in a future version. And there was nothing about expected results. How can I solve this problem? Thanks a lot.

How to test model?

Thank you for your amazing work.

Does this version include code to test the model?

Leo

  • Oo00**__**~~~~

Can't reproduce the results reported in the paper

Firstly, thanks a lot for the great job! But I still have several problems as follows.
I ran this code in Titan Xp(with tensorflow==1.13.1), but I only obtained 60.3 test accuracy in 5-way 1-shot setting. And then I switched to the best hyperparameters provided, but only got 60.9 test accuracy with 68.38 validation accuracy, which is also somewhat less than the results reported in the paper.
I tested by setting the 'checkpoint_steps' to 0 and 'evaluation_mode' to True in the runner.py file.
Is something wrong? Would you please give me some hints to reproduce the results reported? @sygi @andreirusu Thanks a lot!

Where is the 'outer loop' mentioned in the paper?

Hi! I have a small problem...
In the code, I only saw the inner loop', and did not find the implementation of N-way K-shot classification tasks, such as using wrn-28-10 to do image classification?
Maybe it's in the raw_data["embeddings"] of the embeddings you provided?Already processed?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.