If you’ve ever tried to write object-oriented neural networks in pure JAX, you’ve noticed that that JAX only support functional programming. But we love to organize our code into logical blocks (like Linear(weight, bias)). Equinox is a library that precisely does two main things:
Let’s walk through how these two concept bridge the gap between object-oriented programming and JAX’s functional requirements.
JAX transformations like jax.jit, jax.grad, and jax.vmap only operate on pytrees, which are basically data structures such as lists, tuples, dicts whose leaves are JAX arrays. If you pass a standard python class instance into jax.jit, JAX won’t know how to look inside it to find the arrays, and it will crash.
We can force JAX to understand our custom objects by registering them as pytrees. Equinox automates this entirely.
Instead of making boilerplate flatten and unflatten methods for each model, Equinox uses a metaclass to automate it. By turning your classes into dataclass, it knows exactly what fields exist. It then register the class with jax.tree_util.register_pytree_node.
class ModuleMeta(type):
def __new__(mcs, name, bases, namespace):
cls = super().__new__(mcs, name, bases, namespace)
# Automatically turn it into a frozen dataclass
cls = dataclasses.dataclass(frozen=True)(cls)
# Tell JAX how to unpack (flatten) and repack (unflatten) the object.
def flatten(obj):
fields = dataclasses.fields(obj)
dynamic_values = tuple(getattr(obj, f.name) for f in fields)
aux_data = tuple(f.name for f in fields)
return dynamic_values, aux_data
def unflatten(aux_data, dynamic_values):
obj = object.__new__(cls)
for name, value in zip(aux_data, dynamic_values):
object.__setattr__(obj, name, value)
return obj
# Register it to JAX
jtu.register_pytree_node(cls, flatten, unflatten)
return cls
class MyBaseModule(metaclass=ModuleMeta):
pass
Now you can pass entire model instance (like model = Linear(…) directly into jax.jit or jax.grad. JAX will seamlessly unpack the object, process the arrays inside of it and pack it back up.
What if your model contains things that aren’t arrays? For instance, what if your linear layer stores its activation function like jax.nn.relu or a string like name = ”layer1”.
When you pass this pytree into jax.jit. JAX tries to trace everything. It will treat your activation function as a dynamic array and crash because functions aren’t valid XLA types. Normally, you can fix this using static_argnums, but tracking which nested attribute in as massive model is a nightmare.
Equinox intercepts your data right before it hits JAX. It iterates through the entire pytree and uses a filter to split the tree into two parallel trees:
def filter_jit(fun):
@functools.wrap(fun)
def wrapper(*args, **kwargs):
# Split the inputs into arrays and non-arrays
dynamics, stats = partition((args, kwargs), is_array())
# Flatten the static part to make it hashable for JAX
static_leaves, static_treedef = jtu.tree_flatten(static)
# Tell JAX to trace the dynamic part but treat the static part as constants
@functools.partial(jax.jit, static_argsnums=(1,2))
def compiled_fun(_dynamic, _static_leavs, _static_treedef):
_static = jtu.tree_unflatten(_static_treedef, _static_leaves)
_args, _kwargs = combine(_dynamic, _static)
return fun(*args, **kwargs)
return compiled_fun(dynamic, tuple(static_leaves), static_treedef)
return wrapper
You never have to worry about static_argnums again. You can drop strings, booleans, and functions anywhere inside your model when you call @filer_jit , it automatically sorts out what needs to be traced for GPU and what needs to be treated as static constant.
By combining these two ideas, Equinox allows you to write pythonic classes: