GithubHelp home page GithubHelp logo

facebookresearch / fadernetworks Goto Github PK

View Code? Open in Web Editor NEW
757.0 29.0 115.0 7.18 MB

Fader Networks: Manipulating Images by Sliding Attributes - NIPS 2017

License: Other

Python 99.12% Shell 0.88%

fadernetworks's Introduction

FaderNetworks

PyTorch implementation of Fader Networks (NIPS 2017).

Fader Networks can generate different realistic versions of images by modifying attributes such as gender or age group. They can swap multiple attributes at a time, and continuously interpolate between each attribute value. In this repository we provide the code to reproduce the results presented in the paper, as well as trained models.

Single-attribute swap

Below are some examples of different attribute swaps:

Multi-attributes swap

The Fader Networks are also designed to disentangle multiple attributes at a time:

Model

The main branch of the model (Inference Model), is an autoencoder of images. Given an image x and an attribute y (e.g. male/female), the decoder is trained to reconstruct the image from the latent state E(x) and y. The other branch (Adversarial Component), is composed of a discriminator trained to predict the attribute from the latent state. The encoder of the Inference Model is trained not only to reconstruct the image, but also to fool the discriminator, by removing from E(x) the information related to the attribute. As a result, the decoder needs to consider y to properly reconstruct the image. During training, the model is trained using real attribute values, but at test time, y can be manipulated to generate variations of the original image.

Dependencies

Installation

Simply clone the repository:

git clone https://github.com/facebookresearch/FaderNetworks.git
cd FaderNetworks

Dataset

Download the aligned and cropped CelebA dataset from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. Extract all images and move them to the data/img_align_celeba/ folder. There should be 202599 images. The dataset also provides a file list_attr_celeba.txt containing the list of the 40 attributes associated with each image. Move it to data/. Then simply run:

cd data
./preprocess.py

It will resize images, and create 2 files: images_256_256.pth and attributes.pth. The first one contains a tensor of size (202599, 3, 256, 256) containing the concatenation of all resized images. Note that you can update the image size in preprocess.py to work with different resolutions. The second file is a pre-processed version of the attributes.

Pretrained models

You can download pretrained classifiers and Fader Networks by running:

cd models
./download.sh

Train your own models

Train a classifier

To train your own model you first need to train a classifier to let the model evaluate the swap quality during the training. Training a good classifier is relatively simple for most attributes, and a good model can be trained in a few minutes. We provide a trained classifier for all attributes in models/classifier256.pth. Note that the classifier does not need to be state-of-the-art, it is not used during the training process, but is just here to monitor the swap quality. If you want to train your own classifier, you can run classifier.py, using the following parameters:

python classifier.py

# Main parameters
--img_sz 256                  # image size
--img_fm 3                    # number of feature maps
--attr "*"                    # attributes list. "*" for all attributes

# Network architecture
--init_fm 32                  # number of feature maps in the first layer
--max_fm 512                  # maximum number of feature maps
--hid_dim 512                 # hidden layer size

# Training parameters
--v_flip False                # randomly flip images vertically (data augmentation)
--h_flip True                 # randomly flip images horizontally (data augmentation)
--batch_size 32               # batch size
--optimizer "adam,lr=0.0002"  # optimizer
--clip_grad_norm 5            # clip gradient L2 norm
--n_epochs 1000               # number of epochs
--epoch_size 50000            # number of images per epoch

# Reload
--reload ""                   # reload a trained classifier
--debug False                 # debug mode (if True, load a small subset of the dataset)

Train a Fader Network

You can train a Fader Network with train.py. The autoencoder can receive feedback from:

  • The image reconstruction loss
  • The latent discriminator loss
  • The PatchGAN discriminator loss
  • The classifier loss

In the paper, only the first two losses are used, but the two others could improve the results further. You can tune the impact of each of these losses with the lambda_ae, lambda_lat_dis, lambda_ptc_dis, and lambda_clf_dis coefficients. Below is a complete list of all parameters:

# Main parameters
--img_sz 256                      # image size
--img_fm 3                        # number of feature maps
--attr "Male"                     # attributes list. "*" for all attributes

# Networks architecture
--instance_norm False             # use instance normalization instead of batch normalization
--init_fm 32                      # number of feature maps in the first layer
--max_fm 512                      # maximum number of feature maps
--n_layers 6                      # number of layers in the encoder / decoder
--n_skip 0                        # number of skip connections
--deconv_method "convtranspose"   # deconvolution method
--hid_dim 512                     # hidden layer size
--dec_dropout 0                   # dropout in the decoder
--lat_dis_dropout 0.3             # dropout in the latent discriminator

# Training parameters
--n_lat_dis 1                     # number of latent discriminator training steps
--n_ptc_dis 0                     # number of PatchGAN discriminator training steps
--n_clf_dis 0                     # number of classifier training steps
--smooth_label 0.2                # smooth discriminator labels
--lambda_ae 1                     # autoencoder loss coefficient
--lambda_lat_dis 0.0001           # latent discriminator loss coefficient
--lambda_ptc_dis 0                # PatchGAN discriminator loss coefficient
--lambda_clf_dis 0                # classifier loss coefficient
--lambda_schedule 500000          # lambda scheduling (0 to disable)
--v_flip False                    # randomly flip images vertically (data augmentation)
--h_flip True                     # randomly flip images horizontally (data augmentation)
--batch_size 32                   # batch size
--ae_optimizer "adam,lr=0.0002"   # autoencoder optimizer
--dis_optimizer "adam,lr=0.0002"  # discriminator optimizer
--clip_grad_norm 5                # clip gradient L2 norm
--n_epochs 1000                   # number of epochs
--epoch_size 50000                # number of images per epoch

# Reload
--ae_reload ""                    # reload pretrained autoencoder
--lat_dis_reload ""               # reload pretrained latent discriminator
--ptc_dis_reload ""               # reload pretrained PatchGAN discriminator
--clf_dis_reload ""               # reload pretrained classifier
--eval_clf ""                     # evaluation classifier (trained with classifier.py)
--debug False                     # debug mode (if True, load a small subset of the dataset)

Generate interpolations

Given a trained model, you can use it to swap attributes of images in the dataset. Below are examples using the pretrained models:

# Narrow Eyes
python interpolate.py --model_path models/narrow_eyes.pth --n_images 10 --n_interpolations 10 --alpha_min 10.0 --alpha_max 10.0 --output_path narrow_eyes.png

# Eyeglasses
python interpolate.py --model_path models/eyeglasses.pth --n_images 10 --n_interpolations 10 --alpha_min 2.0 --alpha_max 2.0 --output_path eyeglasses.png

# Age
python interpolate.py --model_path models/young.pth --n_images 10 --n_interpolations 10 --alpha_min 10.0 --alpha_max 10.0 --output_path young.png

# Gender
python interpolate.py --model_path models/male.pth --n_images 10 --n_interpolations 10 --alpha_min 2.0 --alpha_max 2.0 --output_path male.png

# Pointy nose
python interpolate.py --model_path models/pointy_nose.pth --n_images 10 --n_interpolations 10 --alpha_min 10.0 --alpha_max 10.0 --output_path pointy_nose.png

These commands will generate images with 10 rows of 12 columns with the interpolated images. The first column corresponds to the original image, the second is the reconstructed image (without alteration of the attribute), and the remaining ones correspond to the interpolated images. alpha_min and alpha_max represent the range of the interpolation. Values superior to 1 represent generations over the True / False range of the boolean attribute in the model. Note that the variations of some attributes may only be noticeable for high values of alphas. For instance, for the "eyeglasses" or "gender" attributes, alpha_max=2 is usually enough, while for the "age" or "narrow eyes" attributes, it is better to go up to alpha_max=10.

References

If you find this code useful, please consider citing:

Fader Networks: Manipulating Images by Sliding Attributes - G. Lample, N. Zeghidour, N. Usunier, A. Bordes, L. Denoyer, M'A. Ranzato

@inproceedings{lample2017fader,
  title={Fader Networks: Manipulating Images by Sliding Attributes},
  author={Lample, Guillaume and Zeghidour, Neil and Usunier, Nicolas and Bordes, Antoine and DENOYER, Ludovic and others},
  booktitle={Advances in Neural Information Processing Systems},
  pages={5963--5972},
  year={2017}
}

Contact: [email protected], [email protected]

fadernetworks's People

Contributors

glample avatar sufuf3 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

fadernetworks's Issues

Memory usage during pre-processing

Hi,
It takes more than 100GBs of RAM during pre-processing the PNGs. Is there a less memory intensive version available somewhere ?
Thanks in advance

interpolate on own images

Feature request: enable an input_path or input_file on which we can run the model to alter the input image. It would be best if an easy to run example would be provided :)

fader network training failure

I used exactly the same params as yours in README.md to train the fader network to manipulate "EyeGlasses" attribute. But the clf_accu_Eyeglasses_1 always decreased to a low value at around 750 epoch, like this:
INFO - 10/18/18 12:58:35 - 1:41:47 - 048928 - Latent discriminator : 0.11062 / Reconstruction loss : 0.00315
INFO - 10/18/18 12:58:39 - 1:41:52 - 049728 - Latent discriminator : 0.05209 / Reconstruction loss : 0.00319
INFO - 10/18/18 12:58:41 - 1:41:53 -
INFO - 10/18/18 12:59:20 - 1:42:33 - Latent discriminator accuracy:
INFO - 10/18/18 12:59:20 - 1:42:33 - lat_dis_accu : 77.984%
INFO - 10/18/18 12:59:20 - 1:42:33 - lat_dis_accu_Eyeglasses: 77.984%
INFO - 10/18/18 12:59:20 - 1:42:33 -
INFO - 10/18/18 13:00:25 - 1:43:38 - Classifier accuracy:
INFO - 10/18/18 13:00:25 - 1:43:38 - clf_accu : 50.873%
INFO - 10/18/18 13:00:25 - 1:43:38 - clf_accu_Eyeglasses : 50.873%
INFO - 10/18/18 13:00:25 - 1:43:38 - clf_accu_Eyeglasses_0: 94.946%
INFO - 10/18/18 13:00:25 - 1:43:38 - clf_accu_Eyeglasses_1: 6.800%
INFO - 10/18/18 13:00:25 - 1:43:38 -
INFO - 10/18/18 13:00:25 - 1:43:38 - Autoencoder loss: 0.00326

And the best evaluation accuracy ever riched is :
INFO - 10/17/18 21:27:23 - 3 days, 4:35:00 - Latent discriminator accuracy:
INFO - 10/17/18 21:27:23 - 3 days, 4:35:00 - lat_dis_accu : 93.114%
INFO - 10/17/18 21:27:23 - 3 days, 4:35:00 - lat_dis_accu_Eyeglasses: 93.114%
INFO - 10/17/18 21:27:23 - 3 days, 4:35:00 -
INFO - 10/17/18 21:28:28 - 3 days, 4:36:05 - Classifier accuracy:
INFO - 10/17/18 21:28:28 - 3 days, 4:36:05 - clf_accu : 79.828%
INFO - 10/17/18 21:28:28 - 3 days, 4:36:05 - clf_accu_Eyeglasses : 79.828%
INFO - 10/17/18 21:28:28 - 3 days, 4:36:05 - clf_accu_Eyeglasses_0: 98.878%
INFO - 10/17/18 21:28:28 - 3 days, 4:36:05 - clf_accu_Eyeglasses_1: 60.779%
INFO - 10/17/18 21:28:28 - 3 days, 4:36:05 -
INFO - 10/17/18 21:28:28 - 3 days, 4:36:05 - Autoencoder loss: 0.00338
INFO - 10/17/18 21:28:28 - 3 days, 4:36:05 - Best evaluation accuracy: 0.79828

here is the params I used:
--img_sz 256
--img_fm 3
--attr "Eyeglasses"

--instance_norm False
--init_fm 32
--max_fm 512
--n_layers 6
--n_skip 0
--deconv_method "convtranspose"
--hid_dim 512
--dec_dropout 0
--lat_dis_dropout 0.3

--n_lat_dis 1
--n_ptc_dis 0
--n_clf_dis 0
--smooth_label 0.2
--lambda_ae 1
--lambda_lat_dis 0.0001
--lambda_ptc_dis 0
--lambda_clf_dis 0
--lambda_schedule 500000
--v_flip False
--h_flip True
--batch_size 32
--ae_optimizer "adam,lr=0.0002"
--dis_optimizer "adam,lr=0.0002"
--clip_grad_norm 5
--n_epochs 1000
--epoch_size 50000

--ae_reload ""
--lat_dis_reload ""
--ptc_dis_reload ""
--clf_dis_reload ""
--eval_clf "models/default/kjrite0bvw/best.pth"
--debug False

and the pytorch version is 0.4.1. Any idea why this happened? Thanks.

multiple attributes

Hi! How you deal with multiple attributes, it seems that we could not simply flip the attributes (e.g., 0->1 or 1->0) to calculate the loss.

Cannot run pretrained models

Hi, I was not able to run the pre-trained models. In particular, I tried to use only one image from CelebA dataset and I got this error:
(venv) D:\PycharmProjects\FaderNetwork\venv>python interpolate.py --model_path models/young.pth --n_images 10 --n_interpolations
10 --alpha_min 10.0 --alpha_max 10.0 --output_path young.png

INFO - 08/04/18 01:01:49 - 0:00:00 - 1 / 0 / 0 images with attributes for train / valid / test sets
Traceback (most recent call last):
File "interpolate.py", line 64, in
data, attributes = load_images(params)
File "D:\PycharmProjects\FaderNetwork\venv\src\loader.py", line 82, in load_images
log_attributes_stats(train_attributes, valid_attributes, test_attributes, params)
File "D:\PycharmProjects\FaderNetwork\venv\src\loader.py", line 39, in log_attributes_stats
logger.debug('Valid %s: %s' % (attr_name, ' / '.join(['%.5f' % valid_attributes[:, k + i].mean() for i in range(n_cat)])))
File "D:\PycharmProjects\FaderNetwork\venv\src\loader.py", line 39, in
logger.debug('Valid %s: %s' % (attr_name, ' / '.join(['%.5f' % valid_attributes[:, k + i].mean() for i in range(n_cat)])))
IndexError: too many indices for tensor of dimension 1

Could you please help me?

The model must use a single boolean attribute only

Hi, thanks for providing the source code of the paper.

I just tried training on the default settings except I ran preprocess.py to generate 128_128 images.

I see the default settings of train.py have Smiling and Male set as attributes, which I assume are the attributes to separate entanglement between during training. (OR maybe I understand it incorrectly, the code always automatically disentangles between all attributes available in list_attr_celeba.txt no matter what is set for --attr)

parser.add_argument("--attr", type=attr_flag, default="Smiling,Male",
                    help="Attributes to classify")
fadernetworks $ python train.py
INFO - 03/19/18 13:50:56 - 0:00:00 - ============ Initialized logger ============
INFO - 03/19/18 13:50:56 - 0:00:00 - ae_optimizer: adam,lr=0.0002
                                     ae_reload:
                                     attr: [('Male', 2), ('Smiling', 2)]
                                     batch_size: 32
                                     clf_dis_reload:
                                     clip_grad_norm: 5
                                     debug: False
                                     dec_dropout: 0.0
                                     deconv_method: convtranspose
                                     dis_optimizer: adam,lr=0.0002
                                     dump_path: ./models/default/aq8phwtw2o
                                     epoch_size: 50000
                                     eval_clf: models/classifier128.pth
                                     h_flip: True
                                     hid_dim: 512
                                     img_fm: 3 
                                     img_sz: 128
                                     init_fm: 32
                                     instance_norm: False
                                     lambda_ae: 1
                                     lambda_clf_dis: 0
                                     lambda_lat_dis: 0.0001
                                     lambda_ptc_dis: 0
                                     lambda_schedule: 500000
                                     lat_dis_dropout: 0.3
                                     lat_dis_reload:
                                     max_fm: 512
                                     n_attr: 4 
                                     n_clf_dis: 0
                                     n_epochs: 1000
                                     n_lat_dis: 1
                                     n_layers: 6
                                     n_ptc_dis: 0
                                     n_skip: 0 
                                     name: default
                                     ptc_dis_reload:
                                     smooth_label: 0.2
                                     v_flip: False

Then I used my saved snapshots after training, male_smiling.pth, to interpolate images with the default settings of interpolate.py and got the following error. Any ideas how to run this interpolate.py, or what should I change?

models $ cp default/aq8phwtw2o/best_rec_ae.pth ./male_smiling.pth
models $ cd ..
fadernetworks $ python interpolate.py 
Traceback (most recent call last):
  File "interpolate.py", line 60, in <module>
    raise Exception("The model must use a single boolean attribute only.")
Exception: The model must use a single boolean attribute only.

jpg vs png CelebA images

The CelebA images are in .png format. But the preprocess.py file assumes they are .jpg, generating an error.

pytorch version?

Which version of pytorch is required to run the code? When I try 0.3.0 I get this attribute error trying to run the interpolate example with downloaded models:

AttributeError: Can't get attribute '_rebuild_tensor_v2' on <module 'torch._utils'>

When I try 0.4 I get a different attribute error:

AttributeError: 'BatchNorm2d' object has no attribute 'track_running_stats'

Presumably it's some version inbetween but I can't work it out.

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.