Diffrax Wrapper

Diffrax Wrapper#

These are helper functions that wrap some function from the diffrax module like the diffrax.diffeqsolve() function. This is only done for convenience, to define some default values for the parameters of the function. Don’t hesitate to directly the functions from the diffrax module if you need more flexibility.

Functions#

icomo.diffeqsolve(*args, ts_out=None, ODE=None, fixed_step_size=False, **kwargs)[source]#

Solves a system of differential equations.

Wrapper of diffrax.diffeqsolve(). It accepts the same parameters as diffrax.diffeqsolve(). Additionally, for convenience, it allows the specification of the output timesteps ts_out and the system of differential equations ODE as keyword arguments. If not specified, it uses the diffrax.Tsit5 solver. For a simple ODE, it is enough to specify the following keyword arguments:

Parameters:
  • ts_out (ArrayLike | None, default: None) – The timesteps at which the output is returned. Equivalent to diffrax.diffeqsolve(..., saveat=diffrax.SaveAt(ts=ts_out)). It sets additionally the initial time t0, the final time t1 and the time step dt0 of the solver if not specified separately.

  • ODE (Callable[[ArrayLike, PyTree, PyTree], PyTree] | None, default: None) – The function f(t, y, args) that returns the derivatives of the variables of the system of differential equations. Equivalent to diffrax.diffeqsolve(..., terms=diffrax.ODETerm(ODE)).

  • y0 – The initial values of the variables of the system of differential equations.

  • args (PyTree of ArrayLike or Callable) – The arguments of the system of differential equations. Passed as the third argument to the ODE function.

  • fixed_step_size (bool, default: False) – If True, the solver uses a fixed step size of dt0, i.e it uses stepsize_controller=diffrax.ConstantStepSize().

Return type:

Solution

Returns:

sol – The solution of the system of differential equations. sol.ys contains the variables of the system of differential equations at the output timesteps.

icomo.interpolate_func(ts_in, values, method='cubic', ret_gradients=False)[source]#

Return a diffrax-interpolation function that can be used to interpolate pytensors.

Wrapper of diffrax.CubicInterpolation and diffrax.LinearInterpolation.

Parameters:
  • ts_in (ArrayLike) – The timesteps at which the time-dependent variable is given.

  • vales – The time-dependent variable.

  • method (str, default: 'cubic') – The interpolation method used. Can be “cubic” or “linear”.

  • ret_gradients (bool, default: False) – If True, the function returns the gradient of the interpolation function.

  • values (ArrayLike)

Return type:

Callable[[ArrayLike], Array]

Returns:

interp – The interpolation function. Call interp(t) to evaluate the interpolated variable at time t. t can be a float or an ArrayLike.