GithubHelp home page GithubHelp logo

wgrathwohl / jem Goto Github PK

View Code? Open in Web Editor NEW
413.0 15.0 63.0 2.88 MB

Project site for "Your Classifier is Secretly an Energy-Based Model and You Should Treat it Like One"

License: Apache License 2.0

Python 100.00%

jem's Introduction

JEM - Joint Energy Models

Official code for the paper Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One.

JEM

Includes scripts for training JEM (Joint-Energy Model), evaluating models at various tasks, and running adversarial attacks.

A pretrained model on CIFAR10 can be found here.

For more info on me and my work please checkout my website, twitter, or Google Scholar.

Many thanks to my amazing co-authors: Jackson (Kuan-Chieh) Wang, Jörn-Henrick Jacobsen, David Duvenaud, Mohammad Norouzi, and Kevin Swersky.

Usage

Training

To train a model on CIFAR10 as in the paper

python train_wrn_ebm.py --lr .0001 --dataset cifar10 --optimizer adam --p_x_weight 1.0 --p_y_given_x_weight 1.0 --p_x_y_weight 0.0 --sigma .03 --width 10 --depth 28 --save_dir /YOUR/SAVE/DIR --plot_uncond --warmup_iters 1000

Evaluation

To evaluate the classifier (on CIFAR10):

python eval_wrn_ebm.py --load_path /PATH/TO/YOUR/MODEL.pt --eval test_clf --dataset cifar_test

To do OOD detection (on CIFAR100)

python eval_wrn_ebm.py --load_path /PATH/TO/YOUR/MODEL.pt --eval OOD --ood_dataset cifar_100

To generate a histogram of OOD scores like Table 2

python eval_wrn_ebm.py --load_path /PATH/TO/YOUR/MODEL.pt --eval logp_hist --datasets cifar10 svhn --save_dir /YOUR/HIST/FOLDER

To generate new unconditional samples

python eval_wrn_ebm.py --load_path /PATH/TO/YOUR/MODEL.pt --eval uncond_samples --save_dir /YOUR/SAVE/DIR --n_sample_steps {THE_MORE_THE_BETTER (1000 minimum)} --buffer_size 10000 --n_steps 40 --print_every 100 --reinit_freq 0.05

To generate conditional samples from a saved replay buffer

python eval_wrn_ebm.py --load_path /PATH/TO/YOUR/MODEL.pt --eval cond_samples --save_dir /YOUR/SAVE/DIR

To generate new conditional samples

python eval_wrn_ebm.py --load_path /PATH/TO/YOUR/MODEL.pt --eval cond_samples --save_dir /YOUR/SAVE/DIR --n_sample_steps {THE_MORE_THE_BETTER (1000 minimum)} --buffer_size 10000 --n_steps 40 --print_every 10 --reinit_freq 0.05 --fresh_samples

Attacks

To run Linf attacks on JEM-1

python attack_model.py --start_batch 0 --end_batch 6 --load_path /PATH/TO/YOUR/MODEL.pt --exp_name /YOUR/EXP/NAME --n_steps_refine 1 --distance Linf --random_init --n_dup_chains 5 --base_dir /PATH/TO/YOUR/EXPERIMENTS/DIRECTORY

To run L2 attacks on JEM-1

python attack_model.py --start_batch 0 --end_batch 6 --load_path /cloud_storage/BEST_EBM.pt --exp_name rerun_ebm_1_step_5_dup_l2_no_sigma_REDO --n_steps_refine 1 --distance L2 --random_init --n_dup_chains 5 --sigma 0.0 --base_dir /cloud_storage/adv_results &

Happy Energy-Based Modeling!

jem's People

Contributors

wgrathwohl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jem's Issues

How to compute the IS using the tensorflow code from Du & Mordatch

Hi, thank you for your excellent work, the code and the pre-trained model.

I have a simple question. I want to reproduce the Inception Score in the paper. But the code is written in TensorFlow (https://github.com/openai/ebm_code_release/blob/master/test_inception.py), and your model is a PyTorch model.

So is it that converts the PyTorch model into a TensorFlow model with some libraries. Or are there any other methods to do it?

Thank you so much!

Evaluating the model on adversarial attacks

The attack.py file only generates the adversarial samples as mentioned in the paper. What is the difference in generating adversarial samples using attack.py file and directly generating using foolbox? Also, is there a script which can be used to evaluate the adversarial images?

How to train JEM with batch norm

Hi

Thanks for the exciting work and open-sourced code.

I have a problem. As you noted in the paper "we have been able to successfully train Joint-EBMs with Batch Normalization", I run your code with enabling batch norm --norm batch, but it seems the code doesn't work at all.

Do I miss anything? Do I need to change some code to enable batch norm?

Any help is appreciated. Thank you.

Adapting JEM to high-resolution images

Dear Will Grathwohl,

Thanks so much for the inspiring work!

Now I am trying to adapt your proposed joint energy-based model to more challenging dataset , where resized 256^2 ~ 512^2 images exist. As a result, the JEM does not converge well no matter what regularization methods (L2/L2 grad regularizations as described at #4 ) are used. Could you please provide with me some ideas for advancing JEM to high-resolution images?

Thanks in advance,

Best,
Jun-Pu

Training time 36 hours

Hi,

I'm trying to to run the JEM training algorithm in train_wrn_ebm.py, using

python train_wrn_ebm.py --lr .0001 --dataset cifar10 --optimizer adam --p_x_weight 1.0 --p_y_given_x_weight 1.0 --p_x_y_weight 0.0 --sigma .03 --width 10 --depth 28 --save_dir /YOUR/SAVE/DIR --plot_uncond --warmup_iters 1000.

However, it's taking about ~2.2s/iteration which works out to at least ~80 hours of training time, (assuming at least 700 steps per epoch for a train batch size of 64 for CIFAR10) rather than 36 as stated in the paper (https://arxiv.org/pdf/1912.03263.pdf, pg 4). Running on a p3.2xlarge instance on AWS. Could you please help explain the discrepancy?

Thanks!

Volatile accuracy

Hi Will, first of all congratulations on all your success! Your work is great and inspiring and you really opened up a whole new realm of modeling possibilities for me (didn't know anything about ebms before).

Second of all, I'm currently exploring another EBM application based on your work in JEM and I was wondering whether you could help me understand a phenomenon. While running my model, I noticed that the training accuracy I calculate every few iterations sometimes decreases. When I noticed this I ran JEM again and realized it happened there too. Now the situations aren't comparable because that data is different but what is the intuition behind this? Is it bad? How can I work against it (change the no of steps in sgld, perhaps?). I usually noticed this in the first few epochs (also in JEM) - perhaps this changes (I'm quite impatient).

Another question I had was approximately at what epoch could you see that the samples you are generating were becoming something rather than just noise. I realize EBMs are volatile and take a long time but I'd just like to get an idea at what point can I say that this set up isn't working and I need to find new parameter settings

Thanks very much!

Model training terminated

Hi, one other issue I wanted to point out was that the training process seemed to terminate about 27 epochs in, due to a diverging loss.

Thanks!
Screenshot 2020-11-23 at 08 46 08

Estimate log p(x,y)

Hi,

Thank you very much for the code as well as the pretrained model.

I am trying to estimate the log p(x,y), but unfortunately I could not figure out in the code if it is already implemented.
If so, could you please point me in the direction on how I can do that.

Thanks in advance,

Dealing with divergence

Hello,

Your work is inspiring!
I have the following problem when I try to run your code.
During training, the loss often blows up and diverges. Could you help me as to how to deal with such divergences? It diverges even after turning off BatchNorm, having warmup-terations... often after 2 epochs.

Any help is appreciated. Thank you.

How to generate Distal Adversarial Examples in paper?

Hi, thank you for your great research. I have a problem about the distal adversaries.

Did you use the code like https://github.com/bethgelab/AnalysisBySynthesis/blob/441479b231fbd6a43615c10c7c68ccc86c31ae44/scripts/attacks.py
for it?

if it is, can you share your configuration for distal adversaries?

opti = torch.optim.SGD([a_helper], lr=1, momentum=0.95)
confidence_level = model.confidence_level    # abs 0.0000031, CNN 1439000, madry 60, 1-NN 0.000000000004
logits_scale = model.logit_scale                      # ABS 430, madry 1, CNN 1, 1-NN 5

Thank you so much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.