Differential Equation solvers are the heart of scientific ML and neural ODEs. But if you’ve ever tried to write a solver in JAX and then call jax.grad, you’ve likely run into issues. Let’s walk through the evolution of a solver from a basic loop to more advanced version to understand how diffrax really works.

1. The Simple Euler Loop

The simplest way to solver an ODE $dy/dt = f(t,y)$ is Euler’s method: just take small steps $dt$ along the gradient. In JAX, you can use jax.lax.while_loop.

def euler(y0, t0, t1, dt):
		def cond(state):
				t, y = state
				return t < t1
		def body(state):
				t, y = state
				# Euler step: y_next = y + dt * f(t, y)
				y_next = y + dt * vector_field(t, y)
				return y_next
		final_t, final_y = jax.lax.while_loop(cond, body, (t0, y0))
		return final_y

The problem is that if you run jax.grad on this, JAX will crash. Why? Reverse-mode autodiff needs to store the state of every single step to calculate gradients later. Since a while_loop can have an unknown number of steps, JAX has no way to pre-allocate the memory. It is simply won’t be able to run backpropagation.

2. The “Bounded Tree” for “fixed” steps

The bounded tree structure solves the unknown number of steps by introducing a fixed binary tree. Once you define a max_step for the solver to take, the tree size is fixed and becomes static. This allows XLA to compile it successfully. The trick here is that each node has a jax.lax.cond that allows the tree to dynamically “skip” execution if it already reached the final time.

Here is how the bounded tree is defined. As an example, we split the timeline into two halves:

def recursive_solve(y, step_start, step_end):
		t = state[0]
		def _run_tree(_):
				if step_end - step_start == 1:
						return take_eular_step(y)
				mid = (step_start + step_end) // 2
		
				# Save only the state at mid (O(log N) memory)!
				mid_y = jax.checkpoint(lambda y: recursive_solve(y, step_start, mid))(y)
				return recursive_solve(mid_y, mid, step_end)
		return jax.lax.cond(t < t1, _run_tree, lambda _: state, None)

Note that jax.checkpoint is used on the left half, which is telling JAX to don’t save anything internals for the branch and just save the inputs so that we can recompute internals during backpropagation (it is a customized implementation of vjp). This is to save memories especially when the vector field has MLPs. (10 Mb per step * 1000 step = 10G memory vs 10 * $log_2(1000)$ = 100 Mb)

Second and more important note is the the jax.lax.cond(t < t1, …) at the return statement. This means that once $t \geq t_{final}$, the recursive solve will skip the entire subtree and just return the current state.

3. The Online Milestone Strategy

But if you ODE needs a million steps? A tree that deep would be too slow to compile. This is where the Diffrax’s default method (RecursiveCheckpointAdjoint) shines. It uses online checkpointing.

Briefly, in the forward pass, it runs a real dynamic while_loop and at every k steps, it saves the state into a milestone buffer. Then in the backward pass, it manually loops backwards through the milestones. It jumps to the last pin, re-runs just these k steps forward to get the local gradients, then jumps back to the previous pin. In this way, we saved a lot of memory by using only a few checkpoints. This is a more advanced topic and I’ll save it for a later post!

Now you know what’s under the hood of Diffrax adaptive step solvers! Happy modeling 😇