GithubHelp home page GithubHelp logo

akanimax / variational_discriminator_bottleneck Goto Github PK

View Code? Open in Web Editor NEW
152.0 8.0 31.0 40.05 MB

Implementation (with some experimentation) of the paper titled "VARIATIONAL DISCRIMINATOR BOTTLENECK: IMPROVING IMITATION LEARNING, INVERSE RL, AND GANS BY CONSTRAINING INFORMATION FLOW" (arxiv -> https://arxiv.org/pdf/1810.00821.pdf)

License: MIT License

Python 100.00%

variational_discriminator_bottleneck's Introduction

Variational_Discriminator_Bottleneck

Implementation (with some experimentation) of the paper titled "VARIATIONAL DISCRIMINATOR BOTTLENECK: IMPROVING IMITATION LEARNING, INVERSE RL, AND GANS BY CONSTRAINING INFORMATION FLOW" (arxiv -> https://arxiv.org/pdf/1810.00821.pdf)

Implementation uses the PyTorch framework.

VGAN architecture:

detailed_architecture


The core concept proposed by the paper is to enforce an Information Bottleneck between the Input images and the Discriminator’s internal representation of them.

As shown in the diagram, the Discriminator is divided into two parts now: An Encoder and the actual Discriminator. Note that the Generator is still the same. The Encoder is modelled using a ResNet similar in architecture to the Generator, while the Discriminator is a simple Linear classifier. Note that the Encoder doesn't output the internal codes of the images, but similar to a VAE’s encoder, gives the means and stds of the distributions from which samples are drawn and fed to discriminator.

CelebA 128x128 Experiment

I trained the VGAN-GP (just replace the normal GAN loss with WGAN-GP) on the CelebA dataset and the results are shown below.

generated samples


The value for Ic that I used is 0.2 as described in the paper and the architectures for G and D are also as described in the paper. The authors trained the model for 300K iterations, but the results that I displayed are at 62K iterations which took me 22.5 hours to train. I will be training them further, but I would really like the readers and enthusiasts to take this forward as I have made the code open-source.

Loss plot:

Loss Plot


Running the Code

Running the training is actually very simple. Just start the training by running the train.py script in the source/ directory. The test/ directory contains the unit tests if you would like to change anything about the implementation Refer to the following parameters for tweaking for your own use:

-h, --help            show this help message and exit
--generator_file GENERATOR_FILE
                    pretrained weights file for generator
--gen_optim_file GEN_OPTIM_FILE
                    previously saved state of Generator Optimizer
--discriminator_file DISCRIMINATOR_FILE
                    pretrained_weights file for discriminator
--dis_optim_file DIS_OPTIM_FILE
                    previously saved state of Generator Optimizer
--images_dir IMAGES_DIR
                    path for the images directory
--folder_distributed_dataset FOLDER_DISTRIBUTED_DATASET
                    path for the images directory
--sample_dir SAMPLE_DIR
                    path for the generated samples directory
--model_dir MODEL_DIR
                    path for saved models directory
--loss_function LOSS_FUNCTION
                    loss function to be used: 'hinge', 'relativistic-
                    hinge', 'standard-gan', 'standard-gan_with-sigmoid',
                    'wgan-gp', 'lsgan'
--size SIZE           Size of the generated image (must be a power of 2 and
                    >= 4)
--latent_distrib LATENT_DISTRIB
                    Type of latent distribution to be used 'uniform' or
                    'gaussian'
--latent_size LATENT_SIZE
                    latent size for the generator
--final_channels FINAL_CHANNELS
                    starting number of channels in the networks
--max_channels MAX_CHANNELS
                    maximum number of channels in the network
--init_beta INIT_BETA
                    initial value of beta
--i_c I_C             value of information bottleneck
--batch_size BATCH_SIZE
                    batch_size for training
--start START         starting epoch number
--num_epochs NUM_EPOCHS
                    number of epochs for training
--feedback_factor FEEDBACK_FACTOR
                    number of logs to generate per epoch
--num_samples NUM_SAMPLES
                    number of samples to generate for creating the grid
                    should be a square number preferably
--checkpoint_factor CHECKPOINT_FACTOR
                    save model per n epochs
--g_lr G_LR           learning rate for generator
--d_lr D_LR           learning rate for discriminator
--data_percentage DATA_PERCENTAGE
                    percentage of data to use
--num_workers NUM_WORKERS
                    number of parallel workers for reading files

Please Note that all the default values are tuned for the CelebA 128x128 experiment. Please refer to the paper for the CIFAR-10 and CelebA-HQ experiments.

Trained weights for generating cool faces / resuming the training :)

Please refer to the shared drive for the saved weights for this model in PyTorch format.

Other links

medium blog -> https://medium.com/@animeshsk3/v-gan-variational-discriminator-bottleneck-an-unfair-fight-between-generator-and-discriminator-972563532dcc
Generated samples video -> https://www.youtube.com/watch?v=-0lBw9z8Ds0
My slack group -> https://join.slack.com/t/amlrldl/shared_invite/enQtNDcyMTIxODg3NjIzLTA3MTlmMDg0YmExYjY5OTgyZTg4MTg5ZGE1YzRlYjljZmM4MzI0MTg1OTcxOTc5NDQ4ZTcwMGVkZjBjZmU5ZWM

Thanks

Please feel free to open Issues / PRs here

Cheers 🍻!
@akanimax :)

variational_discriminator_bottleneck's People

Contributors

akanimax avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

variational_discriminator_bottleneck's Issues

Not quite understand the _bottleneck_loss function in the vdb.Losses.py

Hi, after reading your codes, I have the following questions:

  1. Is mutual information the same as KL divergence and conditional entropy?

  2. I googled "KL divergence python implementation", but didn't find something similar to your code. They look like something like the following:
    np.sum(np.where(p != 0, p * np.log(p / q), 0))
    So could you explain why you calc KL divergence using:
    kl_divergence = (0.5 * th.sum((mus ** 2) + (sigmas ** 2) - th.log((sigmas ** 2) + alpha) - 1, dim=1))

or could you provide some relevant web links/blogs/videos explaining this?

Thanks in advance. :) BTW, your code is great, which is exactly what I am looking for.

some problem

Hello, there is some problem with my generation result, only some color blocks can be generated, do you know anything about this problem?

gen_3_95
gen_11_139
gen_16_128

It seems to be spawning in a good direction, but it is very slow.

The KL divergence is same as in Variational Auto Encoder ?

Hi,
Your implementation is perfect and so informative. In the paper, they clearly mentioned how they simplified the mutual information constraints into KL divergence w.r.t a Gaussian distribution. So here we take the KL divergence between two distributions rather than two random variables. I think this is the same technique that they have used in VAE calling variational inference. Is my understanding is correct?

Actually, the difference is in VAE we do not explicitly talk about the information constrain and directly optimizing w.r.t the KL divergence loss. But with the I_c we have to use a Lagrangian multiplier and update the beta adaptive.

Is my understanding is correct?

How sensitive the initial value for the beta parameter ?

Hi,
Here, you have used a max operation to adaptively update the beta parameters as given in the paper. So first we update the discriminator loss by using wgan-gp loss and bottleneck loss that multiplied with the beta parameter. After that, we update the beta parameter with the max operation. But theoretically when we updating the parameters with the dual gradients. We need to first optimize the discriminator loss and then use that optimized function and calculate the gradient w.r.t beta parameters and update it. Here why they use a max operation? Is that to make sure a monotonic improvement? Or is that to keep the effect of the bottleneck loss to the wgan-gp loss alway positive keeping an upper bound?

Resizing support/progressive growing support

Hiya. I've been experimenting with generating anime faces with GANs for years now, and I've been trying out your GAN implementation to compare with another (more complex) VGAN implementation by nsheppard (currently called PokeGAN), using ~137k anime faces extracted from my Danooru2017 dataset. The results so far are as good as anything I've gotten aside from Nvidia's ProGAN implementation; samples from a few minutes ago after about 2 days:

Epoch 37 sample of anime faces with akanimax's VGAN

For speed, I am using 128px. At some point when it's converged, it'd be nice to switch to 256px and then 512px without restarting from scratch. Some support for adding on additional layers would be great; the exact 'blending' ProGAN does might be a bit hard to implement, but almost as good would be freezing the original layers for a while and training only the new ones; or even just slapping on more layers would be better than restarting from scratch (given how minibatches go down drastically with increased resolution).


On a side note, there's one trick nsheppard found which might be useful. We don't know if the official VGAN code does this because it seems kind of obvious in retrospect but they haven't released source AFAIK. But in the bottleneck loss with the KL loss, it seems that occasionally i_c can be larger than the KL loss, in which case the loss is then negative? It's not clear what this means or how it might be desirable, so he just got rid of it by bounding it to zero:

-        bottleneck_loss = (th.mean(kl_divergence) - i_c)
+        bottleneck_loss = max(0, (th.mean(kl_divergence) - i_c)) # EDIT: per nsheppard, ensure that the loss can't be negative

It seems to permit stable training with much higher values of i_c.

Data Loader outputting a different batch_size after the last iteration of an Epoch.

I have followed how you have coded the custom data loader. It is amazing and I actually used it to for different other codes. But I found one interesting fact. Lest assume data is the output from a custom data loader. So len(iter(data)) gives the total number of batches inside the entire data set. Let's say I have 101 data examples and my batch size is 10. Here len(iter(data)) gives me 11 (I thinks it is a ceil operation). So the 11th batch only samples 1 value. Then if you have a dynamic operation like calculating the GP in WGAN loss you need the batch_size. In case if you fix the batch size you will get an error. But you have used nice tricks such as limit and getting the batch size dynamically line 1, line 2. Please tell me whether my understanding is correct or wrong?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.