GithubHelp home page GithubHelp logo

pnkraemer / probdiffeq Goto Github PK

View Code? Open in Web Editor NEW
26.0 3.0 2.0 31.07 MB

Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem.

Home Page: https://pnkraemer.github.io/probdiffeq/

License: MIT License

Makefile 1.00% Python 99.00%
differential-equations-solvers jax scientific-computing probabilistic-numerics

probdiffeq's People

Contributors

lahramon avatar pnkraemer avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

probdiffeq's Issues

Are checkpoint-smoothers different srategies than Smoothers?

Checkpointing and smoothing are not the biggest friends. Why? Because the solution of the smoother is the backward model,
and the backward model needs to be collapsed cleverly (this is a fixed-point smoother -- see upcoming paper).

But having this collapsing in the default solver step is a little annoying, because it does not apply to generic solves. See #86.
So I would suggest to have a FixedPointSmoother() and a DynamicFixedPointSmoother() that implement steps with collapsing, and to let the other models have fun on their own -- no collapsing.
This would enable the easiest separation of concerns, especially when it comes to debugging weird checkpointing-corner-cases. See #76

Tutorials

We need to get started with tutorials and benchmarks

Autodiff

Adaptive solvers

  • Forward-mode

Non-adaptive solvers

  • Forward-mode
  • Reverse-mode

State vs. solution

What is the difference between a state and a solution?

Currently, there is no difference. This becomes problematic at the extract_fn state, which is when unwanted quantities are discarded: sometimes this replaces a few attributes, sometimes this foregoes original data structures. I find this hard to remember.

One consideration might be to aim for something closer to the

params = extract_fn(init_fn(params))
solution = extract_fn(init_fn(solution))

which would imply:

init_fn: solution -> state
extract_fn: state -> solution

While this is easy to remember, it would necessitate some refactorings in terms of what the state contains and what not.

Rename inits to taylor_series?

This way, the init_fn() in the ODE filter could worry about actual initialisation of the state instead of calling the taylor series.
This would be useful, because it would admit re-initialisation (after a checkpoint, for example).

Implementation naming

  • "Implementation" Is probably the least informative name
  • Maybe something like "IsotropicIBMSqrt" or something?
  • The names of the methods, especially the separation between "extrapolate_mean", "finish_extrapolation", and "revert_markov_kernel" is not good at the moment.

Initial step-size

#91 comments out some initial-step-size code to make things work. Please figure this out and put it back in.

Swap vector field signature to (t, *y, *p)

Often we redefine

def vf(*ys): return vector_field(*ys, t, *p)

which does not really work.
Instead, if we do either

def vf(*ys): return vector_field(t, *ys, *p)

or

def vf(*ys): return vector_field(*ys, *p, t)

there will be much less hazzle with positional-only and keyword-only arguments.

Delegate solve() routines to lower-level versions

How about we implement a couple of methods like

# The wrappers that a user would expect

def simulate_terminal_value(vector_field, t0, t1, u0, taylor_diff_fn, solver):
    taylor_coefficients = taylor_diff(vector_field, num=solver.num_derivatives)
    return odesimulate_terminal_value(vector_field, t0, t1, taylor_coeffs, solver)

def simulate_checkpoints(vector_field, ts, u0, taylor_diff_fn, solver):
    taylor_coefficients = taylor_diff(vector_field, num=solver.num_derivatives)
    return odesimulate_checkpoints(vector_field, ts, taylor_coeffs, solver)

# The actual solvers I'd like to provide

def odesimulate_terminal_value(vector_field, t0, t1, taylor_coeffs, solver):

    # Creates an initial solution object from the Taylor coefficients 
    # (But not the full state -- this decouples the Taylor-coefficient stuff 
    # from the state initialisation and essentially 
    # resolves #48 #85 and probably even more issues)
    # In the ``jax.optimizers`` world, it would be the initial PyTree of Params
    # But here, this is a little too solver-dependent to ask from the user.
    solution = solver.taylorcoefficients_to_solution(taylor_coefficients, t0, t1)

    def cond_fun(state):  # can make an argument, no problem
        return state.accepted.t < state.t1

    return simulate(vector_field, t0, t1, solution, solver)

def odesimulate_checkpoints(vector_field, ts, taylor_coefficients, solver):

    # See above
    solution = solver.taylorcoefficients_to_solution(taylor_coefficients, t0, t1)

    def cond_fun(state):  # can make an argument, no problem
        return state.accepted.t < state.t1

    full_solution = []  # pseudo-init_fn()
    for t0, t1 in zip(ts[:-1], ts[1:]):  # this would be a scan, actually. 
        solution = simulate(vector_field, t0, t1, solver, solution)  # pseudo-apply_fn()
        full_solution.append(solution)
    return full_solution  # pseudo-extract_fn()

# The low-level init-apply-extract schemes and while-loops
# We could even make the choice of backend function an argument of the simulation.

def simulate_no_lax(vector_field, t0, t1, solver: Solver[T], solution: T) -> T:
    problem = (vector_field, t0, t1)
    state = solver.init_fn(*problem, initial_solution)
    while cond_fun(state):
        state = solver.step_fn(*problem, state)
    solution = solver.extract_fn(state)
    return solution

def simulate(vector_field, solver, solution, cond_fun):
    state = solver.init_fn(initial_solution)
    state = lax.while_loop(cond_fun, lambda s: solver.step_fn(vector_field, state=s), state)
    return solver.extract_fn(state)

def simulate_diffrax(vector_field, solver, solution, cond_fun):
    state = solver.init_fn(initial_solution)
    state = diffrax.bounded_while_loop(cond_fun, lambda s: solver.step_fn(vector_field, state=s), state)
    return solver.extract_fn(state)

which would resolve a couple of problems:

  • init_fn and extract_fn could be inverse to each other (properly!) #85
  • The solver loses the Taylor-mode component, which is really something that extends the problem definition instead of helping to solve it (the clarity of how the ODE filter operates would improve!)
  • Code would be super readable because every function is minimal. We would not need many docs, because the code is so trivial.

Extrapolation styles

  • IsotropicImplementation
  • KroneckerImplementation
  • DenseImplementation
  • BatchedImplementation

Are low-level tests redundant?

I find myself refactoring purely with the test_ivpsolve.py module, and then go through updating the lower-level tests afterwards, once things are working.

Does this mean they should leave?

Debug-nan-clean code

At the moment, unused variables (dummies) are created with jnp.nan * jnp.ones_like(), potentially via a tree_map.

This is useful to make sure that these values are actually not used, but it makes it impossible to debug the code with jax_debug_nan flags. This should be resolved, once all the functionality is definitely correct.

Scalar ODEs

What should the policy be for scalar ODEs? Not allow them until further notice? It kind of messes up shapes and stuff (Isotropic states would become (n,) instead of (n,1) and I dont know whether the broadcasting is made for this)

Change IVP signature

From f(y, *ys, *p) to f(t, y, p), potentially with f(t, y=(y1,), p=(p1,)) parameters. A lot of the implementation of autodiff, partial, jit, etc., becomes a lot easier if every vector field has exactly three arguments. To this end, it might be best to enforce that y and p are always iterables/tuples.

Create benchmark utils

Would be quite usefull if we could extract some benchmark utility functions and put them in a separate file docs/benchmarks/_benchmarkutils.py

  • Timing a function (including jitting?)
  • relative and absolute errors
  • reading the most recent version/commit
  • tbc

Early rejection in ODE solver step-adaptation

At the moment, there is always exactly one Adaptive() instance, which wraps always exactly one ODEFIlter() instance.
Should we simplify the implementation and construct an AdaptiveODEFilter() class?

Yes, adaptive time-stepping is a more general concept than ODE filtering. But there are clear advantages of merging them:

  • Early rejection: the ODE filter knows the error estimate in line 2 of a 30-line solver step implementation. We don't need any matrix-matrix operation, and can reject steps extremely early! No wasted computations here! This is a big factor!
  • odefilters constructs, by definition, ODE filters. No need to be too general until it is needed
  • Some SolverState attributes are not clearly ODEFilter or AdaptiveSolver-territory (for example, error_estimate)
  • The step itself could be engrained a bit more deeply into the step-selection mechanism. This would mean something like:
class AdaptiveODEFilter:
    strategy: Union[DynamicFilter, DynamicSmoother, Filter, Smoother]
    information: Information
    control: Control
    
    # (...)    
    def step_fn(vector_field, t0, t1, state):
        x = some_init_fn(...)
        while cond_fn(x):
            x = self.strategy.extrapolate(x, dt)
            y = self.information.linearize(vector_field, x.u)
            error = self.strategy.estimate_error(x)
            dt = self.control.update_fn(error, dt)

        # Handles the heavy matrix-matrix lifting
        # Can also do the end-of-time-domain interpolation, if necessary
        return self.strategy.complete_step(x, y, error, dt)    

where a lot of computation is avoided compared to wrapping ODE filters like any other ODE solver

  • TBC

Strategies

  • Filter
  • Smoother
  • FixedPointSmother
  • LocallyIteratedFilter / LocalGaussNewtonFilter
  • LocallyIteratedSmoother/ LocalGaussNewtonSmoother
  • LocallyIteratedFixedPointSmoother / LocalGaussNewtonFixedPointSmoother

solve_ivp() variations

I think there should be different modes of solves:

Adaptive solvers

All values (meaningful dense outputs):

  • solve()
  • solve_bounded_while_loop() # optional diffrax dependency
  • solution_generator() # native python

Specific values (indicated by simulate* prefixes)

  • simulate_terminal_values()
  • simulate_checkpoints()

Non-adaptive solvers

(separate, because we can scan and get, for instance, reverse-mode auto diff for free)

  • fixed stepsize
  • fixed evaluation grid

All Solvers are kind of adaptive

No need to have a separate (public) object. It can be baked into the solve() routines; if necessary, as a private object (bc. it will be reused a fair amount I suppose).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.