"""Convert a jax function to a pytensor compatible function."""
import inspect
import logging
from collections.abc import Sequence
import jax
# import jax.tree
import jax.numpy as jnp
import numpy as np
import pytensor.scalar as ps
import pytensor.tensor as pt
from jax.tree_util import tree_flatten, tree_leaves, tree_map, tree_unflatten
from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify
log = logging.getLogger(__name__)
class _Shape(tuple):
pass
[docs]
def jax2pytensor(
jaxfunc,
output_shape_def=None,
args_for_graph="all",
name=None,
static_argnames=(),
dtype=None,
):
"""Return a pytensor from a jax jittable function.
It requires to define the output types of the returned values as pytensor types. A
unique name should also be passed in case the name of the jaxfunc is identical to
some other node. The design of this function is based on
https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/
Parameters
----------
jaxfunc : jax jittable function
function for which the node is created, can return multiple tensors as a tuple.
It is required that all return values are able to transformed to
pytensor.TensorVariable.
output_shape_def : function
Function that returns the shape of the output. If None, the shape is expected to
be the same as the shape of the args_for_graph arguments. If not None, the
function should return a tuple of shapes, it will receive as input the shapes of
the args_for_graph as tuples. Shapes are defined as tuples of integers or None.
args_for_graph : list of str or "all"
If "all", all arguments except arguments passed via **kwargs are used for the
graph. Otherwise specify a list of argument names to use for the graph.
name: str
Name of the created pytensor Op, defaults to the name of the passed function.
Only used internally in the pytensor graph.
dtype: pytensor type
dtype of the in- and outputs, if None, it is inferred from the input types, by
upcasting all inputs.
Returns
-------
A Pytensor Op which can be used in a pm.Model as function, is differentiable
and compilable with both JAX and C backend.
"""
### Construct the function to return that is compatible with pytensor but has the
### same signature as the jax function.
def new_func(*args, **kwargs):
func_signature = inspect.signature(jaxfunc)
(
inputs_list,
inputnames_list,
other_args_dic,
) = _split_arguments(func_signature, args, kwargs, args_for_graph)
nonlocal dtype
if dtype is None:
# infer dtype by finding the upcast of all input dtypes. As inputs might be
# not numpy or pytensors, we temporarily convert them to pytensor first
# to obtain the dtype. We transform all inputs to the same dtype, to
# avoid issues if a variable with integer type is given as input, which
# leads to error when differentiating the function.
inputs_list_tmp = tree_map(lambda x: pt.as_tensor_variable(x), inputs_list)
inputs_flat_tmp, _ = tree_flatten(inputs_list_tmp)
input_types_flat_tmp = [inp.type for inp in inputs_flat_tmp]
dtype = ps.upcast(*[inp_type.dtype for inp_type in input_types_flat_tmp])
del inputs_list_tmp, inputs_flat_tmp, input_types_flat_tmp
# Convert our inputs to symbolic variables
inputs_list = tree_map(
lambda x: pt.as_tensor_variable(x, dtype=dtype), inputs_list
)
inputs_flat, input_treedef = tree_flatten(inputs_list)
input_types_flat = [inp.type for inp in inputs_flat]
### Create internal function that accepts flattened inputs to use for pytensor.
def conv_input_to_jax(func, flatten_output=True):
def new_func(inputs_list_flat):
inputs_for_graph = tree_unflatten(input_treedef, inputs_list_flat)
inputs_for_graph = tree_map(lambda x: jnp.array(x), inputs_for_graph)
inputs_for_graph_dic = {
arg: val for arg, val in zip(inputnames_list, inputs_for_graph)
}
results = func(**inputs_for_graph_dic, **other_args_dic)
if not flatten_output:
return results
else:
results, output_treedef_local = tree_flatten(results)
if len(results) > 1:
return tuple(
results
) # Transform to tuple because jax makes a difference between
# tuple and list and not pytensor
else:
return results[0]
return new_func
jitted_sol_op_jax = jax.jit(
conv_input_to_jax(jaxfunc),
static_argnames=static_argnames,
)
# Convert static_argnames to static_argnums because make_jaxpr requires it.
static_argnums = tuple(
i
for i, (k, param) in enumerate(func_signature.parameters.items())
if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and k in static_argnames
)
for input_var, inputname in zip(inputs_list, inputnames_list):
for i, input_elem in enumerate(tree_leaves(input_var)):
type = input_elem.type
if None in type.shape:
if output_shape_def is None:
raise RuntimeError(
f"A dimension of input {inputname}, element {i} is "
f"undefined: {type.shape} You need to provide the "
f"output_shape_def function to define the shape of the "
f"output, or set the shape of the tensor to a integer with "
f"input_var.type.shape = (10,) for example."
)
if output_shape_def is not None:
# Convert shape tuples to dummy shape objects, such that a tree_map on
# such an object doesn't interpret the tuple as part of the tree
input_shapes_flat = [_Shape(t.shape) for t in input_types_flat]
input_shapes_list = tree_unflatten(input_treedef, input_shapes_flat)
input_types_dic = {
arg: shapes for arg, shapes in zip(inputnames_list, input_shapes_list)
}
output_shape = output_shape_def(**input_types_dic)
# For flattening the output shapes, we need to redefine what is a leaf, so
# that the shape tuples don't get also flattened.
is_leaf = lambda x: isinstance(x, Sequence) and (
len(x) == 0 or x[0] is None or isinstance(x[0], int)
)
output_shapes_flat, output_treedef = tree_flatten(
output_shape, is_leaf=is_leaf
)
if len(output_shapes_flat) == 0 or not isinstance(
output_shapes_flat[0], Sequence
):
output_shapes_flat = (output_shapes_flat,)
output_types = [
pt.type.TensorType(dtype=dtype, shape=shape)
for shape in output_shapes_flat
]
else:
### Infer output shape and type from jax function, it works by passing
### pt.TensorTyp variables, as jax only needs type and shape information.
_, output_shape_jax = jax.make_jaxpr(
conv_input_to_jax(jaxfunc, flatten_output=False),
static_argnums=static_argnums,
return_shape=True,
)(tree_map(lambda x: x.type, inputs_flat))
out_shape_jax_flat, output_treedef = tree_flatten(output_shape_jax)
output_types = [
pt.TensorType(dtype=var.dtype, shape=var.shape)
for var in out_shape_jax_flat
]
### Create the Pytensor Op, the normal one and the vector-jacobian product (vjp)
def vjp_sol_op_jax(args):
y0 = args[:-len_gz]
gz = args[-len_gz:]
if len(gz) == 1:
gz = gz[0]
primals, vjp_fn = jax.vjp(local_op.perform_jax, *y0)
gz = tree_map(
lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)), gz, primals
)
if len(y0) == 1:
return vjp_fn(gz)[0]
else:
return tuple(vjp_fn(gz))
jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax)
nonlocal name
if name is None:
name = jaxfunc.__name__
SolOp, VJPSolOp = _return_pytensor_ops(name)
local_op = SolOp(
input_treedef,
output_treedef,
input_arg_names=inputnames_list,
input_types=input_types_flat,
output_types=output_types,
jitted_sol_op_jax=jitted_sol_op_jax,
jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax,
other_args=other_args_dic,
)
@jax_funcify.register(SolOp)
def sol_op_jax_funcify(op, **kwargs):
return local_op.perform_jax
@jax_funcify.register(VJPSolOp)
def vjp_sol_op_jax_funcify(op, **kwargs):
return local_op.vjp_sol_op.perform_jax
### Evaluate the Pytensor Op and return unflattened results
output_flat = local_op(*inputs_flat)
if not isinstance(output_flat, Sequence):
output_flat = [output_flat] # tree_unflatten expects a sequence.
output = tree_unflatten(output_treedef, output_flat)
len_gz = len(output_types)
return output
return new_func
def _split_arguments(func_signature, args, kwargs, args_for_graph):
for key in kwargs.keys():
if key not in func_signature.parameters and key in args_for_graph:
raise RuntimeError(
f"Keyword argument <{key}> not found in function signature. "
f"**kwargs are not supported in the definition of the function,"
f"because the order is not guaranteed."
)
arguments_bound = func_signature.bind(*args, **kwargs)
arguments_bound.apply_defaults()
# Check whether there exist an used **kwargs in the function signature
for arg_name in arguments_bound.signature.parameters:
if (
arguments_bound.signature.parameters[arg_name]
== inspect._ParameterKind.VAR_KEYWORD
):
var_keyword = arg_name
else:
var_keyword = None
arg_names = [
key for key in arguments_bound.arguments.keys() if not key == var_keyword
]
arg_names_from_kwargs = [key for key in arguments_bound.kwargs.keys()]
if args_for_graph == "all":
args_for_graph_from_args = arg_names
args_for_graph_from_kwargs = arg_names_from_kwargs
else:
args_for_graph_from_args = []
args_for_graph_from_kwargs = []
for arg in args_for_graph:
if arg in arg_names:
args_for_graph_from_args.append(arg)
elif arg in arg_names_from_kwargs:
args_for_graph_from_kwargs.append(arg)
else:
raise ValueError(f"Argument {arg} not found in the function signature.")
inputs_for_graph_list = [
arguments_bound.arguments[arg] for arg in args_for_graph_from_args
]
inputs_for_graph_list += [
arguments_bound.kwargs[arg] for arg in args_for_graph_from_kwargs
]
inputnames_for_graph_list = args_for_graph_from_args + args_for_graph_from_kwargs
other_args_dic = {
arg: arguments_bound.arguments[arg]
for arg in arg_names
if arg not in inputnames_for_graph_list
}
other_args_dic.update(
**{
arg: arguments_bound.kwargs[arg]
for arg in arg_names_from_kwargs
if arg not in inputnames_for_graph_list
}
)
return (
inputs_for_graph_list,
inputnames_for_graph_list,
other_args_dic,
)
def _return_pytensor_ops(name):
class SolOp(Op):
def __init__(
self,
input_treedef,
output_treeedef,
input_arg_names,
input_types,
output_types,
jitted_sol_op_jax,
jitted_vjp_sol_op_jax,
other_args,
):
self.vjp_sol_op = None
self.input_treedef = input_treedef
self.output_treedef = output_treeedef
self.input_arg_names = input_arg_names
self.input_types = input_types
self.output_types = output_types
self.jitted_sol_op_jax = jitted_sol_op_jax
self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax
self.other_args = other_args
def make_node(self, *inputs):
self.num_inputs = len(inputs)
# Define our output variables
outputs = [pt.as_tensor_variable(type()) for type in self.output_types]
self.num_outputs = len(outputs)
self.vjp_sol_op = VJPSolOp(
self.input_treedef,
self.input_types,
self.jitted_vjp_sol_op_jax,
self.other_args,
)
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
results = self.jitted_sol_op_jax(inputs)
if self.num_outputs > 1:
for i in range(self.num_outputs):
outputs[i][0] = np.array(results[i], self.output_types[i].dtype)
else:
outputs[0][0] = np.array(results, self.output_types[0].dtype)
def perform_jax(self, *inputs):
results = self.jitted_sol_op_jax(inputs)
return results
def grad(self, inputs, output_gradients):
# If a output is not used, it is disconnected and doesn't have a gradient.
# Set gradient here to zero for those outputs.
# raise NotImplementedError()
for i in range(self.num_outputs):
if isinstance(output_gradients[i].type, DisconnectedType):
if None not in self.output_types[i].shape:
output_gradients[i] = pt.zeros(
self.output_types[i].shape, self.output_types[i].dtype
)
else:
output_gradients[i] = pt.zeros((), self.output_types[i].dtype)
result = self.vjp_sol_op(inputs, output_gradients)
if self.num_inputs > 1:
return result
else:
return (result,) # Pytensor requires a tuple here
# vector-jacobian product Op
class VJPSolOp(Op):
def __init__(
self, input_treedef, input_types, jitted_vjp_sol_op_jax, other_args
):
self.input_treedef = input_treedef
self.input_types = input_types
self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax
self.other_args = other_args
def make_node(self, y0, gz):
y0 = [
pt.as_tensor_variable(
_y,
).astype(self.input_types[i].dtype)
for i, _y in enumerate(y0)
]
gz_not_disconntected = [
pt.as_tensor_variable(_gz)
for _gz in gz
if not isinstance(_gz.type, DisconnectedType)
]
# self.num_gz_not_disconnected = len(gz_not_disconntected)
outputs = [in_type() for in_type in self.input_types]
self.num_outputs = len(outputs)
return Apply(self, y0 + gz_not_disconntected, outputs)
def perform(self, node, inputs, outputs):
# inputs = tree_unflatten(self.full_input_treedef_def, inputs)
# y0 = inputs[:-self.num_gz]
# gz = inputs[-self.num_gz:]
results = self.jitted_vjp_sol_op_jax(tuple(inputs))
if len(self.input_types) > 1:
for i, result in enumerate(results):
outputs[i][0] = np.array(result, self.input_types[i].dtype)
else:
outputs[0][0] = np.array(results, self.input_types[0].dtype)
def perform_jax(self, *inputs):
# inputs = tree_unflatten(self.full_input_treedef_def, inputs)
results = self.jitted_vjp_sol_op_jax(tuple(inputs))
if self.num_outputs == 1:
if isinstance(results, Sequence):
return results[0]
else:
return results
else:
return tuple(results)
SolOp.__name__ = name
SolOp.__qualname__ = ".".join(SolOp.__qualname__.split(".")[:-1] + [name])
VJPSolOp.__name__ = "VJP_" + name
VJPSolOp.__qualname__ = ".".join(
VJPSolOp.__qualname__.split(".")[:-1] + ["VJP_" + name]
)
return SolOp, VJPSolOp