The Prospero Challenge: Artisinal JIT
23 April, 2025
I recently got effectively nerd-sniped by Matt Keeter's Prospero Challenge.
The reference implementation is a simple interpreter which walks through the instruction list, and uses NumPy to compute in bulk the values at each pixel location. Intermediate results (complete grids of values) are stored in memory for use by later instructions. NumPy takes advantage of SIMD for the actual calculations, but the end result is still fairly slow because the memory traffic is extremely high. Matt notes that this allocates 60 GB of RAM for intermediate results (but some basic optimization helps a lot with peak memory usage).
Since the data dependencies are within each pixel and not across pixels, my first thought was that we can at least improve cache usage a lot by iterating over pixels and computing the full expression, rather than iterating over instructions and computing the full grid. Of course, we can still use SIMD to compute multiple pixels at a time up to the supported width of our favorite SIMD target.
Even better, a quick analysis confirmed what I was hoping for: a significant fraction of intermediate results are used only once, and often shortly after they are computed. So in theory we can keep these in registers and never store them in memory at all!
However, computing "deep" slices ruins one advantage the reference implementation had on its side: the interpreter loop overhead really starts to matter if you need to run it hundreds of thousands of times. Although Python is notoriously slow, given that there are ~1 million pixels and ~8 thousand instructions, I suspect (but did not measure) that the interpreter overhead is negligible in the reference implementation.
If we can't interpret, what are our options? Other submissions tried all kinds of off the shelf high-powered JIT/compiler tools, like CUDA, LLVM, Cranelift, and others including Matt's own Fidget system for implicit surfaces. Cranelift has a reputation for being easy to set up and use, but LLVM and CUDA... not so much. But even a friendly JIT like Cranelift seems pretty overpowered considering that we have to support 11 instructions and no control flow.
How hard can it be to just make our own JIT and output some simple AVX instructions?
The instructions
I don't currently have a machine with AVX-512, so, so let's look at what we can do with AVX2 for each instruction in prospero.vm
.
Each 256 bit AVX register (ymm0
-ymm15
) can hold 8 single-precision floating point numbers, so we'll evaluate 8 pixels at a time.
For demonstration purposes we'll store the result in ymm0
in each case, with arguments in ymm1
/ymm2
.
var-x
/var-y
Let's just assume these are already stored (aligned!) in memory, varying appropriately for the pixels we're computing:
vmovaps ymm0, [rdi]
vmovaps ymm0, [rsi]
const
We can't directly broadcast a constant to an AVX register; we need to first move it into an SSE register (xmm0
), which occupies the bottom part of its corresponding AVX register.
Annoyingly, we also can't put a constant straight into the SSE register either.
mov eax, N
vmovd xmm0, eax
vbroadcastss ymm0, xmm0
This felt like a lot of instructions and it seems like loading from memory is more typical. I got a minor 1.5% improvement by passing an array with all the constants pre-broadcasted and just loading them directly, but it also makes the compiler code somewhat more complex.
vmovaps ymm0, [r8 + OFFSET]
The binary operators
These are easy! Thankfully each binary operator corresponds to exactly one 3-parameter instruction:
vaddps ymm0, ymm1, ymm2
vsubps ymm0, ymm1, ymm2
vmulps ymm0, ymm1, ymm2
vmaxps ymm0, ymm1, ymm2
vminps ymm0, ymm1, ymm2
neg
I was kind of surprised that there is no native negation instruction. The Internet suggests a few different ways to implement it, but I chose the simple option of getting a zero register with XOR and then subtracting:
vxorps ymm0, ymm0, ymm0
vsubps ymm0, ymm0, ymm1
square
Multiply by itself, of course:
vmulps ymm0, ymm1, ymm1
sqrt
Thankfully also just one instruction:
vsqrtps ymm0, ymm1
Generating machine code
Generating the correct machine code directly for our desired instructions would not be terribly difficult but it would be incredibly tedious.
The Rust crate iced-x86
frees us from having to worry about encoding details.
use iced_x86::code_asm::*;
let mut asm = CodeAssembler::new(64).unwrap();
asm.vmovaps(ymm0, ptr(rsi)).unwrap();
Using the literal number 64
rather than an enum to indicate 64-bit mode is a bit WTF but the library seems otherwise of decent quality.
Register allocation
This is the main hard part, especially as there are far more intermediate values than registers, so spill choices ought to matter a lot. I wanted to see how well a really simple linear-scan allocator would do.
At a high level, the algorithm is pretty simple:
- Compute live ranges for each intermediate value
- Walk over all the instructions, and allocate parameter/output slots greedily
- If a parameter is already in a register, use that
- If a parameter has been spilled to memory, choose a register to load it to
- When we run out of registers, choose a register to spill based on some heuristic
The choice of heuristic is the interesting part. I went with the dead simple thing of always spilling the value that has the earliest last use. This is not great though if something is used only once far in the future and it holds onto a register the whole time. I also tried spilling the value with the latest next use, but it was more complex and had slightly worse performance.
It's also possible to avoid using registers altogether in some cases as AVX can take some parameters from memory, but for simplicity I always loaded spilled values into a fresh register.
Experiments
These things turned out to make essentially no difference:
- Aligning the memory usage and using
vmovaps
overvmovups
- Hash consing the expressions to remove duplicate values, which eliminates 275 out of the 7866 instructions
These made a big difference:
- Reusing spill slots when they become dead (more on this below)
- Parallelizing the rendering with
rayon
Results
TLDR: on my Ryzen 7 2700X desktop (8 cores, 16 threads), the full 1024x1024 image renders in ~48 ms. For comparison on my ThinkPad with a Core i7-1355U (2 P-cores, 8 E-cores, 12 threads), it takes ~80 ms.
This seems pretty solid for a CPU-only solution that is not using any fancy expression simplification techniques.
Startup latency
Since we have a JIT, there is a startup cost to instantiating the machine code for the program.
Phase | Ryzen 7 2700X (μs) | Core i7-1355U (μs) |
---|---|---|
Parsing | 702 | 519 |
Optimization (deduplication by hash consing) | 1,207 | 885 |
Generating machine code | 5,029 | 2,627 |
Compilation is not totally insignificant, but I didn't try to optimize it.
It's interesting to see that the laptop, while slower overall, beats the desktop by so much on startup with its newer CPU and better single core performance.
Parallelization
On the desktop scaling is fairly close to linear up to the number of physical cores:
Reusing spill slots
Naively each time we spill a value it would get a fresh "slot" in memory. However once a spilled value is loaded for the last time, a future spill can reuse that slot. It makes sense that this would help with cache locality, but the effect was much larger than I expected.
On the laptop1 this brought rendering time from 119 ms to 54 ms. The actual number of slots needed went from 2781 (~89 kB) to 142 (~4.5 kB).
Number of registers
The above effect made me start to question how much register use really mattered, or if it was all just a question of cache locality. Threads introduce quite a bit of noise, so rendering 128x128 on the desktop with one thread, from the minimum possible 3 registers up to using all 16:
Not at all a shape I expected, to say the least.2 (Reading the processor manuals to get to the bottom of this is a rabbit hole for another time.)
However the lack of a large monotonic improvement makes sense considering the relatively small change in memory accesses:
Conclusion
You can find the code on GitHub.
If there is any moral to this story, it's that performance at this level is very unpredictable. It's a bit of a cliche that you should measure and not guess performance, but it was more true here than I expected.
Future experiments to try
- Try doubling the pixel batch size on a CPU with AVX-512 (or quadrupling with FP16?)
- Is there a topological sorting of the instructions which has better locality and substantially less required spilling?
- Keeping the simplicity of the interpreter but improving the cache locality by operating in medium sized batches of pixels; is memory bandwidth a problem if it mostly fits in L1 cache?
1 Sadly at this point in development my aging desktop decided to start overheating when running the benchmarks...
2 The shape on the laptop is somewhat closer to monotonic but still a bit strange.