A 67x speedup by understanding what happens between your kernels
If you’ve ever profiled a JAX program on GPU, you might have been puzzled by something:
“My data is already on the GPU. There’s no CPU↔︎GPU data transfer. Why is it still slow?”
“The HLO while loop doesn't fuse into a single kernel. XLA generates separate kernels for each operation inside the loop body, and each kernel launch has fixed overhead regardless of how tiny the actual compute is.” (- From a real world problem when using lax.scan)
The answer lies in a cost most people never think about: kernel launch overhead.
When JAX executes your Python code, it doesn’t run the math directly. Instead, it traces your function into a computation graph (an HLO/XLA program), and then the XLA compiler turns each operation into one or more GPU kernels — small programs that run on the GPU’s thousands of cores.
Here’s the catch: each kernel may need to be launched from the CPU. Even though the data stays on the GPU, the CPU still has to:
Even when using tools like CUDA Graphs (where the CPU pre-records the sequence of kernels so it’s not involved in every single step), the GPU itself still has to transition from one kernel to the next. Each kernel must run, finish its work, synchronize globally across all its blocks, and write its results back to the GPU’s main global memory (HBM). Only then can the next kernel start.
Each of these kernel transitions takes roughly 5–10 microseconds. That may sound tiny, but it adds up fast.
lax.scan and while_loopjax.lax.scan (and jax.lax.while_loop) is the idiomatic way to write loops in JAX. Under the hood, XLA (JAX’s compiler) turns the loop body into a sequence of kernels. Each operation in the loop body often becomes its own kernel, and the loop itself repeatedly transitions between them.
For an ODE solver doing 100 time steps, this means: