GithubHelp home page GithubHelp logo

stephenlove / vitgan-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from teodortoshkov/vitgan-pytorch

0.0 0.0 0.0 81 KB

A PyTorch implementation of VITGAN: Training GANs with Vision Transformers

License: MIT License

Jupyter Notebook 60.07% Python 39.93%

vitgan-pytorch's Introduction

ViTGAN-pytorch

A PyTorch implementation of VITGAN: Training GANs with Vision Transformers

Open In Colab

TODO:

  1. Use vectorized L2 distance in attention for Discriminator
  2. Overlapping Image Patches
  3. DiffAugment
  4. Self-modulated LayerNorm (SLN)
  5. Implicit Neural Representation for Patch Generation
  6. ExponentialMovingAverage (EMA)
  7. Balanced Consistency Regularization (bCR)
  8. Improved Spectral Normalization (ISN)
  9. Equalized Learning Rate
  10. Weight Modulation

Dependencies

  • Python3
  • einops
  • pytorch_ema
  • stylegan2-pytorch
  • tensorboard
  • wandb
pip install einops git+https://github.com/fadel/pytorch_ema stylegan2-pytorch tensorboard wandb

TLDR:

Train the model with the proposed parameters:

python main.py

Tensorboard

tensorboard --logdir runs/

The following parameters are the parameters, proposed in the paper for the CIFAR-10 dataset:

python main.py

Implementation Details

Generator

The Generator follows the following architecture:

ViTGAN Generator architecture

For debugging purposes, the Generator is separated into a Vision Transformer (ViT) model and a SIREN model.

Given a seed, the dimensionality of which is controlled by latent_dim, the ViT model creates an embedding for each of the patches of the final image. Those embeddings are fed to a SIREN network, combined with a Fourier Position Encoding (Jupyter Notebook). It outputs the patches of the image, which are stitched together.

The ViT part of the Generator differs from a standard Vision Transformer in the following ways:

  • The input to the Transformer consists only of the position embeddings
  • Self-Modulated Layer Norm (SLN) is used in place of LayerNorm
  • There is no classification head

SLN is the only place, where the seed is inputted to the network.
SLN consists of a regular LayerNorm, the result of which is multiplied by gamma and added to beta.
Both gamma and beta are calculated using a fully connected layer, different for each place, SLN is applied.
The input dimension to each of those fully connected is equal to hidden_dimension and the output dimension is equal to hidden_dimension.

SIREN

A description of SIREN: [Blog Post] [Paper] [Colab Notebook]

In contrast to regular SIREN, the desired output is not a single image. For this purpose, the patch embedding is combined to a position embedding.

The positional encoding, used in ViTGAN is the Fourier Position Encoding, the code for which was taken from here: (Jupyter Notebook)

In my implementation, the input to the SIREN is the sum of a patch embedding and a position embedding.

Weight Modulation

Weight Modulation usually consists of a modulation and a demodulation module. After testing the network, I concluded that demodulation is not used in ViTGAN.

My implementation of the weight modulation is heavily based on CIPS. I have adjusted it to work for a fully-connected network, using a 1D convolution. The reason for using 1D convolution, instead of a linear layer is the groups term, which optimizes the performance by a factor of batch_size.

Each SIREN layer consists of a sinsin activation, applied to a weight modulation layer. The size of the input, the hidden and the output layers in a SIREN network could vary. Thus, in case the input size differs from the size of the patch embedding, I define an additional fully-connected layer, which converts the patch embedding to the appropriate size.

Discriminator

The Discriminator follows the following architecture:

ViTGAN Discriminator architecture

The ViTGAN Discriminator is mostly a standard Vision Transformer network, with the following modifications:

  • DiffAugment
  • Overlapping Image Patches
  • Use vectorized L2 distance in attention for Discriminator
  • Improved Spectral Normalization (ISN)
  • Balanced Consistency Regularization (bCR)

DiffAugment

For implementating DiffAugment, I used the code below:
[GitHub] [Paper]

Overlapping Image Patches

Creation of the overlapping image patches is implemented with the use of a convolution layer.

Use vectorized L2 distance in attention for Discriminator

[Paper]

Improved Spectral Normalization (ISN)

The ISN implementation is based on the following implementation of Spectral Normalization:
[GitHub] [Paper]

Balanced Consistency Regularization (bCR)

Zhengli Zhao, Sameer Singh, Honglak Lee, Zizhao Zhang, Augustus Odena, Han Zhang; Improved Consistency Regularization for GANs; AAAI 2021 [Paper]

References

SIREN: Implicit Neural Representations with Periodic Activation Functions
Vision Transformer: [Blog Post]
L2 distance attention: The Lipschitz Constant of Self-Attention
Spectral Normalization reference code: [GitHub] [Paper]
Diff Augment: [GitHub] [Paper]
Fourier Position Embedding: [Jupyter Notebook]
Exponential Moving Average: [GitHub]
Balanced Concictancy Regularization (bCR): [Paper]
SyleGAN2 Discriminator: [GitHub]

vitgan-pytorch's People

Contributors

teodortoshkov avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.