GithubHelp home page GithubHelp logo

Comments (4)

jrevels avatar jrevels commented on August 16, 2024

The println behavior is correct, except for the segfault which seems to just be some base Julia bug that Cassette is somehow triggering. When you're doing nested overdubbing, then your target program for the outer context will be the underlying program overdubbed with your inner context, hence seeing a lot of Cassette methods in your log.

The nested AD example isn't a Cassette bug, but just a "bug" of that AD implementation using Cassette. It's just a toy AD, so the subset of Julia that it supports is obviously extremely limited. However, this is a good test case! I've expanded the toy AD implementation so that it works on it. Here are the two things one has to do to support this test case:

  1. It needs the non-tagged fallbacks for its primitives:
Cassette.execute(ctx::DiffCtx, ::typeof(sin), x::Real) = sin(x)
Cassette.execute(ctx::DiffCtx, ::typeof(cos), x::Real) = cos(x)
Cassette.execute(ctx::DiffCtx, ::typeof(*), x::Real, y::Real) = x * y
Cassette.execute(ctx::DiffCtx, ::typeof(+), x::Real, y::Real) = x + y

This prevents accidentally tracing through these methods, e.g. *(::Real, ::Tagged{T}) looks like *(::Real, ::Real) when "dominating" context isn't tagged with T.

  1. It needs to intercept *(x, y, z) (right now it only does the two arg version), so it needs methods like
Cassette.execute(ctx::DiffCtx, ::typeof(*), x, y, z) = Cassette.execute(ctx, *, Cassette.execute(ctx, *, x, y), z)
Cassette.execute(ctx::DiffCtx, ::typeof(+), x, y, z) = Cassette.execute(ctx, +, Cassette.execute(ctx, +, x, y), z)

(you can also do something similar to define the arbitrary varargs version as a primitive if you want)

from cassette.jl.

chengchingwen avatar chengchingwen commented on August 16, 2024

Thanks for the response!

The reason I mention about the println case is just to show there is a weird Segmentation fault: 11, I guess that might be a bug somehow.

The nested AD example is pretty cool. The origin version is weird but make sense if I think twice. However, Cassette seem to be able to handle it correctly. But I'm a little confused about the new one.

In my thought, if I try to do D(x -> x * D(y -> 5*x*y, 3), 2), the x in the inner D isn't traced, so the value is just calculate and return to the outer D so we get the wrong answer in the origin version.
I'm not sure why non-tagged fallbacks could actually help in this case.

from cassette.jl.

jrevels avatar jrevels commented on August 16, 2024

In my thought, if I try to do D(x -> x * D(y -> 5xy, 3), 2), the x in the inner D isn't traced, so the value is just calculate and return to the outer D so we get the wrong answer in the origin version.
I'm not sure why non-tagged fallbacks could actually help in this case.

The intuition for nested tagged overdubbing is a bit tricky to come by, especially since I haven't really documented it very well yet.

Maybe this example will help:

julia> ctx1 = enabletagging(DiffCtx(), 1);

julia> ctx2 = overdub(ctx1, () -> enabletagging(DiffCtx(), 2));

julia> x = tag(1, ctx1, 1)
Tagged(Tag{nametype(DiffCtx),18292539232925712936,Nothing}(), 1, Meta(1, _))

julia> overdub(ctx1, execute, ctx2, +, x, x)
Tagged(Tag{nametype(DiffCtx),18292539232925712936,Nothing}(), 2, Meta(2, _))

That's with the untagged fallbacks. Without them:

julia> overdub(ctx1, execute, ctx2, +, x, x)
Cassette.OverdubInstead()

julia> overdub(ctx1, overdub, ctx2, +, x, x)
2

Here, x is tagged w.r.t. ctx1. From ctx2's perspective, however, x is essentially an untagged <:Real. Thus, without the untagged fallbacks, +(x, x) is not a primitive for ctx2, and thus should be traced through. This is why, in reality, we really do want +(x, x) to be a primitive - simply tracing through would skip the code we want the outer context to also intercept!

Fun fact: at one point, I actually had the tagging interface auto-define these fallbacks upon new primitive definitions, but it turned out to be quite messy and too "auto-magical".

from cassette.jl.

jrevels avatar jrevels commented on August 16, 2024

The println weirdness here is #77; closing since the AD part of the issue is resolved. Thanks for filing!

from cassette.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.