Comments (5)
So here's a convenient way to do this if you wanted to be easier on the compiler:
using Cassette
Cassette.@context Ctx
mutable struct IOCallback
f::Any
end
function Cassette.execute(ctx::Ctx, ::typeof(println), args...)
previous = ctx.metadata.f
ctx.metadata.f = () -> (previous(); println(args...))
return nothing
end
ctx = Ctx(metadata = IOCallback(() -> nothing))
a = rand(3)
b = rand(3)
c = Cassette.overdub(ctx, add, a, b)
ctx.metadata.f() # prints the thing
This approach places an inference barrier for closure access/construction, but does not interfere with inference of the sliced function or the callback:
julia> @code_typed Cassette.overdub(ctx, add, a, b)
CodeInfo(
2 1 ── %1 = (Core.getfield)(##overdub_arguments#369, 2)::Array{Float64,1} │
│ %2 = (Core.getfield)(##overdub_arguments#369, 3)::Array{Float64,1} │
└─── goto #3 if not true │
2 ── %4 = invoke Cassette.overdub(_2::Cassette.Context{nametype(Ctx),IOCallback,Cassette.NoPass,Nothing,Nothing}, Main.:+::typeof(+), %1::Array{Float64,1}, %2::Array{Float64,1})::Array{Float64,1}
3 ── %5 = φ (#2 => %4, #1 => $(QuoteNode(Cassette.OverdubInstead())))::Union{OverdubInstead, Array{Float64,1}}
│ %6 = π (%5, Array{Float64,1}) │
3 └─── goto #14 if not true │
4 ── %8 = (Core.tuple)("c = ", %6)::Tuple{String,Array{Float64,1}} │╻ string
└─── goto #12 if not true ││
5 ── %10 = Core._apply::typeof(_apply) ││╻ overdub
│ %11 = (%10)(tuple, %8)::Tuple{String,Array{Float64,1}} │││┃ apply_args
│ %12 = (getfield)(%11, 1)::String ││╻ overdub
│ %13 = (getfield)(%11, 2)::Array{Float64,1} │││
│ %14 = (Core.tuple)(%12, %13)::Tuple{String,Array{Float64,1}} │││╻ print_to_string
└─── goto #7 if not true ││││
6 ── %16 = Core.tuple::typeof(tuple) ││││
│ %17 = Base.nothing::Nothing ││││
└─── %18 = (%16)(%17, print_to_string)::Tuple{Nothing,typeof(print_to_string)} ││││╻╷╷ overdub
7 ── %19 = φ (#6 => %18, #5 => $(QuoteNode(Cassette.OverdubInstead())))::Union{OverdubInstead, Tuple{Nothing,typeof(print_to_string)}}
│ %20 = π (%19, Tuple{Nothing,typeof(print_to_string)}) ││││
└─── goto #9 if not true ││││
8 ── %22 = Base.:(#print_to_string#330)::##print_to_string#330 ││││
│ %23 = Core._apply::typeof(_apply) │││││╻ apply_args
│ %24 = (%23)(tuple, %20, %14)::Tuple{Nothing,typeof(print_to_string),String,Array{Float64,1}}
│ %25 = (getfield)(%24, 1)::Nothing ││││╻ overdub
│ %26 = (getfield)(%24, 2)::typeof(print_to_string) │││││
│ %27 = (getfield)(%24, 3)::String │││││
│ %28 = (getfield)(%24, 4)::Array{Float64,1} │││││
└─── %29 = invoke Cassette.overdub(_2::Cassette.Context{nametype(Ctx),IOCallback,Cassette.NoPass,Nothing,Nothing}, %22::getfield(Base, Symbol("##print_to_string#330")), %25::Nothing, %26::typeof(Base.print_to_string), %27::String, %28::Array{Float64,1})::Any
9 ── %30 = φ (#8 => %29, #7 => $(QuoteNode(Cassette.OverdubInstead())))::Any ││││
└─── goto #10 ││││
10 ─ goto #11 ││╻ overdub
11 ─ nothing::Nothing │
12 ─ %34 = φ (#11 => %30, #4 => $(QuoteNode(Cassette.OverdubInstead())))::Any ││
└─── goto #13 ││
13 ─ nothing::Nothing │
14 ─ %37 = φ (#13 => %34, #3 => $(QuoteNode(Cassette.OverdubInstead())))::Any │
│ %38 = (Core.tuple)(%37)::Tuple{Any} │
│ %39 = (Base.getfield)(##overdub_context#368, :metadata)::IOCallback ││╻ getproperty
│ %40 = (Base.getfield)(%39, :f)::Any │││
│ %41 = (Base.getfield)(##overdub_context#368, :metadata)::IOCallback ││╻ getproperty
│ %42 = Main.:(##4#5)::Type{##4#5} ││
│ %43 = (Core.typeof)(%38)::Type{#s55} where #s55<:Tuple{Any} ││
│ %44 = (Core.typeof)(%40)::DataType ││
│ %45 = (Core.apply_type)(%42, %43, %44)::Type{##4#5{_1,_2}} where _2 where _1 ││
│ %46 = %new(%45, %38, %40)::##4#5{_1,_2} where _2 where _1 ││
│ (Base.setfield!)(%41, :f, %46)::##4#5{_1,_2} where _2 where _1 ││╻ setproperty!
└─── goto #16 if not false │
15 ─ nothing::Nothing │
4 16 ─ return %6 │
) => Array{Float64,1}
julia> @code_typed ctx.metadata.f()
CodeInfo(
3 1 ─ %1 = (Core.getfield)(#self#, :args)::Tuple{String} │
│ %2 = (getfield)(%1, 1)::String │
│ %3 = invoke Main.println(%2::String)::Const(nothing, false) │
└── return %3 │
) => Nothing
There is also a low-level way to do this with a manual pass; here is a gist implementing the pass. This approach is basically what you did manually in your example; it introduces a slot in order to accumulate the closure, thus removing the barrier to fully inferring the closure construction/access. You could modify this pass to e.g. move argument construction into the callback itself if you wanted.
Of course, actually doing this comes at the cost of being much rougher on the compiler. There are likely many cases with this approach where the compiler will actually give up tightening an inference result due to the usual suspects (splatting, recursion limiting, etc.), even though there is no theoretically barrier to inferring things exactly.
Aside: I'm in the process of writing docs, and this is actually a pretty good example for writing passes...
from cassette.jl.
Thanks for the example.
how do we achieve it via Cassette's interface
To paraphrase a former U.S. president, "it depends on what your definition of is it
is" 😛
Is your goal to use Cassette to implement the continuation handling, or just massage the code into the form you need for your scheduler?
happy to give other examples
Yes please 🙂Or just expanding the one in the OP a bit, whichever is easier
from cassette.jl.
Ok, here's a really contrived but simple example. Imagine I want to do program slicing to pull out all I/O commands from a function. If I have something like
function add(a, b)
c = a + b
println("c = $c")
return c
end
Then I want to call
result, callback = slice_io(add, 2, 3)
result
will be 5
, as if we had run the original function, and callback()
will print c = 5
to stdout. We can implement this as a compiler pass by transforming to something like (simplifying drastically):
function add(a, b)
c = a + b
return c, () -> println("c = $c")
end
Of course, we can't actually put closures in the IR, so instead we might do something like
struct IOCallback{F,T}
data::T
end
function add(a, b)
c = a + b
return c, IOCallback{Tuple{typeof(add),Int,Int}}((c,))
end
Calling an IOCallback
is then another generated function which does essentially the same thing as overdub
; it grabs the IR for the method type F
and returns a modified version.
Cassette could pretty easily abstract some of that away by providing a generic Callback{Ctx,F,T}
or something equivalent, but I haven't much considered how to build an API around it beyond that.
from cassette.jl.
I think there might be a decently easy answer to this depending on which semantics you need. If you actually only need to slice out IO calls, you might be able to simply update a trace-local pointer to a closure object allocated at the beginning of the trace.
However, your example is slightly different than that - you seem to be slicing out a whole statement rather than a call. If that is a representative example, then the next question becomes "what is the pointcut for this transformation".
To clarify, what would your transformed code for this function need to be to get the semantics you want:
function addstr(a, b)
c = a + b
cstr = "$a + $b = $c"
println(cstr)
return cstr
end
from cassette.jl.
For the sake of argument we can say that println
just closes over whatever arguments it's given, so the string is built in the forward pass. This shouldn't be a critical piece on information though; the pass should be able to decide to recompute everything, or special-case strings, or whatever, if it chooses to.
from cassette.jl.
Related Issues (20)
- Inference issues with accessing hcat-ed arrays
- Tag new version HOT 5
- Internal error when overdubbing HTTP.request HOT 1
- Open discussion - support for dynamic pass creation HOT 3
- Overdubbing not working when function called from within `@threads for` loop HOT 2
- Code using Cassette fails for nightly builds HOT 2
- Discriminating overdub calls for "same" function/args HOT 1
- Is there a way to bail out of overdubbing? HOT 2
- TagBot trigger issue HOT 10
- Errors on Julia v1.6 HOT 3
- less helpful stacktraces on 1.6 HOT 3
- Error compiling Cubature.hcubature in context Traceur.Trace
- Default value to `reflect` should be `Base.current_world` HOT 1
- Cassette and AbstractInterpeter
- Cassette compilation fails as of Julia commit 6ce28008ba6db324b171909fa8e641fe8bce9db4 HOT 1
- Error in overdub with StaticArrays.jl HOT 5
- Error on ColorTypes HOT 2
- Very high TTFX HOT 1
- nightly failing to precompile with `ERROR: LoadError: invalid struct allocation` HOT 3
- `Reflection` should probably store the `MethodInstance` and possibly the `world`.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cassette.jl.