Transform Jax to Pytensor#
This function wraps a JAX function, such that it can be used in a PyTensor graph. This allows to build more complex models with PyMC, for instance including the solution of differential equations, or the solution of a non-linear system of equations, as the many JAX libraries can be used in the model.
API#
- icomo.jax2pytensor(jaxfunc, name=None)[source]#
Return a Pytensor from a JAX jittable function.
This decorator transforms any JAX jittable function into a function that accepts and returns pytensor.Variables. The jax jittable function can accept any nested python structure (pytrees) as input, and return any nested Python structure.
It requires to define the output types of the returned values as pytensor types. A unique name should also be passed in case the name of the jaxfunc is identical to some other node. The design of this function is based on https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/
- Parameters:
jaxfunc (jax jittable function) – function for which the node is created, can return multiple tensors as a tuple. It is required that all return values are able to transformed to pytensor.Variable.
name (str) – Name of the created pytensor Op, defaults to the name of the passed function. Only used internally in the pytensor graph.
- Returns:
A function which can be used in a pymc.Model as function, is differentiable and the resulting model can be compiled either with the default C backend, or the JAX backend.
Examples
A simple example on how the
jax.numpy.sum()function is used in a PyMC model.>>> import jax.numpy as jnp >>> import pymc as pm >>> import icomo >>> import numpyro >>> numpyro.set_host_device_count(4) >>> with pm.Model() as model: ... x = pm.Normal("input", mu=1, size=3) ... sum_pt = icomo.jax2pytensor(jnp.sum) ... sum_x = sum_pt(x) ... obs = pm.Normal("obs", sum_x, observed=3) >>> trace = pm.sample(model=model, nuts_sampler="numpyro") >>> print(np.round(np.mean(trace.posterior["input"].to_numpy()),1)) 1.0
Or a more complex example. One might want to build a model which includes the solution of non-linear equation. For instance the intersection of log(x) and 1/x:
>>> import optimistix >>> with pm.Model() as model: ... arg1 = pm.HalfNormal("arg1") ... arg1 = pm.math.clip(arg1,0.1,10) ... f1 = lambda x, arg1: jnp.log(x*arg1) ... f2 = lambda x: 1/x ... @icomo.jax2pytensor ... def find_intersection(funcs, arg1): ... f1, f2 = funcs["f1"], funcs["f2"] ... loss = lambda x, _: (f1(x, arg1) - f2(x))**2 ... # Loss is squared to make it more convex ... res = optimistix.minimise(fn = loss, ... solver=optimistix.BFGS(rtol=1e-8, atol=1e-8), ... y0=3.) ... return res ... ... intersection_res = find_intersection({"f1": f1, "f2": f2}, arg1) ... intersection = pm.Deterministic("inters", intersection_res.value) ... obs = pm.Normal("obs", mu=intersection, sigma=0.1, observed=3) >>> trace = pm.sample(model = model, nuts_sampler="numpyro") >>> print(f"Inters. = {np.round(np.mean(trace.posterior['inters'].to_numpy()),1)}") Inters. = 3.0 >>> print(f"Std. int. = {np.round(np.std(trace.posterior['inters'].to_numpy()),1)}") Std. int. = 0.1
Jax2pytensor also wraps returned functions from transformed functions, such that they can be used either in another transformed function, or evaluated directly to obtain the pytensor result:
>>> with pm.Model() as model: ... arg1 = pm.HalfNormal("arg1") ... arg1 = pm.math.clip(arg1,0.1,10) ... @icomo.jax2pytensor ... def f1_creator(arg1): ... def f_log(x): ... return jnp.log(x*arg1) ... return f_log ... f1 = f1_creator(arg1) ... ... f2 = lambda x: 1/x ... ... @icomo.jax2pytensor ... def find_intersection(f1, f2): ... loss = lambda x, _: (f1(x) - f2(x))**2 ... res = optimistix.minimise(fn = loss, ... solver=optimistix.BFGS(rtol=1e-8, atol=1e-8), ... y0=3.) ... return res ... ... # Pass wrapped f2 function to find_intersection, it creates a pytensor op ... # that also includes arg1 as its input. ... intersection_res = find_intersection(f1, f2) ... intersection = pm.Deterministic("inters", intersection_res.value) ... ... # Or evaluate f2 directly ... log_var = pm.Deterministic("log_var", f2(intersection)) ... obs = pm.Normal("obs", mu=intersection, sigma=0.1, observed=3) >>> trace = pm.sample(model = model, nuts_sampler="numpyro") >>> print(f"Inters. = {np.round(np.mean(trace.posterior['inters'].to_numpy()),1)}") Inters. = 3.0 >>> print(f"Std. int. = {np.round(np.std(trace.posterior['inters'].to_numpy()),1)}") Std. int. = 0.1 >>> print(f"log_var = {np.round(np.mean(trace.posterior['log_var'].to_numpy()),1)}") log_var = 0.3 >>> # It also works with the default sampler backend >>> trc2 = pm.sample(model = model, cores=1, progressbar=False) >>> print(f"log_var = {np.round(np.mean(trc2.posterior['log_var'].to_numpy()),1)}") log_var = 0.3
This feature of wrapping returned functions hasn’t been tested extensively, so please report any issues you might encounter.
Notes
The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, available at pymc-labls.io. To accept functions and non pytensor variables as input, the function make use of
equinox.partition()andequinox.combine()to split and combine the variables. Shapes are inferred usingpytensor.compile.builders.infer_shape()andjax.eval_shape().