GithubHelp home page GithubHelp logo

Comments (1)

torfjelde avatar torfjelde commented on June 3, 2024

This is not an issue with ADVI + condition syntax, but an issue with your model definition.

Internally, every ~ statement is converted into = + extras, and similarly .~ is converted into .= + extras.

This means that when you write

x .~ Normal(...)

this will result in some expression involving

x .= ...

But, as this is just standard Julia code, this won't work if x is not yet defined!

In your "old" model, i.e.

@model function model_old(x)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0.0, sqrt(s))
    x .~ Normal(m, sqrt(s))
end

x is provided as an argument, and is thus defined. In your new model, x is not defined before we hit the .~ (and thus .=).

Unfortunately this is not clear from the exception thrown.

The first thing to do when something fails with a model is to just check if you can run it without any inference, e.g.

julia> using Turing

julia> @model function model()
           s ~ InverseGamma(2, 3)
           m ~ Normal(0.0, sqrt(s))
           x .~ Normal(m, sqrt(s))
       end
model (generic function with 2 methods)

julia> model_instance = model();

julia> model_instance()
ERROR: UndefVarError: `x` not defined
...

Here we see that we get a more informative error message:)

So, the way to write a .~ + use the new condition syntax is to specifically allocate the x before we hit .~:

julia> @model function model_v2(n)  # need to specify the length as input
           s ~ InverseGamma(2, 3)
           m ~ Normal(0.0, sqrt(s))
           x = Vector(undef, n)
           x .~ Normal(m, sqrt(s))
       end
model_v2 (generic function with 2 methods)

julia> model_instance = model_v2(10);

julia> model_instance()
10-element Vector{Float64}:
  1.4266801924189414
  0.8200920046396959
  0.7113019610151704
  1.231743385599
 -0.5561762370549463
  1.4947248221581675
 -0.9162359604360499
 -1.2980392578414817
 -1.3348312021509032
 -0.44033058315337587

Now, using a Vector(undef, n) is not a great idea in general, as this leads to type-instabilities.

Following advice in the docs (https://turinglang.org/v0.30/docs/using-turing/performancetips#ensure-that-types-in-your-model-can-be-inferred), we should instead to

@model function model_v3(n, ::Type{TV}=Vector{Float64}) where {TV}
    s ~ InverseGamma(2, 3)
    m ~ Normal(0.0, sqrt(s))
    x = TV(undef, n)
    x .~ Normal(m, sqrt(s))
end

Now this will be performant + useable with AD.

As a final note: if you know that a particular variable is always going to be conditioned on and you don't need the flexibility of easily being able to change the values, etc., using the "old" syntax is probably still the way to go as it will be slightly more performant in these cases. If you want performance, .~ is always going to be slower than, say, x ~ filldist(Normal(m, sqrt(s)), length(x)).

Hope this helps!

from turing.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.