GithubHelp home page GithubHelp logo

Function splitting about cassette.jl HOT 5 OPEN

julialabs avatar julialabs commented on July 17, 2024
Function splitting

from cassette.jl.

Comments (5)

jrevels avatar jrevels commented on July 17, 2024 2

So here's a convenient way to do this if you wanted to be easier on the compiler:

using Cassette

Cassette.@context Ctx

mutable struct IOCallback
    f::Any
end

function Cassette.execute(ctx::Ctx, ::typeof(println), args...)
    previous = ctx.metadata.f
    ctx.metadata.f = () -> (previous(); println(args...))
    return nothing
end

ctx = Ctx(metadata = IOCallback(() -> nothing))
a = rand(3)
b = rand(3)
c = Cassette.overdub(ctx, add, a, b)
ctx.metadata.f() # prints the thing

This approach places an inference barrier for closure access/construction, but does not interfere with inference of the sliced function or the callback:

julia> @code_typed Cassette.overdub(ctx, add, a, b)
CodeInfo(
2 1 ── %1  = (Core.getfield)(##overdub_arguments#369, 2)::Array{Float64,1}                │%2  = (Core.getfield)(##overdub_arguments#369, 3)::Array{Float64,1}                │
  └───       goto #3 if not true                                                          │
  2 ── %4  = invoke Cassette.overdub(_2::Cassette.Context{nametype(Ctx),IOCallback,Cassette.NoPass,Nothing,Nothing}, Main.:+::typeof(+), %1::Array{Float64,1}, %2::Array{Float64,1})::Array{Float64,1}
  3 ── %5  = φ (#2 => %4, #1 => $(QuoteNode(Cassette.OverdubInstead())))::Union{OverdubInstead, Array{Float64,1}}%6  = π (%5, Array{Float64,1})                                                     │
3 └───       goto #14 if not true                                                         │
  4 ── %8  = (Core.tuple)("c = ", %6)::Tuple{String,Array{Float64,1}}                     │╻      string
  └───       goto #12 if not true                                                         ││
  5 ── %10 = Core._apply::typeof(_apply)                                                  ││╻      overdub
  │    %11 = (%10)(tuple, %8)::Tuple{String,Array{Float64,1}}                             │││┃      apply_args
  │    %12 = (getfield)(%11, 1)::String                                                   ││╻      overdub
  │    %13 = (getfield)(%11, 2)::Array{Float64,1}                                         │││
  │    %14 = (Core.tuple)(%12, %13)::Tuple{String,Array{Float64,1}}                       │││╻      print_to_string
  └───       goto #7 if not true                                                          ││││
  6 ── %16 = Core.tuple::typeof(tuple)                                                    ││││
  │    %17 = Base.nothing::Nothing                                                        ││││
  └─── %18 = (%16)(%17, print_to_string)::Tuple{Nothing,typeof(print_to_string)}          ││││╻╷╷    overdub
  7 ── %19 = φ (#6 => %18, #5 => $(QuoteNode(Cassette.OverdubInstead())))::Union{OverdubInstead, Tuple{Nothing,typeof(print_to_string)}}%20 = π (%19, Tuple{Nothing,typeof(print_to_string)})                              ││││
  └───       goto #9 if not true                                                          ││││
  8 ── %22 = Base.:(#print_to_string#330)::##print_to_string#330                          ││││%23 = Core._apply::typeof(_apply)                                                  │││││╻      apply_args
  │    %24 = (%23)(tuple, %20, %14)::Tuple{Nothing,typeof(print_to_string),String,Array{Float64,1}}%25 = (getfield)(%24, 1)::Nothing                                                  ││││╻      overdub
  │    %26 = (getfield)(%24, 2)::typeof(print_to_string)                                  │││││
  │    %27 = (getfield)(%24, 3)::String                                                   │││││
  │    %28 = (getfield)(%24, 4)::Array{Float64,1}                                         │││││
  └─── %29 = invoke Cassette.overdub(_2::Cassette.Context{nametype(Ctx),IOCallback,Cassette.NoPass,Nothing,Nothing}, %22::getfield(Base, Symbol("##print_to_string#330")), %25::Nothing, %26::typeof(Base.print_to_string), %27::String, %28::Array{Float64,1})::Any
  9 ── %30 = φ (#8 => %29, #7 => $(QuoteNode(Cassette.OverdubInstead())))::Any            ││││
  └───       goto #10                                                                     ││││
  10 ─       goto #11                                                                     ││╻      overdub
  11nothing::Nothing12%34 = φ (#11 => %30, #4 => $(QuoteNode(Cassette.OverdubInstead())))::Any           ││
  └───       goto #13                                                                     ││
  13nothing::Nothing14%37 = φ (#13 => %34, #3 => $(QuoteNode(Cassette.OverdubInstead())))::Any           │%38 = (Core.tuple)(%37)::Tuple{Any}                                                │
  │    %39 = (Base.getfield)(##overdub_context#368, :metadata)::IOCallback                ││╻      getproperty%40 = (Base.getfield)(%39, :f)::Any                                                │││
  │    %41 = (Base.getfield)(##overdub_context#368, :metadata)::IOCallback                ││╻      getproperty%42 = Main.:(##4#5)::Type{##4#5}                                                   ││%43 = (Core.typeof)(%38)::Type{#s55} where #s55<:Tuple{Any}                        ││%44 = (Core.typeof)(%40)::DataType                                                 ││
  │    %45 = (Core.apply_type)(%42, %43, %44)::Type{##4#5{_1,_2}} where _2 where _1       ││
  │    %46 = %new(%45, %38, %40)::##4#5{_1,_2} where _2 where _1                          ││
  │          (Base.setfield!)(%41, :f, %46)::##4#5{_1,_2} where _2 where _1               ││╻      setproperty!
  └───       goto #16 if not false                                                        │
  15nothing::Nothing4 16return %6                                                                    │
) => Array{Float64,1}

julia> @code_typed ctx.metadata.f()
CodeInfo(
3 1%1 = (Core.getfield)(#self#, :args)::Tuple{String}                                                          │%2 = (getfield)(%1, 1)::String                                                                              │
  │   %3 = invoke Main.println(%2::String)::Const(nothing, false)                                                 │
  └──      return %3                                                                                              │
) => Nothing

There is also a low-level way to do this with a manual pass; here is a gist implementing the pass. This approach is basically what you did manually in your example; it introduces a slot in order to accumulate the closure, thus removing the barrier to fully inferring the closure construction/access. You could modify this pass to e.g. move argument construction into the callback itself if you wanted.

Of course, actually doing this comes at the cost of being much rougher on the compiler. There are likely many cases with this approach where the compiler will actually give up tightening an inference result due to the usual suspects (splatting, recursion limiting, etc.), even though there is no theoretically barrier to inferring things exactly.

Aside: I'm in the process of writing docs, and this is actually a pretty good example for writing passes...

from cassette.jl.

jrevels avatar jrevels commented on July 17, 2024

Thanks for the example.

how do we achieve it via Cassette's interface

To paraphrase a former U.S. president, "it depends on what your definition of is it is" 😛

Is your goal to use Cassette to implement the continuation handling, or just massage the code into the form you need for your scheduler?

happy to give other examples

Yes please 🙂Or just expanding the one in the OP a bit, whichever is easier

from cassette.jl.

MikeInnes avatar MikeInnes commented on July 17, 2024

Ok, here's a really contrived but simple example. Imagine I want to do program slicing to pull out all I/O commands from a function. If I have something like

function add(a, b)
  c = a + b
  println("c = $c")
  return c
end

Then I want to call

result, callback = slice_io(add, 2, 3)

result will be 5, as if we had run the original function, and callback() will print c = 5 to stdout. We can implement this as a compiler pass by transforming to something like (simplifying drastically):

function add(a, b)
  c = a + b
  return c, () -> println("c = $c")
end

Of course, we can't actually put closures in the IR, so instead we might do something like

struct IOCallback{F,T}
  data::T
end

function add(a, b)
  c = a + b
  return c, IOCallback{Tuple{typeof(add),Int,Int}}((c,))
end

Calling an IOCallback is then another generated function which does essentially the same thing as overdub; it grabs the IR for the method type F and returns a modified version.

Cassette could pretty easily abstract some of that away by providing a generic Callback{Ctx,F,T} or something equivalent, but I haven't much considered how to build an API around it beyond that.

from cassette.jl.

jrevels avatar jrevels commented on July 17, 2024

I think there might be a decently easy answer to this depending on which semantics you need. If you actually only need to slice out IO calls, you might be able to simply update a trace-local pointer to a closure object allocated at the beginning of the trace.

However, your example is slightly different than that - you seem to be slicing out a whole statement rather than a call. If that is a representative example, then the next question becomes "what is the pointcut for this transformation".

To clarify, what would your transformed code for this function need to be to get the semantics you want:

function addstr(a, b)
    c = a + b
    cstr = "$a + $b = $c"
    println(cstr)
    return cstr
end

from cassette.jl.

MikeInnes avatar MikeInnes commented on July 17, 2024

For the sake of argument we can say that println just closes over whatever arguments it's given, so the string is built in the forward pass. This shouldn't be a critical piece on information though; the pass should be able to decide to recompute everything, or special-case strings, or whatever, if it chooses to.

from cassette.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.