GithubHelp home page GithubHelp logo

chagge / wlrn Goto Github PK

View Code? Open in Web Editor NEW

This project forked from nenadmarkus/wlrn

0.0 1.0 0.0 3.37 MB

Learn deep local descriptors from weakly-labeled data

Lua 57.47% Shell 1.38% C++ 33.33% Makefile 1.01% Python 6.82%

wlrn's Introduction

Learning Local Descriptors from Weakly-Labeled Data

Current best local descriptors are learned on a large dataset of matching and non-matching keypoint pairs. However, data of this kind is not always available since detailed keypoint correspondences can be hard to establish (e.g., for non-image data). On the other hand, we can often obtain labels for pairs of keypoint bags. For example, keypoint bags extracted from two images of the same object under different views form a matching pair, and keypoint bags extracted from images of different objects form a non-matching pair. On average, matching pairs should contain more corresponding keypoints than non-matching pairs. We propose to learn local descriptors from such information where local correspondences are not known in advance.

Teaser

Each image in the dataset (first row) is processed with a keypoint detector (second row) and transformed into a bag of visual words (third row). Some bags form matching pairs (green arrow) and some form non-matching pairs (red arrows). On average, matching pairs should contain more corresponding local visual words than non-matching pairs. We propose to learn local descriptors by optimizing the mentioned local correspondence criterion on a given dataset. Note that prior work assumes local correspondences are known in advance.

The details of the method can be found in our technical report available on arXiv. If you use our results and/or ideas, please cite the report as (BibTeX)

@misc
{
	wlrn,
	author = {Nenad Marku\v{s} and Igor S. Pand\v{z}i\'c and J\"{o}rgen Ahlberg},
	title = {{Learning Local Descriptors by Optimizing the Keypoint-Correspondence Criterion}},
	year = {2016},
	eprint = {arXiv:1603.09095}
}

Some results (to be updated soon)

A network trained with our method (code in this repo) can be obtained from the folder models/. This network extracts 64f descriptors of unit length from local patches of size 32x32. Here is its structure:

nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> output]
  (1): nn.MulConstant
  (2): nn.View
  (3): nn.Sequential {
    [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> output]
    (1): nn.SpatialConvolution(3 -> 32, 3x3)
    (2): nn.ReLU
    (3): nn.SpatialConvolution(32 -> 64, 4x4, 2,2)
    (4): nn.ReLU
    (5): nn.SpatialConvolution(64 -> 128, 3x3)
    (6): nn.SpatialMaxPooling(2,2,2,2)
    (7): nn.SpatialConvolution(128 -> 32, 1x1)
    (8): nn.SpatialConvolution(32 -> 64, 6x6)
  }
  (4): nn.View
  (5): nn.Normalize(2)
}

The structure is specified in models/3x32x32_to_64.lua. The net parameters are stored as a vector of floats at models/3x32x32_to_64.params. This is to reduce the storage requirements (i.e., the repo size). Use the following code to deploy and use the net.

-- load the network parameters first
params = torch.load('models/3x32x32_to_64.params')

-- create the network and initialize its weights with loaded data
n = dofile('models/3x32x32_to_64.lua')(params):float()

-- generate a random batch of five 32x32 patches (each pixel is a float from [0, 255])
p = torch.rand(5, 3, 32, 32):float():mul(255)

-- propagate the batch through the net to obtain descriptors
-- (note that no patch prepocessing is required (such as mean substraction))
d = n:forward(p)

-- an appropriate similarity between descriptors is, for example, a dot product ...
print(d[1]*d[2])

-- ... or you can use the Euclidean distance
print(torch.norm(d[1] - d[2]))

Notice that although it was trained on 32x32 patches, the model can be applied in a fully-convolutional manner to images of any size (the third module of the architecture contains only convolutions, ReLUs and pooling operations).

How to repeat the training

Follow these steps.

1. Prepare bags of keypoints

Download http://46.101.250.137/data/ukb.tar and extract the archive. It contains two folders with JPG images: ukb-trn/ and ukb-val/. Images from the first folder will be used for training and images from the second one for checking the validation error.

Move to the folder utils/ and compile fast.cpp and extp.cpp with the provided makefile. These are the keypoint detection and patch extraction programs. Use the script batch_extract.sh to transform the downloaded images into bags of keypoints:

bash batch_extract.sh ukb-trn/ ukb-trn-bags/ 128 32
bash batch_extract.sh ukb-val/ ukb-val-bags/ 128 32

Extracted patches should now be in ukb-trn-bags/ and ukb-val-bags/. As these are stored in the JPG format, you can inspect them with your favorite image viewer.

2. Prepare data-loading scripts

To keep a desirable level of abstraction and enable large-scale learning, this code requires the user to provide his/her routines for generating triplets. An example can be found at utils/tripletgen.lua. The strings "--FOLDER--", "--NCHANNELS--" and "--PROBABILITY--" need to be replaced with appropriate ones, depending whether loading training or validation data. The following shell commands will do this for you (replace each slash in the folder paths with backslash+slash as required by sed).

cp utils/tripletgen.lua trn-tripletgen.lua
sed -i -e 's/--FOLDER--/"ukb-trn-bags"/g' trn-tripletgen.lua
sed -i -e 's/--NCHANNELS--/3/g' trn-tripletgen.lua
sed -i -e 's/--PROBABILITY--/0.33/g' trn-tripletgen.lua

cp utils/tripletgen.lua val-tripletgen.lua
sed -i -e 's/--FOLDER--/"ukb-val-bags"/g' val-tripletgen.lua
sed -i -e 's/--NCHANNELS--/3/g' val-tripletgen.lua
sed -i -e 's/--PROBABILITY--/1.0/g' val-tripletgen.lua

After executing them, you should find two Lua files, trn-tripletgen.lua and val-tripletgen.lua, next to wlrn.lua.

3. Specify the descriptor-extractor structure

The model is specified with a Lua script which returns a function for constructing the descriptor extraction network. See the default model in models/3x32x32_to_64.lua for an example.

Of course, you can try different architectures. However, to learn their parameters, some parts of wlrn.lua might need additional tweaking (such as learning rates).

4. Start the learning script

Finally, learn the parameters of the network by running the traininig script:

th wlrn.lua models/3x32x32_to_64.lua trn-tripletgen.lua -v val-tripletgen.lua -w params.t7

The training should finish in about a day on a GeForce GTX 970 with cuDNN. The file params.t7 contains the learned parameters of the descriptor extractor specified in models/3x32x32_to_64.lua. Use the following code to deploy them:

n = dofile('models/3x32x32_to_64.lua')():float()
p = n:getParameters()
p:copy(torch.load('params.t7'))
torch.save('net.t7', n)

Contact

For any additional information contact me at [email protected].

Copyright (c) 2016, Nenad Markus. All rights reserved.

wlrn's People

Contributors

nenadmarkus avatar

Watchers

Chagge avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.