GithubHelp home page GithubHelp logo

nvlabs / deepinversion Goto Github PK

View Code? Open in Web Editor NEW
474.0 24.0 77.0 7.11 MB

Official PyTorch implementation of Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion (CVPR 2020)

License: Other

Python 100.00%

deepinversion's Introduction

Python 3.6

Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion

This repository is the official PyTorch implementation of Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion presented at CVPR 2020.

The code will help to invert images from models of torchvision (pretrained on ImageNet), and run the images over another model to check generalization. We plan to update repo with CIFAR10 examples and teacher-student training.

Useful links:

Teaser

License

Copyright (C) 2020 NVIDIA Corporation. All rights reserved.

This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit https://github.com/NVlabs/DeepInversion/blob/master/LICENSE

Updates

  • 2020 July 7. Added CIFAR10 inversion result for ResNet34 in the folder cifar10. Code on knowledge distillation will follow soon.
  • 2020 June 16. Added a new scaling factor first_bn_multiplier for first BN layer. This improves fidelity.

Requirements

Code was tested in virtual environment with Python 3.6. Install requirements:

pip install torch==1.4.0
pip install torchvision==0.5.0
pip install numpy
pip install Pillow

Additionally install APEX library for FP16 support (2x less memory, 2x faster): Installing NVIDIA APEX

Provided code was originally designed to invert ResNet50v1.5 model trained for 90 epochs that achieves 77.26% top-1 on ImageNet. We are not able to share the model, but anyone can train it here: ResNet50v1.5. Code works well for the default ResNet50 from torchvision package.

Code was tested on NVIDIA V100 GPU and Titan X Pascal.

Running the code

This snippet will generate 84 images by inverting resnet50 model from torchvision package.

python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25

Arguments:

  • bs - batch size, should be close to original batch size during training, but not necessary.
  • lr - learning rate for the optimizer of input tensor for model inversion.
  • do_flip - will do random flipping between iterations
  • exp_name - name of the experiment, will create folder with this name in ./generations/ where intermediate generations will be stored after 100 iterations
  • r_feature - coefficient for feature distribution regularization, might need adjustment for other networks
  • arch_name - name of the network architecture, should be one of pretrained models from torch vision package: resnet50, resnet18, mobilenet_v2 etc.
  • fp16 - enables FP16 training if needed, will use FP16 training via APEX AMP (O2 level)
  • verifier - enables checking accuracy of generated images with another network (def mobilenet_v2) network after each 100 iterations. Useful to observe generalizability of generated images.
  • setting_id - settings for optimization: 0 - multi resolution scheme, 1 - 2k iterations full resolution, 2 - 20k iterations (the closes to ResNet50 experiments in the paper). Recommended to use setting_id={0, 1}
  • adi_scale - competition coefficient. With positive value will lead to images that are good for the original model, but bad for verifier. Value 0.2 was used in the paper.
  • random_label - randomly select classes for inversion. Without this argument the code will generate hand picked classes.

After 3k iterations (~6 mins on NVIDIA V100) generation is done: Verifier accuracy: 91.6...% (experiment with >98% verifier accuracy can be found /example_logs). We generated images by inverting vanilla ResNet50 (not trained for image generation) and classification accuracy by MobileNetv2 is >90%. A grid of images look like (from /final_images/, reduced quality due to JPEG compression. ) Generated grid of images

Optimization is sensitive to hyper-parameters. Try local tunings for your setups/applications. Try to change the r_feature coefficient, l2 regularization, betas of Adam optimizer (beta=0 work well). Keep looking at loss_r_feature as it indicates how close feature statistics are to the training distribution.

Reduce batch size if out of memory or without FP16 optimization. In the paper, we used batch size of 152, and larger batch size is preferred. This code will generate images from 41 hand picked classes. To randomize the target classes, simply use argument --random_label.

Examples of running code with different arguments and resulting images can be found at /example_logs/.

Check if you can invert other architectures, or even apply to other applications (keypoints, detection etc.). Method has a room for improvement: (a) improving the loss for feature regularization (we used MSE in paper but that may not be ideal for distribution matching), (b) making it even faster, (c) generating images for which multiple models are confident, (d) increasing diversity.

Share your most exciting images at Twitter with hashtag #Deepinversion and #DeepInvert.

Citation

@inproceedings{yin2020dreaming,
	title = {Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion},
	author = {Yin, Hongxu and Molchanov, Pavlo and Alvarez, Jose M. and Li, Zhizhong and Mallya, Arun and Hoiem, Derek and Jha, Niraj K and Kautz, Jan},
	booktitle = {The IEEE/CVF Conf. Computer Vision and Pattern Recognition (CVPR)},
	month = June,
	year = {2020}
}

deepinversion's People

Contributors

ahatamiz avatar ayberkydn avatar francescodisalvo05 avatar holgerroth avatar hongxuyin avatar molchanovp avatar pamolchanov avatar wyli avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deepinversion's Issues

best_cost update

best cost doesn't seem to be updated i.e., remains 1e-4 so the inputs are always being updated.

The style of the generated images

Hi!
Thank you all so much for the significant work.
There is a question. Why are the generated images appear like Cartoons images (or with art style) rather than realistic image?
Are there any possible explanation?
Best.

How to search for hypers on a new dataset

Hi, thanks for your inspiring work and the generated ImageNet images are awesome. However, the generated images doesn't look good when I replace the ImageNet model with a model trained on another dataset. It seems the quality of the images is sensitive to the hypers ('tv_l1','tv_l2','r_feature' and 'l2'). Could you provide any insight of searching good hypers on a new dataset?

How to use the code for ADI?

I am having trouble figuring out how to use the code for performing ADI. What are the settings that we need to set for performing ADI?

Calculate batch norm statistic loss on parallel training

Hello, I have one question about batch norm statistic loss.

Consider parallel training. I have 8 GPUs. and 1 gpu can bear 128 batch size.

But you know, batch norm statistic loss is calculated on each machine and each machine share their gradients not whole batch(1024). And I think this can cause image quality degradation.

So, here is my question. How can I calculate batch norm statistic loss on parallel training just like calculating whole batch size not mini-batch

VGG Architectures

Hi there,
What are the dimensions of the fully connected layers for the VGG's used in this paper?
Thanks!

A minor typo

This work is amazing! The generated images are so realistic!

line 156 in imagenet_inversion.py

parameters["store_best_images"] = self.store_best_images

=>

parameters["store_best_images"] = args.store_best_images

Details about adversarial losses in the code

Thank you for your work. I notice that the student model does not use pre-trained weights when using the ADI method in the code to optimise inputs. And the optimizer is only for input tensors. This means that the weights of the student model are not optimized and remain initially distributed.

Does this result in the student model not being able to output appropriate logits to measure JS distance throughout the training process?

Having trouble installing apex and getting the result for the basic snippet

Hi,
After running this basic snippet (which I think does not require apex installation) python imagenet_inversion.py --bs=84 --do_flip --exp_name="rn50_inversion" --r_feature=0.01 --arch_name="resnet50" --verifier --adi_scale=0.0 --setting_id=0 --lr 0.25, I am getting only noisy images in the /generations folder and the verifier accuracy is only 0.0%. No final result in the final_images folder. I am having trouble installing apex as well.

Questions on the KD process on CIFAR10 dataset

Hi there,

You work is great! Here I have some questions about the Knowledge Distillation process on CIFAR10 dataset in you experiment part.

  1. How many CIFAR10-like images have you generated in order to reach those accuracies is Table 1 in your paper? As we have tried with 3000 or 10000 generated images (with DI,Resnet34,alpha_f = 10) using vanilla KD to distill from Resnet34 to Resnet18 and only reached 25% or 55% validation acc.

  2. We encountered problems when trying ADI.
    In the description of Table 1, it's said "for ADI, we generate one new batch of images every 50 KD iterations and merge the newly generated images into the existing set of generated iamges". Could you please explain more about this? Does the 50KD iteration mean 50 KD epochs? Does the "one new batch of images" mean a batch of like 256 images and merge them into the exisiting generated dataset? Does the KD process have to hang up and wait for the "new-batch-generating" process every 50 KD iteration (epoch if I get it correctly)?

Thanks

Question about R_compete term

Thank you for your interesting research,

I have a question about the R_{compete} term. The paper states as follows about R_{compete}:

"During optimization, this new term leads to new images the student cannot easily classify whereas the teacher can."

Since the Jensen-Shanon divergence is symmetry in terms of p(x^{hat}) and q(x^{hat}), the opposite might be true: The images are generated such that the teacher cannot classify them, but the student can.

Is it correct? And does it contradicts the purpose of the R_{compete} term?

Code for Other Experiments

Hi,

Your work is amazing! I'm wondering whether you are going to release the code for :1. Data-free pruning 2. Data-free Knowledge Transfer 3. Data-free Continual Learning.

Thanks!

Why the student model is *pretrained* for ImageNet for ADI?

Hi, thanks for your great work! I noticed the student network is pretrained for the ADI experiment on ImageNet. This is quite strange since for data-free knowledge distillation, the goal is to train a student with the synthetic samples. If you already have a pretrained student, the problem does not exist from the beginning.

Meanwhile, for the cifar10 experiment, the student is not pretrained, which I think should be the normal setting though. But there is an inconsistency here. Could you explain a little what makes you choose different schemes for cifar10 and ImageNet? Thanks!

Reproducibility on CIFAR-10

Hi, I followed Section 4.1 of the paper on CIFAR-10 but only got results like this after 5k iterations (convergence):
output_00050_gpu_0

I have modified these parts compared with the code for ImageNet: --r_feature=2.5e-2 --tv_l2 2e-4 --l2 2e-5 --lr 0.01 -setting_id=2 --random_label --main_loss_multiplier 0.25

The idea for setting the hyperparameters is to balance different losses in similar magnitude as in the released code for ImageNet. I used the pre-trained ResNet-34 with 95.5% accuracy trained by https://github.com/mbsariyildiz/resnet-pytorch I have also tried the same hyperparameters as indicated in Section 4.1 of the paper but still couldn't get better synthesized images.

Additionally, the image resolution has been set to 32, and I trained the model without adaptive inversion.

Are there any other changes that should be made in the code for CIFAR-10? Thank you for your help!

Here is the training log for the first 2k iterations. The printed losses have been multiplied by coefficients, and I used the same pre-trained ResNet-34 as the verifier:

------------iteration 100----------
total loss 1.5887961387634277
loss_r_feature 0.8383388519287109
main criterion 0.20274938642978668
loss var l2 0.54662197265625
loss l2 0.001085866928100586
Verifier accuracy: 67.85713958740234
------------iteration 200----------
total loss 1.1957563161849976
loss_r_feature 0.6756953239440918
main criterion 0.00945814698934555
loss var l2 0.50958203125
loss l2 0.0010206340789794923
Verifier accuracy: 100.0
------------iteration 300----------
total loss 1.0136175155639648
loss_r_feature 0.5451815128326416
main criterion 0.002356192795559764
loss var l2 0.46513969726562504
loss l2 0.000940232162475586
Verifier accuracy: 100.0
------------iteration 400----------
total loss 0.8616381883621216
loss_r_feature 0.4419695377349854
main criterion 0.0011746699456125498
loss var l2 0.41763930664062504
loss l2 0.0008546984863281251
Verifier accuracy: 100.0
------------iteration 500----------
total loss 0.7258148789405823
loss_r_feature 0.354432487487793
main criterion 0.0014118807157501578
loss var l2 0.36920244140625
loss l2 0.0007680741882324219
Verifier accuracy: 100.0
------------iteration 600----------
total loss 0.6115573644638062
loss_r_feature 0.284948205947876
main criterion 0.004634861368685961
loss var l2 0.32129179687500004
loss l2 0.0006825406646728516
Verifier accuracy: 100.0
------------iteration 700----------
total loss 0.5235512256622314
loss_r_feature 0.24538052082061768
main criterion 0.000998093979433179
loss var l2 0.276566259765625
loss l2 0.000606338119506836
Verifier accuracy: 100.0
------------iteration 800----------
total loss 0.46245500445365906
loss_r_feature 0.22423591613769533
main criterion 0.0008074002689681947
loss var l2 0.2368692626953125
loss l2 0.0005424297332763672
Verifier accuracy: 100.0
------------iteration 900----------
total loss 0.4182308316230774
loss_r_feature 0.2140077829360962
main criterion 0.000744551420211792
loss var l2 0.2029859130859375
loss l2 0.0004926132965087891
Verifier accuracy: 100.0
------------iteration 1000----------
total loss 0.38464871048927307
loss_r_feature 0.20906946659088135
main criterion 0.0006658562342636287
loss var l2 0.1744591552734375
loss l2 0.00045424114227294926
Verifier accuracy: 100.0
------------iteration 1100----------
total loss 0.3591064214706421
loss_r_feature 0.2067859172821045
main criterion 0.0003871832450386137
loss var l2 0.151506787109375
loss l2 0.0004265190887451172
Verifier accuracy: 100.0
------------iteration 1200----------
total loss 0.3402350842952728
loss_r_feature 0.20384662151336672
main criterion 0.0008285769145004451
loss var l2 0.13515140380859375
loss l2 0.0004085089111328125
Verifier accuracy: 100.0
------------iteration 1300----------
total loss 0.32719239592552185
loss_r_feature 0.2029275894165039
main criterion 0.0004118723445571959
loss var l2 0.12345615234375
loss l2 0.0003967760086059571
Verifier accuracy: 100.0
------------iteration 1400----------
total loss 0.32120072841644287
loss_r_feature 0.20393846035003663
main criterion 0.0007077441550791264
loss var l2 0.11616383056640625
loss l2 0.00039069297790527346
Verifier accuracy: 100.0
------------iteration 1500----------
total loss 0.31567034125328064
loss_r_feature 0.20350909233093262
main criterion 0.00039401365211233497
loss var l2 0.11137900390625001
loss l2 0.00038823162078857425
Verifier accuracy: 100.0
------------iteration 1600----------
total loss 0.31221890449523926
loss_r_feature 0.20267820358276367
main criterion 0.000622493855189532
loss var l2 0.10852938232421876
loss l2 0.0003888259506225586
Verifier accuracy: 100.0
------------iteration 1700----------
total loss 0.31075066328048706
loss_r_feature 0.20341982841491701
main criterion 0.0003112043777946383
loss var l2 0.10662841796875
loss l2 0.00039120994567871096
Verifier accuracy: 100.0
------------iteration 1800----------
total loss 0.30613452196121216
loss_r_feature 0.199886953830719
main criterion 0.0005381234805099666
loss var l2 0.10531446533203126
loss l2 0.0003950079727172852
Verifier accuracy: 100.0
------------iteration 1900----------
total loss 0.30858731269836426
loss_r_feature 0.20306947231292727
main criterion 0.0002929923066403717
loss var l2 0.10482397460937501
loss l2 0.00040085716247558595
Verifier accuracy: 100.0
------------iteration 2000----------
total loss 0.30381497740745544
loss_r_feature 0.1986505150794983
main criterion 0.0004578616062644869
loss var l2 0.104299462890625
loss l2 0.00040712715148925784
Verifier accuracy: 100.0

Reproducing Imagenet Knowledge Transfer Top-1 Accuracy

Hi,
Very interesting work!
According to Table 6 in the paper, training for 90 epochs with the 140K generated dataset should reach top-1 accuracy of 68.0%.
I'm trying to train Resnet50v1.5 based on the protocol here https://github.com/NVIDIA/DeepLearningExamples with the 140k dataset, can't pass top-1 accuracy of 10%.

Can you please elaborate on the training process using the generated 140k images? What protocol or additional work was required to reach the mentioned accuracy?

Thanks!

What is the KD temperature in the CIFAR10 experiment

Hi, great thanks for your inspiring work! I am trying to reproduce the data-free KD experiment on CIFAR10. The descriptions in the paper and supplementary material help a lot, but I cannot find the temperature of KD you used. Could you help with this?

The code in the cifar10 directory can only generate one batch of data, there is no KD training there. If possible, could you complete the left training part? I am afraid unofficial implementations may not reproduce your results (as noted in other issues). Really appreciate it!

Regards,

loss about segmentation task

Thanks for your work!
I want to generate images from segmentaion task model(like deeplabv3plus), but I use your original loss got bad result.
Could you help me with this issue?
Thanks a lot!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.