GithubHelp home page GithubHelp logo

arf's Introduction

Augmented Random Forest

Simple Post-Training Robustness Using Test Time Augmentations and Random Forest

This repo reproduces all the reuslts shown in our paper.

Init project

  1. Run in the project dir:
source ./init_project.sh

Create validation set and test set indices for all dataset by running:

python src/scripts/set_val_test_inds.py

This generates the 'test' and 'test-val' subsets (as explained in the paper) for each dataset

Train

Train Resnet networks for cifar10, cifar100, svhn, and tiny_imagenet using src/train.py.

For example, for CIFAR-10 run:

  1. Regular network:
python src/scripts/train.py --dataset cifar10 --net resnet34 --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00
  1. TRADES:
python src/scripts/train.py --dataset cifar10 --net resnet34 --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/adv_robust_trades --adv_trades True
  1. VAT:
python src/scripts/train.py --dataset cifar10 --net resnet34 --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/adv_robust_vat --adv_vat True

If you wish also to reproduce results for the ensemble, train 9 more networks in:

/tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_01
/tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_02
/tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_03
/tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_04
/tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_05
/tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_06
/tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_07
/tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_08
/tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_09

Attack

For attacking a network, use src/attack.py.

For example, to attack CIFAR-10 with the $FGSM^2$ attack (defined in the paper), run:

python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack fgsm --targeted True --attack_dir fgsm2 --eps 0.031

Prior to training the Random Forest classifier, one has to generate all the non-adapted attacks in section 4 in the paper: fgsm1, fgsm2, jsma, pgd1, pgd2, cw, cw_Linf, square, and boundary. The complete set of attacks one must run is given here:

  1. [$FGSM^1$]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack fgsm --targeted True --attack_dir fgsm1 --eps 0.01
  1. [$FGSM^2$]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack fgsm --targeted True --attack_dir fgsm2 --eps 0.031
  1. [JSMA]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack jsma --targeted True --attack_dir jsma
  1. [$PGD^1$]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack pgd --targeted True --attack_dir pgd1 --eps 0.01 --eps_step 0.003
  1. [$PGD^2$]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack pgd --targeted True --attack_dir pgd2 --eps 0.031 --eps_step 0.003
  1. [Deepfool]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack deepfool --targeted False --attack_dir deepfool
  1. [$CW_{L_2}$]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack cw --targeted True --attack_dir cw
  1. [$CW_{L_\infty}$]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack cw_Linf --targeted True --attack_dir cw_Linf --eps 0.031
  1. [Square]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack square --targeted False --attack_dir square --eps 0.031
  1. [Boundary]:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack boundary --targeted True --attack_dir boundary

Fit the random forest

After attacking a network with the above 10 attack, train the random forest by running:

python src/scripts/train_random_forest.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00

The random forest parameters will be saved under: /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00/random_forest/random_forest_classifier.pkl

Adaptive white-box BPDA attack:

After saving the random forest weights, you can attack the ARF defense.

  1. First, create a substitute model using:
python src/scripts/train_random_forest_sub.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00
  1. Second, call the BPDA attack:
python src/scripts/attack.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --attack bpda --targeted True --eps 0.031 --eps_step 0.007 --max_iters 10

Evaluation

Use src/scrips/eval.py to evaluate the defenses.

For evaluation a plain model without any defense, run:

python src/scripts/eval.py --checkpoint_dir /tmp/adversarial_robustness/cifar10/resnet34/regular/resnet34_00 --method simple --attack_dir <YOUR_SELECTED_ATTACK> --dump_dir simple

For calculating accuracy on the Ensemble, TTA, or ARF, replace the "simple" above with "ensemble", "tta", or "random_forest", respectively.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.