GithubHelp home page GithubHelp logo

agadetsky / pytorch-pl-variance-reduction Goto Github PK

View Code? Open in Web Editor NEW
35.0 2.0 2.0 420 KB

[AAAI 2020 Oral] Low-variance Black-box Gradient Estimates for the Plackett-Luce Distribution

License: Apache License 2.0

Python 14.72% Jupyter Notebook 85.28%
aaai2020 pytorch variance-reduction plackett-luce relax rebar structure-learning directed-acyclic-graph control-variates neurips-2019

pytorch-pl-variance-reduction's Introduction

Low-variance Black-box Gradient Estimates for the Plackett-Luce Distribution

This repo contains code for our paper Low-variance Black-box Gradient Estimates for the Plackett-Luce Distribution.

Abstract

Learning models with discrete latent variables using stochastic gradient descent remains a challenge due to the high variance of gradients. Modern variance reduction techniques mostly consider categorical distributions and have limited applicability when the number of possible outcomes becomes large. In this work, we consider models with latent permutations and propose control variates for the Plackett-Luce distribution. In particular, the control variates allow us to optimize black-box functions over permutations using stochastic gradient descent. To illustrate the approach, we consider a variety of causal structure learning tasks for continuous and discrete data. We show that for differentiable functions, our method outperforms competitive relaxation-based optimization methods and is also applicable to non-differentiable score functions.

Citation

@inproceedings{gadetsky2020lowvariance,
  author    = {Artyom Gadetsky and
               Kirill Struminsky and
               Christopher Robinson and
               Novi Quadrianto and
               Dmitry P. Vetrov},
  title     = {Low-Variance Black-Box Gradient Estimates for the Plackett-Luce Distribution},
  booktitle = {The Thirty-Fourth {AAAI} Conference on Artificial Intelligence, {AAAI}
               2020, New York, NY, USA, February 7-12, 2020},
  pages     = {10126--10135},
  publisher = {{AAAI} Press},
  year      = {2020},
  url       = {https://aaai.org/ojs/index.php/AAAI/article/view/6572}
}

Toy Experiment

Prepare environment (maybe you'll need to change cudatoolkit version in toy_env.yml or even use cpuonly version of PyTorch):

conda env create -f toy_env.yml
conda activate toy_env

Run toy_experiment.py:

python toy_experiment.py --estimator exact
python toy_experiment.py --estimator reinforce
python toy_experiment.py --estimator rebar
python toy_experiment.py --estimator relax

Plot figure using plot_toy.ipynb

alt text

Results were obtained using cpu. Quantitative results for cuda may vary slighlty due to randomness in cuda kernels, but qualitative results remain the same.

pytorch-pl-variance-reduction's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch-pl-variance-reduction's Issues

License

Can you add license information?

What if log_theta is parametrized by another network?

I am confused about the comments in estimators.py:
"in all computations (including those made by the user later), we don't
"want to backpropagate past "logits" into the model. We make a detached
"copy of logits and rebuild the graph from the detached copy to z

logits = logits.detach().require_grad(True)

if logits are parametrized by a network theta, then if you detach here then the network won't get updated right?

Code for DAG experiment

Hello,

Thanks a lot for your code. A really interesting work. Will the code for experiments on DAGs be released? It would be very nice to see them on some real problems like causal inference as given in the paper.

I tried to manually reproduce the results for 10, 20, 50 nodes but unfortunately failed to do so. If the code is not being released for it, can I kindly know some details regarding the implementation of accelerated proximal method (hyperparmeters, optimizer used, number of iterations run, learning rate etc.).

Thanks a lot for your help.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.