GithubHelp home page GithubHelp logo

yejg2017 / metalearning-lstm Goto Github PK

View Code? Open in Web Editor NEW

This project forked from markdtw/meta-learning-lstm-pytorch

0.0 1.0 0.0 46 KB

pytorch implementation of Optimization as a Model for Few-shot Learning

Python 94.13% Shell 5.87%

metalearning-lstm's Introduction

Optimization as a Model for Few-shot Learning

Pytorch implementation of Optimization as a Model for Few-shot Learning in ICLR 2017 (Oral)

Model Architecture

Prerequisites

  • python 3+
  • pytorch 0.4+ (developed on 1.0.1 with cuda 9.0)
  • pillow
  • tqdm (a nice progress bar)

Data

  • Mini-Imagenet as described here
    • You can download it from here (~2.7GB, google drive link)

Preparation

  • Make sure Mini-Imagenet is split properly. For example:
    - data/
      - miniImagenet/
        - train/
          - n01532829/
            - n0153282900000005.jpg
            - ...
          - n01558993/
          - ...
        - val/
          - n01855672/
          - ...
        - test/
          - ...
    - main.py
    - ...
    
    • It'd be set if you download and extract Mini-Imagenet from the link above
  • Check out scripts/train_5s_5c.sh, make sure --data-root is properly set

Run

For 5-shot, 5-class training, run

bash scripts/train_5s_5c.sh

Hyper-parameters are referred to the author's repo.

For 5-shot, 5-class evaluation, run (remember to change --resume and --seed arguments)

bash scripts/eval_5s_5c.sh

Notes

seed train episodes val episodes val acc mean val acc std test episodes test acc mean test acc std
719 41000 100 59.08 9.9 100 56.59 8.4
- - - - - 250 57.85 8.6
- - - - - 600 57.76 8.6
53 44000 100 58.04 9.1 100 57.85 7.7
- - - - - 250 57.83 8.3
- - - - - 600 58.14 8.5
  • The results I get from directly running the author's repo can be found here, I have slightly better performance (~5%) but neither results match the number in the paper (60%) (Discussion and help are welcome!).
  • Training with the default settings takes ~2.5 hours on a single Titan Xp while occupying ~2GB GPU memory.
  • The implementation replicates two learners similar to the author's repo:
    • learner_w_grad functions as a regular model, get gradients and loss as inputs to meta learner.
    • learner_wo_grad constructs the graph for meta learner:
      • All the parameters in learner_wo_grad are replaced by cI output by meta learner.
      • nn.Parameters in this model are casted to torch.Tensor to connect the graph to meta learner.
  • Several ways to copy a parameters from meta learner to learner depends on the scenario:
    • copy_flat_params: we only need the parameter values and keep the original grad_fn.
    • transfer_params: we want the values as well as the grad_fn (from cI to learner_wo_grad).
      • .data.copy_ v.s. clone() -> the latter retains all the properties of a tensor including grad_fn.
      • To maintain the batch statistics, load_state_dict is used (from learner_w_grad to learner_wo_grad).

References

metalearning-lstm's People

Contributors

markdtw avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.