GithubHelp home page GithubHelp logo

vibnet's People

Contributors

zexinli0w0 avatar zhuchen03 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

vibnet's Issues

dimension issue

When I try to run the first provided command, there is dimension problem shown up as:

RuntimeError: The size of tensor a (64) must match the size of tensor b (3) at non-singleton dimension 3

The complete error information:

python ib_vgg_train.py --gpu 0 --batch-norm --resume-vgg-pt baseline/cifar10/checkpoint_299_nocrop.tar --ban-crop --opt adam --cfg D4 --epochs 300 --lr 1.4e-3 --weight-decay 5e-5 --kl-fac 1.4e-5 --save-dir ib_vgg_chk/D4
Namespace(ban_crop=True, ban_flip=False, batch_norm=True, batchsize=128, cfg='D4', data_set='cifar10', epochs=300, gpu=0, ib_lr=-1, ib_wd=-1, init_var=0.01, kl_fac=1.4e-05, lr=0.0014, lr_epoch=30, lr_fac=0.5, mag=9, momentum=0.9, no_ib=False, opt='adam', print_freq=50, reg_weight=0, resume='', resume_vgg_pt='baseline/cifar10/checkpoint_299_nocrop.tar', resume_vgg_vib='', sample_test=0, sample_train=1, save_dir='ib_vgg_chk/D4', tb_path='tb_ib_vgg', threshold=0, val=False, weight_decay=5e-05, workers=1)
Files already downloaded and verified
Using structure [(64, 0.03125), (64, 0.03125), 'M', (128, 0.0625), (128, 0.0625), 'M', (256, 0.125), (256, 0.125), (256, 0.125), 'M', (512, 0.25), (512, 0.25), (512, 0.25), 'M', (512, 0.5), (512, 0.5), (512, 0.5), 'M']
detected VIB params (45): ['conv_layers.3.prior_z_logD', 'conv_layers.3.post_z_mu', 'conv_layers.3.post_z_logD', 'conv_layers.7.prior_z_logD', 'conv_layers.7.post_z_mu', 'conv_layers.7.post_z_logD', 'conv_layers.12.prior_z_logD', 'conv_layers.12.post_z_mu', 'conv_layers.12.post_z_logD', 'conv_layers.16.prior_z_logD', 'conv_layers.16.post_z_mu', 'conv_layers.16.post_z_logD', 'conv_layers.21.prior_z_logD', 'conv_layers.21.post_z_mu', 'conv_layers.21.post_z_logD', 'conv_layers.25.prior_z_logD', 'conv_layers.25.post_z_mu', 'conv_layers.25.post_z_logD', 'conv_layers.29.prior_z_logD', 'conv_layers.29.post_z_mu', 'conv_layers.29.post_z_logD', 'conv_layers.34.prior_z_logD', 'conv_layers.34.post_z_mu', 'conv_layers.34.post_z_logD', 'conv_layers.38.prior_z_logD', 'conv_layers.38.post_z_mu', 'conv_layers.38.post_z_logD', 'conv_layers.42.prior_z_logD', 'conv_layers.42.post_z_mu', 'conv_layers.42.post_z_logD', 'conv_layers.47.prior_z_logD', 'conv_layers.47.post_z_mu', 'conv_layers.47.post_z_logD', 'conv_layers.51.prior_z_logD', 'conv_layers.51.post_z_mu', 'conv_layers.51.post_z_logD', 'conv_layers.55.prior_z_logD', 'conv_layers.55.post_z_mu', 'conv_layers.55.post_z_logD', 'fc_layers.2.prior_z_logD', 'fc_layers.2.post_z_mu', 'fc_layers.2.post_z_logD', 'fc_layers.5.prior_z_logD', 'fc_layers.5.post_z_mu', 'fc_layers.5.post_z_logD']
detected VGG params (58): ['conv_layers.0.weight', 'conv_layers.0.bias', 'conv_layers.1.weight', 'conv_layers.1.bias', 'conv_layers.4.weight', 'conv_layers.4.bias', 'conv_layers.5.weight', 'conv_layers.5.bias', 'conv_layers.9.weight', 'conv_layers.9.bias', 'conv_layers.10.weight', 'conv_layers.10.bias', 'conv_layers.13.weight', 'conv_layers.13.bias', 'conv_layers.14.weight', 'conv_layers.14.bias', 'conv_layers.18.weight', 'conv_layers.18.bias', 'conv_layers.19.weight', 'conv_layers.19.bias', 'conv_layers.22.weight', 'conv_layers.22.bias', 'conv_layers.23.weight', 'conv_layers.23.bias', 'conv_layers.26.weight', 'conv_layers.26.bias', 'conv_layers.27.weight', 'conv_layers.27.bias', 'conv_layers.31.weight', 'conv_layers.31.bias', 'conv_layers.32.weight', 'conv_layers.32.bias', 'conv_layers.35.weight', 'conv_layers.35.bias', 'conv_layers.36.weight', 'conv_layers.36.bias', 'conv_layers.39.weight', 'conv_layers.39.bias', 'conv_layers.40.weight', 'conv_layers.40.bias', 'conv_layers.44.weight', 'conv_layers.44.bias', 'conv_layers.45.weight', 'conv_layers.45.bias', 'conv_layers.48.weight', 'conv_layers.48.bias', 'conv_layers.49.weight', 'conv_layers.49.bias', 'conv_layers.52.weight', 'conv_layers.52.bias', 'conv_layers.53.weight', 'conv_layers.53.bias', 'fc_layers.0.weight', 'fc_layers.0.bias', 'fc_layers.3.weight', 'fc_layers.3.bias', 'fc_layers.6.weight', 'fc_layers.6.bias']
Learning rate of IB: 0.0014, learning rate of others: 0.0014
loaded pretraind model with acc 91.94
Traceback (most recent call last):
  File "ib_vgg_train.py", line 395, in <module>
    main()
  File "ib_vgg_train.py", line 104, in main
    model.state_dict()[ib_keys[i*9+j]].copy_(state_dict['state_dict'][vgg_keys[i*6+j]])
RuntimeError: The size of tensor a (64) must match the size of tensor b (3) at non-singleton dimension 3

running crash on CPU

Thanks to the author for posting this very easy-to-read code.

When conducting testing, I recognize there is a compatible bug preventing running on the CPU end.
In Line 17 in ib_layers.py of function reparameterize

eps = torch.FloatTensor(batch_size, std.size(0)).cuda(mu.get_device()).normal_()

Here if the model is located on the CPU, then mu.get_device() will return -1. Then torch.FloatTensor(...).cuda(-1) will raise an runtime error like following:

RuntimeError: Device index must not be negative

I suggest the author make conditional statements here to avoid runtime error.

Why are you using weight decay on IB parameters?

Hello,

In the paper a regularization term that is based on the information bottleneck theory is introduced.
however, I have realized that in the code, you are setting weight decay coefficient for the ib parameters in the optimizer. weight decay will simply penalize the norm 2 of both \mu and \sigma . this means that you are adding another regularization term to the one you introduced in the paper (i.e. log(1+(\alpha)^2)).

could you kindly let me know the reason?

question about layer-by-layer IB loss

Hi, there is a simple network feeded with a random input x:

x --> A --> B--> classifier --> y^hat

I want to squeeze x through layer A and B, then get logit vector from the classifier and finally output predictive label y^hat. What is the IB loss for each layer? Is it L = I(B;A)-I(A;y) + I(classifier;B)-I(classifier;y)? Why or why not? Thanks!!

Questions about Logq(y|hL)

Hi!
I'm reading the code of your paper and I meet a question about the calculation of Logq(y|h_L) in equation(9). In the code, you calculate the cross-entropy of the outputs and the labels and use this cross-entropy to represent the average of Logq(y|h_L) . But I can't understand why they are equal. Could you explain this for me?
Looking forward to your reply! Thank you!

Questions about the calculation of the loss

Hi,

Thanks for releasing this implementation! I have been looking through this code and your paper (arXiv version) and there are a couple of things that are not clear to me:

  1. kl_mult (an argument to each InformationBottleneck layer) corresponds to \gamma_i in equation (10). Then, after you sum the \gamma weighted KLDs of each IB layer you multiply the resulting kl_total by kl_fac before adding it to the cross entropy loss.

     ce_loss = criterion(output, target_var)
     
     loss = ce_loss
     if kl_fac > 0:
         loss += kl_total * kl_fac
    

    What is the motivation for kl_fac? It does not appear to derive from (10).

  2. In the code I can't find the weighting term L on the data term in (10). Can you point me to where you scale this term? It occurred to me that kl_fac might be scaling down the KL term instead of scaling up the data term, but in that case I would expect the value of kl_fac to be 1/L, much larger than the default value of 1e-6.

Thanks again!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.