Comments (4)
The println
behavior is correct, except for the segfault which seems to just be some base Julia bug that Cassette is somehow triggering. When you're doing nested overdubbing, then your target program for the outer context will be the underlying program overdubbed with your inner context, hence seeing a lot of Cassette methods in your log.
The nested AD example isn't a Cassette bug, but just a "bug" of that AD implementation using Cassette. It's just a toy AD, so the subset of Julia that it supports is obviously extremely limited. However, this is a good test case! I've expanded the toy AD implementation so that it works on it. Here are the two things one has to do to support this test case:
- It needs the non-tagged fallbacks for its primitives:
Cassette.execute(ctx::DiffCtx, ::typeof(sin), x::Real) = sin(x)
Cassette.execute(ctx::DiffCtx, ::typeof(cos), x::Real) = cos(x)
Cassette.execute(ctx::DiffCtx, ::typeof(*), x::Real, y::Real) = x * y
Cassette.execute(ctx::DiffCtx, ::typeof(+), x::Real, y::Real) = x + y
This prevents accidentally tracing through these methods, e.g. *(::Real, ::Tagged{T})
looks like *(::Real, ::Real)
when "dominating" context isn't tagged with T
.
- It needs to intercept
*(x, y, z)
(right now it only does the two arg version), so it needs methods like
Cassette.execute(ctx::DiffCtx, ::typeof(*), x, y, z) = Cassette.execute(ctx, *, Cassette.execute(ctx, *, x, y), z)
Cassette.execute(ctx::DiffCtx, ::typeof(+), x, y, z) = Cassette.execute(ctx, +, Cassette.execute(ctx, +, x, y), z)
(you can also do something similar to define the arbitrary varargs version as a primitive if you want)
from cassette.jl.
Thanks for the response!
The reason I mention about the println
case is just to show there is a weird Segmentation fault: 11
, I guess that might be a bug somehow.
The nested AD example is pretty cool. The origin version is weird but make sense if I think twice. However, Cassette seem to be able to handle it correctly. But I'm a little confused about the new one.
In my thought, if I try to do D(x -> x * D(y -> 5*x*y, 3), 2)
, the x in the inner D
isn't traced, so the value is just calculate and return to the outer D
so we get the wrong answer in the origin version.
I'm not sure why non-tagged fallbacks could actually help in this case.
from cassette.jl.
In my thought, if I try to do D(x -> x * D(y -> 5xy, 3), 2), the x in the inner D isn't traced, so the value is just calculate and return to the outer D so we get the wrong answer in the origin version.
I'm not sure why non-tagged fallbacks could actually help in this case.
The intuition for nested tagged overdubbing is a bit tricky to come by, especially since I haven't really documented it very well yet.
Maybe this example will help:
julia> ctx1 = enabletagging(DiffCtx(), 1);
julia> ctx2 = overdub(ctx1, () -> enabletagging(DiffCtx(), 2));
julia> x = tag(1, ctx1, 1)
Tagged(Tag{nametype(DiffCtx),18292539232925712936,Nothing}(), 1, Meta(1, _))
julia> overdub(ctx1, execute, ctx2, +, x, x)
Tagged(Tag{nametype(DiffCtx),18292539232925712936,Nothing}(), 2, Meta(2, _))
That's with the untagged fallbacks. Without them:
julia> overdub(ctx1, execute, ctx2, +, x, x)
Cassette.OverdubInstead()
julia> overdub(ctx1, overdub, ctx2, +, x, x)
2
Here, x
is tagged w.r.t. ctx1
. From ctx2
's perspective, however, x
is essentially an untagged <:Real
. Thus, without the untagged fallbacks, +(x, x)
is not a primitive for ctx2
, and thus should be traced through. This is why, in reality, we really do want +(x, x)
to be a primitive - simply tracing through would skip the code we want the outer context to also intercept!
Fun fact: at one point, I actually had the tagging interface auto-define these fallbacks upon new primitive definitions, but it turned out to be quite messy and too "auto-magical".
from cassette.jl.
The println
weirdness here is #77; closing since the AD part of the issue is resolved. Thanks for filing!
from cassette.jl.
Related Issues (20)
- Something with tagging is broken in 1.4 HOT 1
- Inference issues with accessing hcat-ed arrays
- Tag new version HOT 5
- Internal error when overdubbing HTTP.request HOT 1
- Open discussion - support for dynamic pass creation HOT 3
- Overdubbing not working when function called from within `@threads for` loop HOT 2
- Code using Cassette fails for nightly builds HOT 2
- Discriminating overdub calls for "same" function/args HOT 1
- Is there a way to bail out of overdubbing? HOT 2
- TagBot trigger issue HOT 10
- Errors on Julia v1.6 HOT 3
- less helpful stacktraces on 1.6 HOT 3
- Error compiling Cubature.hcubature in context Traceur.Trace
- Default value to `reflect` should be `Base.current_world` HOT 1
- Cassette and AbstractInterpeter
- Cassette compilation fails as of Julia commit 6ce28008ba6db324b171909fa8e641fe8bce9db4 HOT 1
- Error in overdub with StaticArrays.jl HOT 5
- Error on ColorTypes HOT 2
- Very high TTFX HOT 1
- nightly failing to precompile with `ERROR: LoadError: invalid struct allocation` HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cassette.jl.