GithubHelp home page GithubHelp logo

dplc's Introduction

Deep Generative Models for Distribution-Preserving Lossy Compression

PyTorch implementation of Deep Generative Models for Distribution-Preserving Lossy Compression (NIPS 2018), a framework that unifies generative models and lossy compression. The resulting models behave like generative models at zero bitrate, almost perfectly reconstruct the training data at high enough bitrate, and smoothly interpolate between generation and reconstruction at intermediate bitrates (cf. the figure above, the numbers indicate the rate in bits per pixel).

Prerequisites

  • Python 3 (tested with Python 3.6.4)
  • PyTorch (tested with version 0.4.1)
  • tensorboardX

Training

The training procedure consists of two steps

  1. Learn a generative model of the data.
  2. Learn a rate-constrained encoder and a stochastic mapping into the latent space of the of the fixed generative model by minimizing distortion.

The train.py script allows to do both of these steps.

To learn the generative model we consider Wasserstein GAN with gradient penalty (WGAN-GP), Wasserstein Autoencoder (WAE), and a combination of the two termed Wasserstein++. The following examples show how to train these models as in the experiments in the paper using the CelebA data set (see train.py for a description of the flags).

WGAN-GP:

python train.py --dataset celeba --dataroot /path/to/traindata/ --testroot /path/to/testdata/ --cuda --nz 128 \
    --sigmasqz 1.0 --lr_eg 0.0001 --lr_di 0.0001 --beta1 0.5 --beta2 0.9 --niter 165 --check_every 100 \
    --workers 6 --outf /path/to/results/ --batchSize 64 --test_every 100 --addsamples 10000 --manualSeed 321 \
    --wganloss

WAE:

python train.py --dataset celeba --dataroot /path/to/traindata/ --testroot /path/to/testdata/ --cuda --nz 128 \
    --sigmasqz 1.0 --lr_eg 0.001 --niter 55 --decay_steps 30 50 --decay_gamma 0.4 --check_every 100 \
    --workers 8 --recloss --mmd --bnz --outf /path/to/results/ --lbd 100 --batchSize 256 --detenc --useenc \
    --test_every 20 --addsamples 10000 --manualSeed 321

Wasserstein++:

python train.py --dataset celeba --dataroot /path/to/traindata/ --testroot /path/to/testdata/ --cuda --nz 128 \
    --sigmasqz 1.0 --lr_eg 0.0003 --niter 165 --decay_steps 100 140 --decay_gamma 0.4 --check_every 100 \
    --workers 6 --recloss --mmd --bnz --outf /path/to/results/ --lbd 100 --batchSize 256 --detenc --useenc \
    --test_every 20 --addsamples 10000 --manualSeed 321 --wganloss --useencdist --lbd_di 0.000025 --intencprior

To learn the rate-constrained encoder and the stochastic mapping run the following (parameters again for the experiment on the CelebA data set):

python train.py --dataset celeba --dataroot /path/to/traindata/ --testroot /path/to/testdata/ --cuda --nz 128 \
    --sigmasqz 1.0 --lr_eg 0.001 --niter 55 --decay_steps 30 50 --decay_gamma 0.4 --check_every 100 \
    --workers 6 --recloss --mmd --bnz --batchSize 256 --useenc --comp --freezedec --test_every 100 \
    --addsamples 10000 --manualSeed 321 --outf /path/to/results/ --netG /path/to/trained/generator \
    --nresenc 2 --lbd 300 --ncenc 8

Here, --ncenc determines the number of channels at the encoder output (and hence the bitrate) and --lbd determines the regularization strength of the MMD penalty on the latent space (has to be adapted as a function of the bitrate).

In the paper we also consider the LSUN bedrooms data set. We provide the flag --lsun_custom_split that splits off 10k samples for the LSUM training set (the LSUN testing set is too small to compute the FID score to asses sample quality). Otherwise, training on the LSUN data set is as outlined above (with different parameters).

Citation

If you use this code for your research, please cite this paper:

@inproceedings{tschannen2018deep,
    Author = {Tschannen, Michael and Agustsson, Eirikur and Lucic, Mario},
    Title = {Deep Generative Models for Distribution-Preserving Lossy Compression},
    Booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
    Year = {2018}}

dplc's People

Contributors

mitscha avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

dplc's Issues

it is biased to output the mean

Hello,
during reading the paper, I could not understand why it is biased to output the mean. I would be grateful if you could explain it.

WX20210321-190050@2x

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.