Source code for icomo.comp_model

"""Tools to create compartmental models and compile them into functions that can be used."""

import logging

import diffrax
import graphviz
from pytensor.tensor.type import TensorType

from icomo.pytensor_op import create_and_register_jax

logger = logging.getLogger(__name__)


[docs]def SIR(t, y, args): """ Differential equations of an SIR model. Parameters ---------- t: float time variable y: dict dictionary of compartments. Has to include keys "S", "I", "R". The value of "S" is the number of susceptible individuals, "I" is the number of infected individuals and "R" is the number of recovered individuals. args: tuple tuple of arguments. The first argument is the function β(t) that describes the infection rate. The second argument is a dictionary of constant arguments. It has to include the key "gamma" which is the recovery rate and the key "N" which is the total population size. Returns ------- dy: dict dictionary of derivatives with keys "S", "I", "R". """ β, const_arg = args γ = const_arg["gamma"] N = const_arg["N"] dS = -β(t) * y["I"] * y["S"] / N dI = β(t) * y["I"] * y["S"] / N - γ * y["I"] dR = γ * y["I"] dy = {"S": dS, "I": dI, "R": dR} return dy
[docs]def Erlang_SEIRS(t, y, args): """Return the derivatives of the Erlang SEIRS model. This function defines the differential equations of an SEIRS model with Erlang distributed latent, infectious and recovering (recovered-to-susceptible) periods. Parameters ---------- t: float time variable y: dict dictionary of compartments. Has to include keys "S", "Es", "Is", "Rs". The value of "Es", "Is", "Rs" have to be a list of length `n` where `n` is the number of compartments/shape of the Erlang distribution/kernel. args: tuple: (callable, dict) Callable is a function with argument t that returns the value of the transmission rate beta(t) dict is the constant arguments of the model. Has to include keys "N", "rate_latent", "rate_infectious" and "rate_recovery". Returns ------- dComp: dict Same as y, but with the derivatives of the compartments. """ beta, const_arg = args N = const_arg["N"] dComp = {} I = sum(y["Is"]) dComp["S"] = -beta(t) * I * y["S"] / N # Latent period dEs, outflow = erlang_kernel( inflow=beta(t) * I * y["S"] / N, Vars=y["Es"], rate=const_arg["rate_latent"], ) dComp["Es"] = dEs # Infectious period dIs, outflow = erlang_kernel( inflow=outflow, Vars=y["Is"], rate=const_arg["rate_infectious"], ) dComp["Is"] = dIs # Recovered/non-susceptible period dRs, outflow = erlang_kernel( inflow=outflow, Vars=y["Rs"], rate=const_arg["rate_recovery"], ) dComp["Rs"] = dRs dComp["S"] = dComp["S"] + outflow return dComp
[docs]def erlang_kernel(inflow, Vars, rate): """Model the compartments delayed by an Erlang kernel. Utility function to model an Erlang kernel for a compartmental model. The shape is determined by the length of Vars. Parameters ---------- inflow: float or ndarray The inflow into the first compartment of the kernel Vars: list of floats or list of ndarrays The compartments of the kernel rate: float or ndarray The rate of the kernel, 1/rate is the mean time spent in total in all compartments Returns ------- dVars: list of floats or list of ndarrays The derivatives of the compartments out_flow: float or ndarray The outflow from the last compartment """ dVars = [] m = len(Vars) for i, Var_i in enumerate(Vars): dVars.append(-m * rate * Var_i) if i == 0: dVars[0] = dVars[0] + inflow else: dVars[i] = dVars[i] + m * rate * Vars[i - 1] out_flow = m * rate * Vars[-1] return dVars, out_flow
[docs]class CompModel: """Class to help building a compartmental model. The model is built by adding flows between compartments. The model is then compiled into a function that can be used in an ODE. """
[docs] def __init__(self, Comp_dict): """Initialize the CompModel class. Parameters ---------- Comp_dict: dict Dictionary of compartments. Keys are the names of the compartments and values are floats or ndarrays that represent their value. """ self.Comp = Comp_dict self.dComp = {} self.graph = graphviz.Digraph("comp_model") for key in self.Comp.keys(): self.graph.node(key) self.graph.attr(rankdir="LR")
[docs] def flow(self, start_comp, end_comp, rate, label=None, end_comp_is_list=False): """ Add a flow from start_comp to end_comp with rate flow. Parameters ---------- start_comp: str or list Key of the start compartment. Can also be a list of keys in which case an identical flow is added from each compartment of the list to end_com end_comp: str Key of the end compartment rate: float or ndarray rate of the flow to add between compartments, is multiplied by start_comp, so it should be broadcastable whith it. label: str, optional label of the edge between the compartments that will be used when displaying a graph of the compartmental model. end_comp_is_list: bool, default: False If True, end_comp points to a list of compartments, and the outflow is added to the first of them. This is typically the case when the end compartment is used with a Erlang kernel. Returns ------- None """ if isinstance(start_comp, str): start_comp_list = [start_comp] else: start_comp_list = start_comp for start_comp in start_comp_list: self._add_dy_to_comp(start_comp, -rate * self.Comp[start_comp]) self._add_dy_to_comp( end_comp, rate * self.Comp[start_comp], end_comp_is_list ) self.graph.edge(start_comp, end_comp, label=label)
[docs] def erlang_flow( self, start_comp, end_comp, rate, label=None, end_comp_is_list=False ): """ Add a flow with erlang kernel from start_comp to end_comp with rate flow. Parameters ---------- start_comp: str start compartments of the flow. The length the list is the shape of the Erlang kernel. end_comp: str end compartment of the flow rate: float or ndarray rate of the flow, equal to the inverse of the mean time spent in the Erlang kernel. label: str, optional label of the edge between the compartments that will be used when displaying a graph of the compartmental model. end_comp_is_list: bool, default: False If True, end_comp points to a list of compartments, and the outflow is added to the first of them. This is typically the case when the end compartment is used with a Erlang kernel. """ if not ( isinstance(self.Comp[start_comp], list) or isinstance(self.Comp[start_comp], tuple) ): raise RuntimeError( "start_comp should refer to a list of compartments to allow" "the modelling of an erlang kernel across those compartments" ) dy, outflow = erlang_kernel(inflow=0, Vars=self.Comp[start_comp], rate=rate) self._add_dy_to_comp(start_comp, dy, comp_is_list=True) self._add_dy_to_comp(end_comp, outflow, comp_is_list=end_comp_is_list) self.graph.edge(start_comp, end_comp, label=label)
def _add_dy_to_comp(self, comp_key, dy, comp_is_list=False): if not comp_is_list: if not comp_key in self.dComp.keys(): self.dComp[comp_key] = 0 self.dComp[comp_key] = self.dComp[comp_key] + dy else: if not comp_key in self.dComp.keys(): self.dComp[comp_key] = [0 for _ in self.Comp[comp_key]] if isinstance(dy, list): if not len(dy) == len(self.Comp[comp_key]): raise RuntimeError( "dy should have the same length as the compartment" ) for i, _ in enumerate(self.Comp[comp_key]): self.dComp[comp_key][i] = self.dComp[comp_key][i] + dy[i] else: # Add only to first comp self.dComp[comp_key][0] = self.dComp[comp_key][0] + dy @property def dy(self): """Returns the derivative of the compartments. This should be returned by the function that defines the system of ODEs. Returns ------- dComp: dict """ return self.dComp
[docs] def view_graph(self, on_display=True): """Display a graph of the compartmental model. Parameters ---------- on_display : bool If True, the graph is displayed in the notebook, otherwise it is saved as a pdf in the current folder and opened with the default pdf viewer. Returns ------- None """ if on_display: try: from IPython.display import display display(self.graph) except: self.graph.view() else: self.graph.view()
[docs]def delayed_copy(initial_var, delayed_vars, tau_delay): """Return the derivative to model the delayed copy of a compartment. The delay has the form of an Erlang kernel with shape parameter len(delayed_vars). Parameters ---------- initial_var : float or ndarray The compartment that is copied delayed_vars : list of floats or list of ndarrays List of compartments that are delayed copies of initial_var, the last element is the compartment which has the same total content as initial_var over time, but is delayed by tau_delay. tau_delay : float or ndarray The mean delay of the copy. Returns ------- d_delayed_vars : list of floats or list of ndarrays The derivatives of the delayed compartments """ length = len(delayed_vars) inflow = initial_var / tau_delay * length d_delayed_vars, outflow = erlang_kernel(inflow, delayed_vars[:], 1 / tau_delay) return d_delayed_vars
[docs]class ODEIntegrator: """Creates an integrator for compartmental models. The integrator is a function that takes as input a function that returns the derivatives of the variables of the system of differential equations, and returns a function that solves the system of differential equations. For the initilization of the integrator object, the timesteps of the solver and the output have to be specified. """
[docs] def __init__( self, ts_out, ts_solver=None, ts_arg=None, interp="cubic", solver=diffrax.Tsit5(), t_0=None, t_1=None, **kwargs, ): """Initialize the ODEIntegrator class. Parameters ---------- ts_out : array-like The timesteps at which the output is returned. ts_solver : array-like or None The timesteps at which the solver will be called. If None, it is set to ts_out. ts_arg : array-like or None The timesteps at which the time-dependent argument of the system of differential equations are given. If None, it is set to ts_solver. interp : str The interpolation method used to interpolate_pytensor the time-dependent argument of the system of differential equations. Can be "cubic" or "linear". solver : :class:`diffrax.AbstractStepSizeController` The solver used to integrate the system of differential equations. Default is diffrax.Tsit5(), a 5th order Runge-Kutta method. t_0 : float or None The initial time of the integration. If None, it is set to ts_solve[0]. t_1 : float or None The final time of the integration. If None, it is set to ts_solve[-1]. **kwargs Arguments passed to the solver, see :func:`diffrax.diffeqsolve` for more details. """ self.ts_out = ts_out if ts_solver is None: self.ts_solver = self.ts_out else: self.ts_solver = ts_solver if t_0 is None: self.t_0 = float(self.ts_solver[0]) else: self.t_0 = t_0 if self.t_0 > self.ts_out[0]: raise ValueError("t_0 should be smaller than the first element of ts_out") if t_1 is None: self.t_1 = float(self.ts_solver[-1]) else: self.t_1 = t_1 if self.t_1 < self.ts_out[-1]: raise ValueError("t_1 should be larger than the last element of ts_out") self.ts_arg = ts_arg self.interp = interp self.solver = solver self.kwargs_solver = kwargs
[docs] def get_func(self, ODE, list_keys_to_return=None): """ Return a function that solves the system of differential equations. Parameters ---------- ODE : function(t, y, args) A function that returns the derivatives of the variables of the system of differential equations. The function has to take as input the time `t`, the variables `y` and the arguments `args=(arg_t, constant_args)` of the system of differential equations. `t` is a float, `y` is a list or dict of floats or ndarrays, or in general, a pytree, see :mod:`jax.tree_util` for more details. The return value of the function has to be a pytree/list/dict with the same structure as `y`. list_keys_to_return : list of str or None, default is None The keys of the variables of the system of differential equations that are returned by the integrator. If set, the integrator returns a list of the variables of the system of differential equations in the order of the keys. If `None`, the output is returned as is. Returns ------- integrator : function(y0, arg_t=None, constant_args=None) A function that solves the system of differential equations and returns the output at the specified timesteps. The function takes as input `y0` the initial values of the variables of the system of differential equations, the time-dependent argument of the system of differential equations `arg_t`, and the constant arguments `constant_args` of the system of differential equations. `t`, `y0` and `(arg_t, constant_args)` are passed to the ODE function as its three arguments. If `arg_t` is `None`, only `constant_args` are passed to the ODE function and vice versa, without being in a tuple. """ def integrator(y0, arg_t=None, constant_args=None): if arg_t is not None: if not callable(arg_t): if self.ts_arg is None: raise RuntimeError("Specify ts_arg to use a non-callable arg_t") arg_t_func = interpolation_func( ts=self.ts_arg, x=arg_t, method=self.interp ).evaluate else: logger.warning( "arg_t is callable, but ts_arg is not None. ts_arg" " won't be used." ) arg_t_func = arg_t if arg_t is None and self.ts_arg is not None: logger.warning( "You did specify ts_arg, but arg_t is None. Did you mean to do this?" ) term = diffrax.ODETerm(ODE) if arg_t is None: args = constant_args elif constant_args is None: args = arg_t_func else: args = ( arg_t_func, constant_args, ) saveat = diffrax.SaveAt(ts=self.ts_out) stepsize_controller = ( diffrax.StepTo(ts=self.ts_solver) if not "stepsize_controller" in self.kwargs_solver else self.kwargs_solver["stepsize_controller"] ) dt0 = ( None if isinstance(stepsize_controller, diffrax.StepTo) else self.ts_solver[1] - self.ts_solver[0] ) sol = diffrax.diffeqsolve( term, self.solver, self.t_0, self.t_1, dt0=dt0, stepsize_controller=stepsize_controller, y0=y0, args=args, saveat=saveat, **self.kwargs_solver, # adjoint=diffrax.BacksolveAdjoint(), ) if list_keys_to_return is None: return sol.ys else: return tuple([sol.ys[key] for key in list_keys_to_return]) return integrator
[docs] def get_op( self, ODE, return_shapes=((),), list_keys_to_return=None, name=None, ): """Return a pytensor operator that solves the system of differential equations. Same as get_func, but returns a pytensor operator that can be used in a pymc model. Beware that for this operator the output of the integration of the ODE can only be a single or a list variables. If the output is a dict, set list_keys_to_return to specify the keys of the variables that are returned by the integrator. These return values aren't allowed to be further nested. Parameters ---------- ODE : function(t, y, args) A function that returns the derivatives of the variables of the system of differential equations. The function has to take as input the time `t`, the variables `y` and the arguments `args=(arg_t, constant_args)` of the system of differential equations. `t` is a float, `y` is a list or dict of floats or ndarrays, or in general, a pytree, see :mod:`jax.tree_util` for more details. The return value of the function has to be a pytree/list/dict with the same structure as `y`. return_shapes : tuple of tuples, default is ((),) The shapes (except the time dimension) of the variables of the system of differential equations that are returned by the integrator. If `list_keys_to_return` is `None`, the shapes have to be given in the same order as the variables are returned by the integrator. If `list_keys_to_return` is not `None`, the shapes have to be given in the same order as the keys in `list_keys_to_return`. The default `((),)` means a single variable with only a time dimension is returned. list_keys_to_return : list of str or None, default is None The keys of the variables of the system of differential equations that will be chosen to be returned by the integrator. Necessary if the ODE returns a `dict`, as :mod:`pytensor` only accepts single outputs or a list of outputs. If `None`, the output is returned as is. name : The name under which the operator is registered in pymc. Returns ------- pytensor_op : :class:`pytensor.graph.op.Op` A :mod:`pytensor` operator that can be used in a :class:`pymc.Model`. """ integrator = self.get_func(ODE, list_keys_to_return=list_keys_to_return) pytensor_op = create_and_register_jax( integrator, output_types=[ TensorType( dtype="float64", shape=tuple([len(self.ts_out)] + list(shape)) ) for shape in return_shapes ], name=name, ) return pytensor_op
[docs]def interpolation_func(ts, x, method="cubic"): """ Return a diffrax-interpolation function that can be used to interpolate_pytensor the time-dependent variable. Parameters ---------- ts : array-like The timesteps at which the time-dependent variable is given. x : array-like The time-dependent variable. method The interpolation method used. Can be "cubic" or "linear". Returns ------- interp : :class:`diffrax.CubicInterpolation` or :class:`diffrax.LinearInterpolation` The interpolation function. Call `interp.evaluate(t)` to evaluate the interpolated variable at time `t`. t can be a float or an array-like. """ if method == "cubic": coeffs = diffrax.backward_hermite_coefficients(ts, x) interp = diffrax.CubicInterpolation(ts, coeffs) elif method == "linear": interp = diffrax.LinearInterpolation(ts, x) else: raise RuntimeError( f'Interpoletion method {method} not known, possibilities are "cubic" or "linear"' ) return interp
[docs]def interpolate_pytensor(ts_in, ts_out, y, method="cubic", ret_gradients=False): """ Interpolate the time-dependent variable `y` at the timesteps `ts_out`. Parameters ---------- ts_in : array-like The timesteps at which the time-dependent variable is given. ts_out : array-like The timesteps at which the time-dependent variable should be interpolated. y : array-like The time-dependent variable. method : str The interpolation method used. Can be "cubic" or "linear". ret_gradients : bool If True, the gradients of the interpolated variable are returned. Default is False. Returns ------- y_interp : array-like The interpolated variable at the timesteps `ts_out`. """ def interpolator(ts_out, y): interp = interpolation_func(ts_in, y, method) if ret_gradients: return interp.derivative(ts_out) else: return interp.evaluate(ts_out) interpolator_op = create_and_register_jax( interpolator, output_types=[ TensorType(dtype="float64", shape=(len(ts_out),)), ], ) return interpolator_op(ts_out, y)