ruqizhang / csgmcmc Goto Github PK
View Code? Open in Web Editor NEWCyclical Stochastic Gradient MCMC for Bayesian Deep Learning
Cyclical Stochastic Gradient MCMC for Bayesian Deep Learning
Hello, Ruiqi Zhang,
I have read part of your paper and code, and it is a really interesting work! At the top of Page 3, in the second paragraph, when you wrote "use a minibatch to approximate U(\theta)", should the scale coefficient N'/N be N/N'? Since we are using the minibatch to approximate the full dataset.
Another question is in the code file "cifar_csghmc.py": the line 78, you wrote "d_p.add_(weight_decay, p.data)". My naive understanding is that you are computing the potential energy term, involving both the log-prior (from a normal distribution) and the log-likelihood. If that's the case, we seem to miss the coefficient "N" mentioned above...
I'm looking forward to your reply! Thank you very much!
Hi,
Thanks for an interesting paper and open-source code!
I'm struggling somewhat to match your SGLD/SGHMC implementation with the update equations in the paper. In https://github.com/ruqizhang/csgmcmc/blob/51e511478d607b2523fc803d82a26edd39b14b6d/experiments/cifar_csgmcmc.py (line 79), you have noise_std = 2*lr*alpha**0.5
. We will take the gradient of this noise term and then multiply the result with lr, right? Will that actually match the update equation for SGLD then?
Here's how I would implement SGLD (with p(theta) = N(0, alpha*I)), N=datasize):
loss_fn = nn.CrossEntropyLoss()
loss_likelihood = loss_fn(logits, y)
loss_prior = 0.0
for param in network.parameters():
loss_prior += (1.0/2.0)*(1.0/N)*(1.0/alpha)*torch.sum(torch.pow(param, 2))
loss_noise = 0.0
for param in network.parameters():
loss_noise += (1.0/math.sqrt(N))*math.sqrt(2.0/lr)*torch.sum(param*Variable(torch.normal(torch.zeros(param.size()), std=1.0).cuda()))
loss = loss_likelihood + loss_prior + loss_noise
Hope you can help resolve my confusion.
Regards
//
Fredrik
Hi! cifar100_ensemble.py
does the ensembling in logit space while the textbook and commonsensical algorithm is to do the ensembling in prediction space. Is it OK?
Hello! I'm having a bit of trouble reproducing the results for cyclic HMC with your code. Could you please help me with that?
I'm seeing parser.add_argument('--alpha', type=int, default=1, help='1: SGLD; <1: SGHMC')
but shouldn't alpha here be float in [0,1]?
Also, if I allow it to be float and test the training on some value of alpha which seems reasonable (e.g. alpha = 0.05 resulting in momentum term = 0.95 which should be OK?), network weights seem to go to nan soon.
I have some questions on the implementation of csghmc and how to test uncertainy
1 For csghmc, does the method in "Stochastic Gradient Hamiltonian Monte Carlo" is used at the sampling stage. I would like to know how do you implement it or corresponding function in the code.
2 How can i reproduce uncertainty estimation result shown in the paper.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.