GithubHelp home page GithubHelp logo

fsl-imprinted-weights's Introduction

Pytorch implementation of Low-Shot Learning with Imprinted Weights.

I started this repository as the clone of this repository. But then I realized that I had changed almost everything, therefore decided to make a separate repo on it's own. But still there may be some intact code snippets, for which I give credit to @YU1ut.

Important note

In the paper, Inception V1 is used as the feature extractor. However, since there is no pre-trained Inception V1 model in torchvision.models package, in this repo I use ResNet-50 as the feature extractor. Besides, somewhat surprisingly fine-tuning ResNet-50 with RMSProp (with the exact same parameters as in the paper) results in poor generalization. Perhaps, this is yet another case where SGD with momentum is superior than RMSProp with momentum, in terms of generalization.

Development environment

  • ubuntu 18.04
  • cuda 9.0
  • conda 4.5.11
  • python 3.6.4
  • pytorch 1.0.0
  • torchvision 0.2.1
  • sklearn 0.19.1
  • matplotlib 3.0.1
  • numpy 1.15.4
  • tqdm

Dataset

Download CUB_200_2011 Dataset and unzip it into data directory under the repository folder.

Usage

Currently, I don't have enough time to explain every step in detail. Please see the scripts to understand how things work.

Results

Please see the followings for how scores are computed.

  • Each n shot setting is performed 5 times such that a different seed is used before sampling of n samples from novel classes.
  • In each imprinting experiment, a sample from a novel class is augmented 5 times.
  • For each score, confusion matrix is computed and then per-class recall scores are averaged accordingly.

Average per-class recalls of novel classes in CUB-200-2011

w/o FT

n = 1 2 5 10 20
Rand-noFT (paper) 0.17 0.17 0.17 0.17 0.17
Imprinting (paper) 21.26 28.69 39.52 45.77 49.32
Imprinting + Aug (paper) 21.40 30.03 39.35 46.35 49.80
Rand-noFT 0.00 0.00 0.00 0.00 0.01
Imprinting + Aug 20.2 27.9 38.9 46.3 50.4

w/ FT

n = 1 2 5 10 20
Rand + FT (paper) 5.25 13.41 34.95 54.33 65.60
Imprinting + FT (paper) 18.67 30.17 46.08 59.39 68.77
AllClassJoint (paper) 3.89 10.82 33.00 50.24 64.88
Rand + FT 3.8 11.6 32.9 51.7 66.8
Imprinting + Aug + FT 19.3 31.4 50.4 61.7 66.9
AllClassJoint 5.6 16.0 41.5 59.6 71.7
AllClassJoint - Cosine Similarity 6.6 19.5 47.8 65.6 76.7

Average per-class recalls of all classes in CUB-200-2011

w/o FT

n = 1 2 5 10 20
Rand-noFT (paper) 37.36 37.36 37.36 37.36 37.36
Imprinting (paper) 44.75 48.21 52.95 55.99 57.47
Imprinting + Aug (paper) 44.60 48.48 52.78 56.51 57.84
Rand-noFT 41.2 41.2 41.2 41.2 41.2
Imprinting + Aug 50.4 54.1 59.2 62.8 64.8

w/ FT

n = 1 2 5 10 20
Rand + FT (paper) 39.26 43.36 53.69 63.17 68.75
Imprinting + FT (paper) 45.81 50.41 59.15 64.65 68.73
AllClassJoint (paper) 38.02 41.89 52.24 61.11 68.31
Rand + FT 42.4 45.9 56.2 65.1 72.5
Imprinting + Aug + FT 50.3 56.0 65.1 70.6 72.6
AllClassJoint 42.1 46.9 59.9 68.8 74.4
AllClassJoint - Cosine Similarity 44.3 50.6 64.5 73.7 78.6

fsl-imprinted-weights's People

Contributors

mbsariyildiz avatar

Stargazers

Matthieu HERNANDEZ avatar Howard H. Tang avatar R. Gokberk Cinbis avatar

Watchers

 avatar

Forkers

13301338176

fsl-imprinted-weights's Issues

magic number in imprint

Hi,
Is the magic number[100] in new_weight a class number?
new_weight = torch.zeros(100, d_emb)
for i in range(100):
tmp = output_stack[target_stack == (i + 100)].mean(0) if not random else torch.randn(d_emb, device=device)
new_weight[i] = tmp / tmp.norm(p=2)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.