GithubHelp home page GithubHelp logo

robust_overfitting's Introduction

Overfitting in adversarially robust deep learning

A repository which implements the experiments for exploring the phenomenon of robust overfitting, where robust performance on the test performance degradessignificantly over training. Created by Leslie Rice, Eric Wong, and Zico Kolter. See our paper on arXiv here.

News

  • 04/10/2020 - The AutoAttack framework of Croce & Hein (2020) evaluated our released models using this repository here. On CIFAR10, our models trained with standard PGD and early stopping ranks at #5 overall, and #1 for defenses that do not rely on additional data.
  • 02/26/2020 - arXiv posted and repository release

Robust overfitting hurts - early stopping is essential!

A large amount of research over the past couple years has looked into defending deep networks against adversarial examples, with significant improvements over the well-known PGD-based adversarial training defense. However, adversarial training doesn't always behave similarly to standard training. The main observation we find is that, unlike in standard training, training to convergence can significantly harm robust generalization, and actually increases robust test error well before training has converged, as seen in the following learning curve:

overfitting

After the initial learning rate decay, the robust test error actually increases! As a result, training to convergence is bad for adversarial training, and oftentimes, simply training for one epoch after decaying the learning rate achieves the best robust error on the test set. This behavior is reflected across multiple datasets, different approaches to adversarial training, and both L-infinity and L-2 threat models.

No algorithmic improvements over PGD-based adversarial training

We can apply this knowledge to PGD-based adversarial training (e.g. as done by the original paper here), and find that early stopping can substantially improve the robust test error by 8%! As a result, we find that PGD-based adversarial training is as good as existing SOTA methods for adversarial robustness (e.g. on par with or slightly better than TRADES). On the flipside, we note that the results reported by TRADES also rely on early stopping, as training the TRADES approach to convergence results in a significant increase in robust test error. Unfortunately, this means that all of the algorithmic gains over PGD in adversarially robust training can be equivalent obtained with early stopping.

What is in this repository?

  • The experiments for CIFAR-10, CIFAR-100, and SVHN are in train_cifar.py, train_cifar100.py, train_svhn.py respectively.
  • CIFAR-10 training with semisupervised data is done in train_cifar_semisupervised_half.py, and uses the 500K pseudo-labeled TinyImages data from https://github.com/yaircarmon/semisup-adv
  • TRADES training is done with the repository located at https://github.com/yaodongyu/TRADES, with the only modification being the changes to the learning rate schedule to train to convergence (to decay at epochs 100 and 150 out of 200 total epochs).
  • For ImageNet training, we used the repository located at https://github.com/MadryLab/robustness with no modifications. The resulting logged data is stored in .pth files which can be loaded with torch.load() and are simply dictionaries of logged data. The scripts containing the parameters for resuming the ImageNet experiments can be found in imagenet_scripts/.
  • Training logs are all located in the experiments folder, and each subfolder corresponds to a set of experiments carried in the paper.

Model weights for the following models can be found in this drive folder:

  • The best checkpoints for CIFAR-10 WideResNets defined in wideresnet.py (in for width factor 10 and 20 (from the double descent curve trained against L-infinity)
  • The best checkpoints for SVHN / CIFAR-10 (L2) / CIFAR-100 / ImageNet models reported in Table 1 (the ImageNet checkpoints are in the format directly used by https://github.com/MadryLab/robustness). The remaining models are for the Preactivation ResNet18 defined in preactresnet.py.

robust_overfitting's People

Contributors

leslierice1 avatar riceric22 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

robust_overfitting's Issues

About the computation for CIFAR10 experiment

Hello,

Thanks for your sharing and your outstanding contributions for adversarial learning.

I wonder how much computation power is needed to run the default settings of CIFAR10, e.g., batch_size=128.

Could you list the corresponding GPU architectures which is able to run it? How many GPU memories are needed.

Thanks & Regards!
Momo

Is it better to use stratified sampling to divide the validation set?

I noticed that the labels of the validation set are slightly unbalanced, something like this: Counter({3: 113, 1: 112, 5: 107, 0: 100, 6: 99, 7: 98, 9: 98, 2: 94, 8: 90, 4: 89}) with seed 0 under my environment settings. I haven't tested it yet, but maybe a stratified sampling is better?

experience on SVHN outputs anormally high confidence on adversarial accuracy

Hello. Thank you for opening your code and experience log.

While running your code to train SVHN, I found that training SVHN gives strangely high adversarial accuracy on test set.

Including your paper, SVHN usually shows adversarial accuracy near 55~60%.

However, when I run your code for 4 times with different seeds(0~3), 3 of them gives accuracy near 90%.

The only change I made on the original code is to add 3 lines at the begining of the code to assign GPU.

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"   #can be "1", "2", "3"

I share the logs of that are trained so far.

(The trainings are not finished yet, but as you describe the "Best" performance also in the paper, strange best performance can be an issue)

output1.log
output2.log
output3.log
output0.log

I never saw any paper that claims their adversarial accuracy on SVHN is near 90%. So I presume this is result of a bug but I am not certain.

How was the validation split derived?

Great work and thanks for the repo!

I'm seeing this line in the code to define the validation set:

dataset = torch.load("cifar10_validation_split.pth")

I couldn't find this file in the codebase. Would you mind pointing me to the file, or explaining how the split of 1,000 images was defined?

Thanks,
Tony

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.