GithubHelp home page GithubHelp logo

xjtueducation / dsacline Goto Github PK

View Code? Open in Web Editor NEW

This project forked from vislearn/dsacline

0.0 2.0 0.0 239 KB

Differentiable RANSAC: Learning Robust Line Fitting

License: BSD 3-Clause "New" or "Revised" License

Python 100.00%

dsacline's Introduction

Differentiable RANSAC: Learning Robust Line Fitting

Introduction

This code illustrates the principles of differentiable RANSAC (DSAC) on a simple toy problem of fitting lines to noisy, synthetic images.

Input and desired output.

Left: Input image. Right: Ground truth line.

We solve this task by training a CNN which predicts a set of 2D points within the image. We fit our desired line to these points using RANSAC.

DSAC line fitting.

Left: Input image. Center: Points predicted by a CNN. Right: Line (blue) fitted to the predictions.

Ideally, the CNN would place all its point predictions on the image line segment. But because RANSAC is robust to outlier points, the CNN may choose to allow some erroneous predictions in favor of overall accuracy.

We train the CNN end-to-end from scratch to minimize the deviation between the (robustly) fitted line and the ground truth line.

Running the Code

Just execute python main.py to start a training run with the standard settings. Running python main.py -h will list all parameter options for playing around.

The code generates training data on the fly, and trains two CNNs in parallel. The first CNN predicts a set of 2D points to which the output line is fitted using DSAC. The second CNN is a baseline where the line parameters are predicted directly, i.e. without DSAC.

In a specified interval during training, both CNNs are tested on a fixed validation set, and the visualization of the predictions is stored in an image such as the following:

Training output.

Left: Validation inputs. Center: DSAC estimates with dots marking the CNN prediction, blue the fitted line and green the ground truth line. Green borders mark accurate line predictions, red boxes mark inaccurate line predictions (there is a threshold parameter). Right: Predictions of the baseline CNN (direct prediction of line parameters).

Dependencies

This code requires the following packages, and was tested with the package version in brackets.

pytorch (0.5.0), torchvision (0.2.1), scikit-image (0.14.0)

Training Speed

Depending on your system specification, one training iteration (with the standard batch size of 32) can take more than one second. This might seem excessive for a simple toy problem. Note that this code is designed for educative clarity rather than speed. The whole DSAC portion of training runs in native Python on a single CPU core, and backpropagation relies soley on standard PyTorch autograd. In any production setting, one would write a C++/CUDA extension encapsulating DSAC for a huge runtime boost. See for example our camera localization pipelines which utilize DSAC here and here.

How Does It Work?

Vanilla RANSAC works by creating a set of model hypotheses (line hypotheses in our case), scoring them e.g. by inlier counting, and selecting the best one.

DSAC is based on the idea of making hypothesis selection a probabilistic action. The probability of selecting a hypothesis increases with its score (e.g. inlier count). Training the CNN aims at minimizing the expected loss of the selected hypothesis.

Training output.

More details and a formal description can be found in the papers referenced at the end of this document.

In a nutshell, the training process works like this:

  1. CNN predicts 2D points
  2. sample line hypotheses by choosing random pairs of points
  3. score hypotheses by soft inlier counting, and calculate selection probabilities
  4. refine hypotheses by re-fitting them to their soft inliers
  5. calculate expected loss of refined hypotheses w.r.t. selection probabilities
  6. backprob, update CNN, repeat

Loss Function

For this toy problem, we are interested in observing visually nicely aligned lines rather then the nominal error in line parameters. We thus measure the maximum distance between the predicted line and ground truth within the image, and aim at minimizing this distance as our loss function.

Loss function.

Red arrows mark the error between ground truth line (green) and estimated line (blue) that we try to minimize.

Code Structure

main.py Main script that handles the training loop.

dsac.py Encapsulates robust, differentiable line fitting with DSAC (sampling hypotheses, scoring, refinement, expected loss).

line_dataset.py Generates random, noisy input images with associated ground truth parameters. Also includes functions for visualizing predictions.

line_loss.py Loss function used to compare predicted and ground truth lines.

line_nn.py Definition of the CNN architecture which supports prediction of 2D points or direct regression of line parameters.

Publications

The following paper introduced DSAC for camera localization (paper link).

@inproceedings{brachmann2017dsac,
  title={{DSAC}-{Differentiable RANSAC} for Camera Localization},
  author={Brachmann, Eric and Krull, Alexander and Nowozin, Sebastian and Shotton, Jamie and Michel, Frank and Gumhold, Stefan and Rother, Carsten},
  booktitle={CVPR},
  year={2017}
}

This code uses a soft inlier count instead of a learned scoring function, as suggested in the following paper (paper link).

@inproceedings{brachmann2018lessmore,
  title={Learning Less is More-{6D} Camera Localization via {3D} Surface Regression},
  author={Brachmann, Eric and Rother, Carsten},
  booktitle={CVPR},
  year={2018}
}

Please cite one of these papers if you use DSAC or parts of this code in your own work.

dsacline's People

Contributors

ebrach avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.