Source code for icomo.diffrax_wrapper

"""Wrapper of diffrax functions for convenience."""

import inspect
from collections.abc import Callable
from typing import Optional

import diffrax
import jax
from jaxtyping import Array, ArrayLike, PyTree


[docs] def diffeqsolve( *args, ts_out: Optional[ArrayLike] = None, ODE: Callable[[ArrayLike, PyTree, PyTree], PyTree] | None = None, fixed_step_size: bool = False, **kwargs, ) -> diffrax.Solution: """Solves a system of differential equations. Wrapper of :meth:`diffrax.diffeqsolve`. It accepts the same parameters as :meth:`diffrax.diffeqsolve`. Additionally, for convenience, it allows the specification of the output timesteps ``ts_out`` and the system of differential equations ``ODE`` as keyword arguments. If not specified, it uses the :class:`diffrax.Tsit5` solver. For a simple ODE, it is enough to specify the following keyword arguments: Parameters ---------- ts_out : The timesteps at which the output is returned. Equivalent to ``diffrax.diffeqsolve(..., saveat=diffrax.SaveAt(ts=ts_out))``. It sets additionally the initial time ``t0``, the final time ``t1`` and the time step ``dt0`` of the solver if not specified separately. ODE : The function `f(t, y, args)` that returns the derivatives of the variables of the system of differential equations. Equivalent to ``diffrax.diffeqsolve(..., terms=diffrax.ODETerm(ODE))``. y0 : The initial values of the variables of the system of differential equations. args : PyTree of ArrayLike or Callable The arguments of the system of differential equations. Passed as the third argument to the ODE function. Other Parameters ---------------- fixed_step_size : If True, the solver uses a fixed step size of ``dt0``, i.e it uses ``stepsize_controller=diffrax.ConstantStepSize()``. Returns ------- sol : The solution of the system of differential equations. sol.ys contains the variables of the system of differential equations at the output timesteps. """ signature = inspect.signature(diffrax.diffeqsolve) signature_bound = signature.bind_partial(*args, **kwargs) if fixed_step_size: if "stepsize_controller" in signature_bound.arguments: raise TypeError( "Don't simultaneously specify `icomo.diffeqsolve(..., " "fixed_step_size=True, " "stepsize_controller=...)`" ) kwargs["stepsize_controller"] = diffrax.ConstantStepSize() if ts_out is not None: if "saveat" in signature_bound.arguments: raise TypeError( "Don't simultaneously specify `icomo.diffeqsolve(..., " "ts_out=..., saveat=...)`" ) kwargs["saveat"] = diffrax.SaveAt(ts=ts_out) if "t0" not in signature_bound.arguments: kwargs["t0"] = ts_out[0] if "t1" not in signature_bound.arguments: kwargs["t1"] = ts_out[-1] if "dt0" not in signature_bound.arguments: kwargs["dt0"] = ts_out[1] - ts_out[0] else: if "t0" not in signature_bound.arguments: raise TypeError( "Specify `icomo.diffeqsolve(..., t0=...)` and/or " "`icomo.diffeqsolve(..., ts_out=...)`" ) if "t1" not in signature_bound.arguments: raise TypeError( "Specify `icomo.diffeqsolve(..., t1=...)` and/or " "`icomo.diffeqsolve(..., ts_out=...)`" ) if "dt0" not in signature_bound.arguments: raise TypeError( "Specify `icomo.diffeqsolve(..., dt0=...)` and/or " "`icomo.diffeqsolve(..., ts_out=...)`, or " "`icomo.diffeqsolve(..., ts_solver=...)`" ) if ODE is not None: if "terms" in signature_bound.arguments: raise TypeError( "Don't simultaneously specify `icomo.diffeqsolve(..., " "ODE=..., terms=...)`" ) kwargs["terms"] = diffrax.ODETerm(ODE) else: if "terms" not in signature_bound.arguments: raise TypeError( "Specify either `icomo.diffeqsolve(..., ODE=...)` or " "`icomo.diffeqsolve(..., terms=...)`" ) if "solver" not in signature_bound.arguments: kwargs["solver"] = diffrax.Tsit5() return diffrax.diffeqsolve(*args, **kwargs)
[docs] def interpolate_func( ts_in: ArrayLike, values: ArrayLike, method: str = "cubic", ret_gradients: bool = False, ) -> Callable[[ArrayLike], Array]: """ Return a diffrax-interpolation function that can be used to interpolate pytensors. Wrapper of :class:`diffrax.CubicInterpolation` and :class:`diffrax.LinearInterpolation`. Parameters ---------- ts_in : The timesteps at which the time-dependent variable is given. vales : The time-dependent variable. method : The interpolation method used. Can be "cubic" or "linear". ret_gradients : If True, the function returns the gradient of the interpolation function. Returns ------- interp : The interpolation function. Call ``interp(t)`` to evaluate the interpolated variable at time `t`. `t` can be a float or an ArrayLike. """ ts_in = jax.numpy.array(ts_in) if method == "cubic": coeffs = diffrax.backward_hermite_coefficients(ts_in, values) interp = diffrax.CubicInterpolation(ts_in, coeffs) elif method == "linear": interp = diffrax.LinearInterpolation(ts_in, values) else: raise RuntimeError( f'Interpolation method {method} not known, possibilities are "cubic" or ' f'"linear"' ) if ret_gradients: # return jax.vmap(interp.derivative, 0, 0) return interp.derivative else: # return jax.vmap(interp.evaluate, 0, 0) return interp.evaluate