GithubHelp home page GithubHelp logo

Comments (5)

ChrisRackauckas avatar ChrisRackauckas commented on June 14, 2024 1

We can also hack it with getproperty overloading

from delaydiffeq.jl.

ChrisRackauckas avatar ChrisRackauckas commented on June 14, 2024

Yes, this makes sense. I am a little worried about compile times, but maybe it all just quickly compiles away.

from delaydiffeq.jl.

devmotion avatar devmotion commented on June 14, 2024

Yes, hopefully the compiler is smart enough.

However, there's another issue: in the same way we have to pass around integrator to the tgrad, analytic, etc. functions (or not). Of course, this could be ensured on the level of DiffEqFunction in the same way as in the example above by using overloads such as f(Val{:analytic}, ...). But since we switched away from this form I guess that's not a good idea 😄

Alternatively, one could define functions such as

function DiffEqBase.analytic(f::ODEFunction{iip,unpack}, u, integrator, t) where {iip,unpack}
    has_analytic(f) || error("analytical solution is not defined")

    unpack ? f.analytic(u, get_p(integrator), t) : f.analytic(u, integrator, t)
end 

for all such overloads, but I don't know if this makes any difference.

I still like the idea of attacking this problem on the lowest level, but of course an alternative would be to explicitly define p before every (chunk of) function calls, e.g., by defining

function perform_step!(integrator, cache::BS3ConstantCache)
    p = unpack_params(integrator, integrator.f)
    .....
end

unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,false}) where iip = integrator
unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,true}) where iip = get_p(integrator)

from delaydiffeq.jl.

devmotion avatar devmotion commented on June 14, 2024

I'm working on a prototype for ODEFunction and I still hope that not too many changes are necessary in OrdinaryDiffEq.

However, I'm not sure how to deal with the fact that p is used to construct the cache in https://github.com/JuliaDiffEq/OrdinaryDiffEq.jl/blob/master/src/solve.jl#L246 before the ODEIntegrator exists. As far as I can see, p is mostly/only used to construct the Jacobian w.r.t u for the nonlinear solvers in lines such as https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/nlsolve/utils.jl#L195 to evaluate f.jac(uprev, p, t). I mean, if jac is given we want to use it but I don't know how to retrieve its type if it expects a full integrator.

Can we get around this problem somehow by not caching W but passing it around when it's created?

from delaydiffeq.jl.

devmotion avatar devmotion commented on June 14, 2024

Since passing around the integrator in OrdinaryDiffEq is not completely straightforward (at least it seems to me), I started playing around with something that's more centered around the use case in DelayDiffEq. One idea was to use getproperty overloading such that all calls of @unpack f = integrator or integrator.f in OrdinaryDiffEq return an ODE Function with a history that is built on integrator, similar to the following simple example:

using DelayDiffEq, DiffEqBase, Test

struct ODEFunctionWrapper{iip,F,H} <: DiffEqBase.AbstractODEFunction{iip}
    f::F
    h::H
end

function wrap(prob::DDEProblem)
    ODEFunctionWrapper{isinplace(prob.f),typeof(prob.f),typeof(prob.h)}(prob.f, prob.h)
end

(f::ODEFunctionWrapper{false})(u, p, t) = f.f(u, f.h, p, t)
(f::ODEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t)

struct TestStruct{F,A}
    f::F
    a::A
end

function buildTestStruct(prob::DDEProblem, u, p, t)
    f = wrap(prob)
    a = f(u, p, t)

    TestStruct(f, a)
end

function buildTestStruct(prob::DDEProblem, du, u, p, t)
    f = wrap(prob)
    f(du, u, p, t)

    TestStruct(f, first(du))
end

function Base.getproperty(test::TestStruct, x::Symbol)
    if x === :f
        f = getfield(test, :f)
        if isinplace(f)
            (du, u, p, t) -> f.f(du, u, (p, t) -> [t * test.a], p, t)
        else
            (u, p, t) -> f.f(u, (p, t) -> t * test.a, p, t)
        end
    else
        getfield(test, x)
    end
end

function calc(test::TestStruct, u, p, t)
    f = test.f
    f(u, p, t)
end

function calc!(test::TestStruct, du, u, p, t)
    f = test.f
    f(du, u, p, t)
    nothing
end

function f_ip(du, u, h, p, t)
    du[1] = h(p, t)[1] - u[1]
    nothing
end

f_scalar(u, h, p, t) = h(p, t) - u

function test()
    prob_ip = DDEProblem(f_ip, [1.0], (p, t) -> [0.0], (0.0, 10.0))
    prob_scalar = DDEProblem(f_scalar, 1.0, (p, t) -> 0.0, (0.0, 10.0))

    wrap_ip = wrap(prob_ip)
    wrap_scalar = wrap(prob_scalar)

    a = [0.0]
    wrap_ip(a, [5.0], nothing, 0.0)
    @test a[1] == - 5.0
    wrap_ip(a, [5.0], nothing, 5.0)
    @test a[1] == - 5.0
    wrap_ip(a, [5.0], nothing, 10.0)
    @test a[1] == - 5.0

    @test wrap_scalar(5.0, nothing, 0.0) == - 5.0
    @test wrap_scalar(5.0, nothing, 5.0) == - 5.0
    @test wrap_scalar(5.0, nothing, 10.0) == - 5.0

    struct_ip = buildTestStruct(prob_ip, [0.0], [5.0], nothing, 4.0)
    @test struct_ip.a == -5.0

    struct_scalar = buildTestStruct(prob_scalar, 5.0, nothing, 4.0)
    @test struct_scalar.a == -5.0

    b = [0.0]
    calc!(struct_ip, b, [5.0], nothing, 1.0)
    @test b[1] == -10.0
    calc!(struct_ip, b, [5.0], nothing, 4.0)
    @test b[1] == -25.0

    @test calc(struct_scalar, 5.0, nothing, 2.0) == -15.0
    @test calc(struct_scalar, 5.0, nothing, 6.0) == -35.0
end

However, I'm not sure, how this will affect performance if it is possible at all.

from delaydiffeq.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.