Comments (9)
You don't have to use the Turing...
types. If you load DistributionsAD, you also get the differentiation rules for MvNormal
(see e.g. https://github.com/TuringLang/DistributionsAD.jl/blob/d6aaa6452c033312ebd26aaf4f241b3fac9bebb7/src/multivariate.jl#L309). So with this type piracy both approaches in your example should use the same differentiation rules.
from advancedvi.jl.
Interesting result - did you enable tape caching for ReverseDiff
? From my experience, it provides substantial speedup over non-caching ReverseDiff
.
@mohamed82008 @willtebbutt @devmotion
from advancedvi.jl.
Do you use DistributionsAD? Also d
is a non-constant global, do you get the same results if you fix this? And is there any difference if you use
gradient(X) do x
loglikelihood(d, x)
end
(with constant d
)? IIRC DistributionsAD contains optimized differentiation rules for loglikelihood
.
from advancedvi.jl.
@yebai I don't know how to do the tape caching so I cannot try. But @devmotion here are the results with constant d
and constant d
from TuringMvNormal
using BenchmarkTools
using Distributions, DistributionsAD
using Zygote
using ReverseDiff
m = rand(50)
C = rand(50, 50) |> x -> x * x'
d = MvNormal(m, C)
const c_d = MvNormal(m, C)
ad_d = TuringDenseMvNormal(m, C)
const c_ad_d = TuringDenseMvNormal(m, C)
f(x) = logpdf(d, x)
c_f(x) = logpdf(c_d, x)
ad_f(x) = loglikelihood(ad_d, x)
c_ad_f(x) = loglikelihood(c_ad_d, x)
X = rand(d, 40)
## Previous result
@btime ReverseDiff.gradient($X) do x
sum(f, eachcol(x))
end
# 18.051 ms (634979 allocations: 26.17 MiB)
@btime Zygote.gradient($X) do x
sum(f, eachcol(x))
end
# 2.029 ms (14810 allocations: 4.15 MiB)
## With constant d
@btime ReverseDiff.gradient($X) do x
sum(c_f, eachcol(x))
end
# 18.774 ms (634938 allocations: 26.17 MiB)
@btime Zygote.gradient($X) do x
sum(c_f, eachcol(x))
end
# 1.481 ms (13480 allocations: 3.32 MiB)
## Same thing but using DistributionsAD/ loglikelihood
@btime ReverseDiff.gradient($X) do x
sum(ad_f, eachcol(x))
end
# 16.594 ms (533099 allocations: 22.31 MiB)
@btime Zygote.gradient($X) do x
sum(ad_f, eachcol(x))
end
# 2.650 ms (17890 allocations: 4.29 MiB)
## Same thing but using DistributionsAD/ loglikelihood and constant distribution
@btime ReverseDiff.gradient($X) do x
sum(c_ad_f, eachcol(x))
end
# 16.560 ms (532938 allocations: 22.31 MiB)
@btime Zygote.gradient($X) do x
sum(c_ad_f, eachcol(x))
end
# 2.188 ms (17174 allocations: 3.50 MiB)
from advancedvi.jl.
It also seems you still use sum
and eachcol
with loglikelihood
? loglikelihood
is already defined as the sum of logpdf
for a set of samples (such as columns in a matrix), so both should be removed.
from advancedvi.jl.
That's fair, but I just use the logpdf of a Gaussian as an example. In practice f
might be much more complex and take vectors only. That's actually why I am posting this issue on the AdvancedVI
repo
from advancedvi.jl.
In your code, ReverseDiff builds a new Wengert list / gradient tape at every call. My guess is that, if you build and compile the tape once the results will look quite a bit different.
See: https://github.com/JuliaDiff/ReverseDiff.jl/blob/master/examples/gradient.jl
from advancedvi.jl.
True, but Zygote has the advantage to do that automatically. The only annoying thing with the tape approach is that when samples are stored in a matrix, it forces you to output a vector of vector (I think), I will get back to my refactoring PR #25 soon and try to improve this aspect.
from advancedvi.jl.
I'll close this issue for now since the topic is quite outdated. Please re-open if anybody feals the issue is still relevant.
from advancedvi.jl.
Related Issues (20)
- custom training loop implementation help HOT 5
- TagBot trigger issue HOT 13
- Rethinking AdvancedVI HOT 19
- VI+PSIS HOT 1
- Pathfinder HOT 2
- Both Bijectors and Distributions export "Distribution" HOT 1
- Missing API method HOT 3
- SVGD HOT 2
- Minibatches HOT 3
- Setting up Documenter
- Need a weighted loss function/ log likelihood HOT 1
- Question/feature request about amortized inference HOT 5
- Add Tapir to AD tests HOT 26
- Set up `JuliaFormatter` HOT 6
- Set up unit tests for GPU support
- Run benchmarking only on PRs HOT 3
- `var` and `cov` on `MvLocationScale` secretly assume the base distribution is standardized
- Where are the benchmarks posted? HOT 6
- Make use of DifferentiationInterface.jl? HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from advancedvi.jl.