GithubHelp home page GithubHelp logo

haowencs / deep-kernel-transfer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from bayeswatch/deep-kernel-transfer

0.0 0.0 0.0 656 KB

Official pytorch implementation of the paper "Bayesian Meta-Learning for the Few-Shot Setting via Deep Kernels" (NeurIPS 2020)

Home Page: https://arxiv.org/abs/1910.05199

Shell 1.09% Python 98.91%

deep-kernel-transfer's Introduction

Official pytorch implementation of the paper:

"Bayesian Meta-Learning for the Few-Shot Setting via Deep Kernels" (2020) Patacchiola, M., Turner, J., Crowley, E. J., O'Boyle, M., & Storkey, A., Advances in Neural Information Processing (NeurIPS, Spotlight) [arXiv]

@inproceedings{patacchiola2020bayesian,
  title={Bayesian Meta-Learning for the Few-Shot Setting via Deep Kernels},
  author={Patacchiola, Massimiliano and Turner, Jack and Crowley, Elliot J. and Storkey, Amos},
  booktitle={Advances in Neural Information Processing Systems},
  year={2020}
}

Overview. We introduce a Bayesian meta-learning method based on Gaussian Processes (GPs) to tackle the problem of few-shot learning. We propose a simple, yet effective variant of deep kernel learning in which the kernel is transferred across tasks, which we call Deep Kernel Transfer (DKT). This approach is straightforward to implement, provides uncertainty quantification, and does not require estimation of task-specific parameters. We empirically demonstrate that DKT outperforms several state-of-the-art algorithms in few-shot regression, classification, and cross-domain adaptation.

NOTE: previous pre-prints of this paper have used the names "GPNet" and "GPShot". In the published version we are using the name "DKT".

Requirements

  1. Python >= 3.x
  2. Numpy >= 1.17
  3. pyTorch >= 1.2.0
  4. GPyTorch >= 0.3.5
  5. (optional) TensorboardX

WARNING: some users have experienced issues in running some of the scripts due to the error: "Matrix not positive definite". This is likely caused by the latest versions of GPyTorch. The configuration that is working on our system uses: gpytorch 1.0.1, python 3.6.9, torch 1.8.1. We suggest to replicate this configuration in a conda environment in case you experience the same issue.

Installation

pip install numpy torch torchvision gpytorch h5py pillow

We confirm that the following configuration worked for us: numpy 1.18.1, torch 1.4.0, torchvision 0.5.0, gpytorch 1.0.1, h5py 5.10.0, pillow 7.0.0

DKT: code of our method

Regression. The implementation of our method is based on the gpyTorch library. The code for the regression case is available in DKT_regression.py.

Classification. The code for the classification case is accessible in DKT.py, with most of the important pieces contained in the train_loop() method (training), and in the correct() method (testing).

Note: there is the possibility of using the scikit Laplace approximation at test time (classification only), setting laplace=True in correct(). However, this has not been investigated enough and it is not the method used in the paper.

Experiments

These are the instructions to train and test the methods reported in the paper in the various conditions.

Download and prepare a dataset. This is an example of how to download and prepare a dataset for training/testing. Here we assume the current directory is the project root folder:

cd filelists/DATASET_NAME/
sh download_DATASET_NAME.sh

Replace DATASET_NAME with one of the following: omniglot, CUB, miniImagenet, emnist, QMUL. Notice that mini-ImageNet is a large dataset that requires substantial storage, therefore you can save the dataset in another location and then change the entry in configs.py in accordance.

Methods. There are a few available methods that you can use: DKT, maml, maml_approx, protonet, relationnet, matchingnet, baseline, baseline++. You must use those exact strings at training and test time when you call the script (see below). Note that our method is DKT, and that baseline corresponds to feature transfer in our paper. By default DKT has a BNCosSim kernel, to change this please edit the entry in configs.py.

Backbone. The script allows training and testing on different backbone networks. By default the script will use the same backbone used in our experiments (Conv4). Check the file backbone.py for the available architectures, and use the parameter --model=BACKBONE_STRING where BACKBONE_STRING is one of the following: Conv4, Conv6, ResNet10|18|34|50|101.

Regression

QMUL Head Pose Trajectory Regression. In order to run this experiment you first have to download and setup the QMUL dataset, this can be done automatically running the file download_QMUL.sh from the folder filelists/QMUL. Moreover, you have to change the kernel type, editing the entry in configs.py (default BNCosSim) to rbf or spectral. Please note that other kernels are not supported for this experiment and their use will raise an error. The methods that can be used for regression are DKT and transfer (feature transfer). In order to train these methods, use:

python train_regression.py --method="DKT" --seed=1

The number of training epochs can be set with --stop_epoch. The above command will save a checkpoint to save/checkpoints/QMUL/Conv3_DKT, which you can test on the test set with:

python test_regression.py --method="DKT" --seed=1

You can additionally specify the size of the support set with --n_support (which defaults to 5), and the number of test epochs with --n_test_epochs (which defaults to 10).

Periodic functions. The code for the periodic functions experiments is available in the sines folder. This needs some adjustment of the parameters at the code level to reproduce the in-range and out-of-range conditions (see the associated README).

Classification

Train classification. The various methods can be trained using the following syntax:

python train.py --dataset="miniImagenet" --method="DKT" --train_n_way=5 --test_n_way=5 --n_shot=1 --seed=1 --train_aug

This will train DKT 5-way 1-shot on the mini-ImageNet dataset with seed 1. The dataset string can be one of the following: CUB, miniImagenet. At training time the best model is evaluated on the validation set and stored as best_model.tar in the folder ./save/checkpoints/DATASET_NAME. The parameter --train_aug enables data augmentation. The parameter seed set the seed for pytorch, numpy, and random. Set --seed=0 or remove the parameter for a random seed. Additional parameters are reported in the file io_utils.py.

Test classification. For testing DKT, maml and maml_approx it is enough to repeat the train command replacing the call to train.py with the call to test.py as follows:

python test.py --dataset="miniImagenet" --method="DKT" --train_n_way=5 --test_n_way=5 --n_shot=1 --seed=1 --train_aug

Other methods require to store the features (for efficiency) before testing, this can be done running the script save_features.py before calling test.py. For instance, if you trained a protonet, you should call:

python save_features.py --dataset="miniImagenet" --method="protonet" --train_n_way=5 --test_n_way=5 --n_shot=1 --seed=1 --train_aug
python test.py --dataset="miniImagenet" --method="protonet" --train_n_way=5 --test_n_way=5 --n_shot=1 --seed=1 --repeat=5 --train_aug

We noticed that the original code has a large variance on test tasks. To reduce this variance we add the parameter repeat=N. It iterates N times with different seeds and take an average over the N tests, we used N=5 (3000 tasks) in our experiments.

Cross-domain classification

For the cross-domain classification experiments the procedure is the same described previously. The only difference is that the available datasets are: cross_char, and cross. The former being omniglot -> EMNIST, and the latter miniImagenet -> CUB. Here an example of training procedure:

python train.py --dataset="cross_char" --method="DKT" --train_n_way=5 --test_n_way=5 --n_shot=1 --seed=1

Note that the parameter --train_aug (data augmentation) is not used for cross_char but only for cross.

Acknowledgements

This repository is a fork of: https://github.com/wyharveychen/CloserLookFewShot

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.