GithubHelp home page GithubHelp logo

frgfm / pokegan Goto Github PK

View Code? Open in Web Editor NEW
2.0 4.0 0.0 9.86 MB

Implementations of GANs in PyTorch for Pokemon image generation

License: MIT License

Python 100.00%
gan dcgan pytorch tensorboard pokemon deep-learning dc-gan spectral-normalization

pokegan's Introduction

PokeGAN

License Build Status codecov

This repository is an ongoing implementation of shallow GAN architectures to generate Pokemons using PyTorch.

all_samples

Table of Contents

Getting started

Prerequisites

  • Python 3.6 (or more recent)
  • pip
  • Pokemon dataset from Kaggle, or a static fallback provided by this repo. Please note that if you use the original Kaggle version, it's better to switch the JPG images to PNG format (to avoid transparency handling later on)

Installation

You can install the project requirement as follows:

git clone https://github.com/frgfm/PokeGAN.git
pip install -r PokeGAN/requirements.txt

or install it as a package:

pip install git+https://github.com/frgfm/PokeGAN.git

Usage

There are two available training script: main.py for classic DCGAN, and progan.py for ProGAN training. You can use the --help flag to get more advanced usage instructions.

usage: main.py [-h] [--size SIZE] [--device DEVICE] [--lr LR] [--dropout DROPOUT] [--z-size Z_SIZE]
               [--latent-size LATENT_SIZE] [--wd WEIGHT_DECAY] [--ls LABEL_SMOOTHING]
               [--noise NOISE] [--swap SWAP] [-b BATCH_SIZE] [--epochs EPOCHS] [-j WORKERS]
               data_path

Pokemon GAN Training

positional arguments:
  data_path             path to dataset folder

optional arguments:
  -h, --help            show this help message and exit
  --size SIZE           Image size to produce (default: 64)
  --device DEVICE       device (default: 0)
  --lr LR               initial learning rate (default: 0.001)
  --dropout DROPOUT     dropout rate (default: 0.3)
  --z-size Z_SIZE       number of features fed to the generator (default: 96)
  --latent-size LATENT_SIZE
                        latent feature map size (default: 4)
  --wd WEIGHT_DECAY, --weight-decay WEIGHT_DECAY
                        weight decay (default: 0)
  --ls LABEL_SMOOTHING, --label-smoothing LABEL_SMOOTHING
                        label smoothing (default: 0.1)
  --noise NOISE         Norm of the noise added to labels (default: 0.1)
  --swap SWAP           Probability of swapping labels (default: 0.03)
  -b BATCH_SIZE, --batch-size BATCH_SIZE
                        batch size (default: 32)
  --epochs EPOCHS       number of total epochs to run (default: 400)
  -j WORKERS, --workers WORKERS
                        number of data loading workers (default: 16)
  --output-file OUTPUT_FILE
                        path where to save (default: ./gan.pth)

Architecture & training scheme

Architecture

Similar to DCGAN, but with weight initialization using normal distribution rather than uniform distribution. The Discriminator and the generators have mirrored architectures for downsampling and upsampling.

Tried InstanceNorm rather than BatchNorm but the latter proved to be more effective.

Training scheme

progan_scheme

Source: Progressive Growing of GANs for improved quality, stability, and variation, ICLR 2018

Using the idea suggested by ProGAN, the implementation include a progressive training scheme:

  • We select a target image size and a starting size
  • Each stage is characterized by a single output size ([16, 32, 64] for instance).
  • Then each stage goes through a training cycle (sequence of epochs with identical learning rate, eg [dict(lr=5e-4, nb_epochs=100), dict(lr=2e-4, nb_epochs=200)])
  • When the training cycle is over, the output size is doubled. We recreate the discriminator and generators for this size (it adds a sampling layer to each network), load the weights learned in the previous stage to the appropriate layer, and freeze those already trained layers.
  • The training stops when the training cycle with target size is over.

Experiments

Things that were tested to improve training:

  • Normalization: batch normalization, instance normalization, local response normalization, layer normalization, spectral normalization.
  • Kernel size: raising the generator kernel size from 3x3 to 5x5 to get smoother results.
  • Learning rate: scaling down the learning rate by a factor (<= 1) for the generator leads to easier convergence towards a local Nash Equilibrium.
  • GAN loss: standard loss, relativistic loss, relativistic average loss

And other tricks to be implemented soon:

  • Gradient penalty & consistency term

Results

Stage 1 (16x16 images)

Samples

stage1_samples

Gradient flow

stage1_gradflow

Stage 2 (32x32 images)

Samples

stage2_samples

Gradient flow

stage2_gradflow

Stage 3 (64x64 images)

Samples (mode collapse)

stage3_samples

Gradient flow

stage3_gradflow

Contributing

Regarding issues, use the following format for the title:

[Topic] Your Issue name

Example:

[State saving] Add a feature to automatically save and load model states

Credits

License

Distributed under the MIT License. See LICENSE for more information.

pokegan's People

Contributors

frgfm avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pokegan's Issues

[Spectral normalization] Inter-stage weight transfer is not working

After successful first stage training, the generated samples suddenly become completely grey as if all neurons died, when a norm_fn is selected.

Leads to investigate: transfer of weights to the extended networks, parameter freezing

Last samples of first stage (16x16):
stage16_epoch800

First samples of the second stage (32x32):
stage32_epoch900

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.