# The wrappers that a user would expect
def simulate_terminal_value(vector_field, t0, t1, u0, taylor_diff_fn, solver):
taylor_coefficients = taylor_diff(vector_field, num=solver.num_derivatives)
return odesimulate_terminal_value(vector_field, t0, t1, taylor_coeffs, solver)
def simulate_checkpoints(vector_field, ts, u0, taylor_diff_fn, solver):
taylor_coefficients = taylor_diff(vector_field, num=solver.num_derivatives)
return odesimulate_checkpoints(vector_field, ts, taylor_coeffs, solver)
# The actual solvers I'd like to provide
def odesimulate_terminal_value(vector_field, t0, t1, taylor_coeffs, solver):
# Creates an initial solution object from the Taylor coefficients
# (But not the full state -- this decouples the Taylor-coefficient stuff
# from the state initialisation and essentially
# resolves #48 #85 and probably even more issues)
# In the ``jax.optimizers`` world, it would be the initial PyTree of Params
# But here, this is a little too solver-dependent to ask from the user.
solution = solver.taylorcoefficients_to_solution(taylor_coefficients, t0, t1)
def cond_fun(state): # can make an argument, no problem
return state.accepted.t < state.t1
return simulate(vector_field, t0, t1, solution, solver)
def odesimulate_checkpoints(vector_field, ts, taylor_coefficients, solver):
# See above
solution = solver.taylorcoefficients_to_solution(taylor_coefficients, t0, t1)
def cond_fun(state): # can make an argument, no problem
return state.accepted.t < state.t1
full_solution = [] # pseudo-init_fn()
for t0, t1 in zip(ts[:-1], ts[1:]): # this would be a scan, actually.
solution = simulate(vector_field, t0, t1, solver, solution) # pseudo-apply_fn()
full_solution.append(solution)
return full_solution # pseudo-extract_fn()
# The low-level init-apply-extract schemes and while-loops
# We could even make the choice of backend function an argument of the simulation.
def simulate_no_lax(vector_field, t0, t1, solver: Solver[T], solution: T) -> T:
problem = (vector_field, t0, t1)
state = solver.init_fn(*problem, initial_solution)
while cond_fun(state):
state = solver.step_fn(*problem, state)
solution = solver.extract_fn(state)
return solution
def simulate(vector_field, solver, solution, cond_fun):
state = solver.init_fn(initial_solution)
state = lax.while_loop(cond_fun, lambda s: solver.step_fn(vector_field, state=s), state)
return solver.extract_fn(state)
def simulate_diffrax(vector_field, solver, solution, cond_fun):
state = solver.init_fn(initial_solution)
state = diffrax.bounded_while_loop(cond_fun, lambda s: solver.step_fn(vector_field, state=s), state)
return solver.extract_fn(state)