PyPy: Doing the Prospero-Challenge in RPython
Recently I had a lot of fun playing with the Prospero Challenge by Matt Keeter. The challenge is to render a 1024x1024 image of a quote from The Tempest by Shakespeare. The input is a mathematical formula with 7866 operations, which is evaluated once per pixel.
What made the challenge particularly enticing for me personally was the fact that the formula is basically a trace in SSA-form – a linear sequence of operations, where every variable is assigned exactly once. The challenge is to evaluate the formula as fast as possible. I tried a number of ideas how to speed up execution and will talk about them in this somewhat meandering post. Most of it follows Matt's implementation Fidget very closely. There are two points of difference:
- I tried to add more peephole optimizations, but they didn't end up helping much.
- I implemented a "demanded information" optimization that removes a lot of operations by only keeping the sign of the result. This optimization ended up being useful.
Most of the prototyping in this post was done in RPython (a statically typable subset of Python2, that can be compiled to C), but I later rewrote the program in C to get better performance. All the code can be found on Github.
Input program
The input program is a sequence of operations, like this:
The first column is the name of the result variable, the second column is the operation, and the rest are the arguments to the operation. var-x is a special operation that returns the x-coordinate of the pixel being rendered, and equivalently for var-y the y-coordinate. The sign of the result gives the color of the pixel, the absolute value is not important.
A baseline interpreter
To run the program, I first parse them and replace the register names with indexes, to avoid any dictionary lookups at runtime. Then I implemented a simple interpreter for the SSA-form input program. The interpreter is a simple register machine, where every operation is executed in order. The result of the operation is stored into a list of results, and the next operation is executed. This was the slow baseline implementation of the interpreter but it's very useful to compare against the optimized versions.
This is roughly what the code looks like
Running the naive interpreter on the prospero image file is super slow, since it performs 7866 * 1024 * 1024 float operations, plus the interpretation overhead.
Using Quadtrees to render the picture
The approach that Matt describes in his really excellent talk is to use quadtrees: recursively subdivide the image into quadrants, and evaluate the formula in each quadrant. For every quadrant you can simplify the formula by doing a range analysis. After a few recursion steps, the formula becomes significantly smaller, often only a few hundred or a few dozen operations.
At the bottom of the recursion you either reach a square where the range analysis reveals that the sign for all pixels is determined, then you can fill in all the pixels of the quadrant. Or you can evaluate the (now much simpler) formula in the quadrant by executing it for every pixel.
This is an interesting use case of JIT compiler/optimization techniques, requiring the optimizer itself to execute really quickly since it is an essential part of the performance of the algorithm. The optimizer runs literally hundreds of times to render a single image. If the algorithm is used for 3D models it becomes even more crucial.
Writing a simple optimizer
Implementing the quadtree recursion is straightforward. Since the program has no control flow the optimizer is very simple to write. I've written a couple of blog posts on how to easily write optimizers for linear sequences of operations, and I'm using the approach described in these Toy Optimizer posts. The interval analysis is basically an abstract interpretation of the operations. The optimizer does a sequential forward pass over the input program. For every operation, the output interval is computed. The optimizer also performs optimizations based on the computed intervals, which helps in reducing the number of operations executed (I'll talk about this further down).
Here's a sketch of the Python code that does the optimization:
The resulting optimized traces are then simply interpreted at the bottom of the quadtree recursion. Matt talks about also generating machine code from them, but when I tried to use PyPy's JIT for that it was way too slow at producing machine code.
Testing soundness of the interval abstract domain
To make sure that my interval computation in the optimizer is correct, I implemented a hypothesis-based property based test. It checks the abstract transfer functions of the interval domain for soundness. It does so by generating random concrete input values for an operation and random intervals that surround the random concrete values, then performs the concrete operation to get the concrete output, and finally checks that the abstract transfer function applied to the input intervals gives an interval that contains the concrete output.
For example, the random test for the square operation would look like this:
This test generates a random float b, and two other floats a and c such that the interval [a, c] contains b. The test then checks that the result of the square operation on b is contained in the interval [rmin, rmax] returned by the abstract transfer function for the square operation.
Peephole rewrites
The only optimization that Matt does in his implementation is a peephole optimization rule that removes min and max operations where the intervals of the arguments don't overlap. In that case, the optimizer statically can know which of the arguments will be the result of the operation. I implemented this peephole optimization in my implementation as well, but I also added a few more peephole optimizations that I thought would be useful.
However, it turns out that all my attempts at adding other peephole optimization rules were not very useful. Most rules never fired, and the ones that did only had a small effect on the performance of the program. The only peephole optimization that I found to be useful was the one that Matt describes in his talk. Matt's min/max optimization were 96% of all rewrites that my peephole optimizer applied for the prospero.vm input. The remaining 4% of rewrites were (the percentages are of that 4%):
In the end it turned out that having these extra optimization rules made the total runtime of the system go up. Checking for the rewrites isn't free, and since they apply so rarely they don't pay for their own cost in terms of improved performance.
There are some further rules that I tried that never fired at all:
This investigation is clearly way too focused on a single program and should be re-done with a larger set of example inputs, if this were an actually serious implementation.
Demanded Information Optimization
LLVM has an static analysis pass called 'demanded bits'. It is a backwards analysis that allows you to determine which bits of a value are actually used in the final result. This information can then be used in peephole optimizations. For example, if you have an expression that computes a value, but only the last byte of that value is used in the final result, you can optimize the expression to only compute the last byte.
Here's an example. Let's say we first byte-swap a 64-bit int, and then mask off the last byte:
In this case, the "demanded bits" of the byteswap(a) expression are 0b0...011111111, which inversely means that we don't care about the upper 56 bits. Therefore the whole expression can be optimized to a >> 56.
For the Prospero challenge, we can observe that for the resulting pixel values, the value of the result is not used at all, only its sign. Essentially, every program ends implicitly with a sign operation that returns 0.0 for negative values and 1.0 for positive values. For clarity, I will show this sign operation in the rest of the section, even if it's not actually in the real code.
This makes it possible to simplify certain min/max operations further. Here is an example of a program, together with the intervals of the variables:
This program can be optimized to:
Because that expression has the same result as the original expression: if x > 0.1, for the result of min(x, y) to be negative then y needs to be negative.
Another, more complex, example is this:
Which can be optimized to this:
This is because the sign of min(x, y) is the same as the sign of y if x > 0, and the sign of max(z, min(x, y)) is thus the same as the sign of max(z, y).
To implement this optimization, I do a backwards pass over the program after the peephole optimization forward pass. For every min call I encounter, where one of the arguments is positive, I can optimize the min call away and replace it with the other argument. For max calls I simplify their arguments recursively.
The code looks roughly like this:
In my experiment, this optimization lets me remove 25% of all operations in prospero, at the various levels of my octree. I'll briefly look at performance results further down.
Further ideas about the demanded sign simplification
There is another idea how to short-circuit the evaluation of expressions that I tried briefly but didn't pursue to the end. Let's go back to the first example of the previous subsection, but with different intervals:
Now we can't use the "demanded sign" trick in the optimizer, because neither x nor y are known positive. However, during execution of the program, if x turns out to be negative we can end the execution of this trace immediately, since we know that the result must be negative.
So I experimented with adding return_early_if_neg flags to all operations with this property. The interpreter then checks whether the flag is set on an operation and if the result is negative, it stops the execution of the program early:
This looked pretty promising, but it's also a trade-off because the cost of checking the flag and the value isn't zero. Here's a sketch to the change in the interpreter:
I implemented this in the RPython version, but didn't end up porting it to C, because it interferes with SIMD.
Dead code elimination
Matt performs dead code elimination in his implementation by doing a single backwards pass over the program. This is a very simple and effective optimization, and I implemented it in my implementation as well. The dead code elimination pass is very simple: It starts by marking the result operation as used. Then it goes backwards over the program. If the current operation is used, its arguments are marked as used as well. Afterwards, all the operations that are not marked as used are removed from the program. The PyPy JIT actually performs dead code elimination on traces in exactly the same way (and I don't think we ever explained how this works on the blog), so I thought it was worth mentioning.
Matt also performs register allocation as part of the backwards pass, but I didn't implement it because I wasn't too interested in that aspect.
Random testing of the optimizer
To make sure I didn't break anything in the optimizer, I implemented a test that generates random input programs and checks that the output of the optimizer is equivalent to the input program. The test generates random operations, random intervals for the operations and a random input value within that interval. It then runs the optimizer on the input program and checks that the output program has the same result as the input program. This is again implemented with hypothesis. Hypothesis' test case minimization feature is super useful for finding optimizer bugs. It's just not fun to analyze a problem on a many-thousand-operation input file, but Hypothesis often generated reduced test cases that were only a few operations long.
Visualizing programs
It's actually surprisingly annoying to visualize prospero.vm well, because it's quite a bit too large to just feed it into Graphviz. I made the problem slightly easier by grouping several operations together, where only the first operation in a group is used as the argument for more than one operation further in the program. This made it slightly more manageable for Graphviz. But it still wasn't a big enough improvement to be able to visualize all of prospero.vm in its unoptimized form at the top of the octree.
Here's a visualization of the optimized prospero.vm at one of the octree levels:
The result is on top, every node points to its arguments. The min and max operations form a kind of "spine" of the expression tree, because they are unions and intersection in the constructive solid geometry sense.
I also wrote a function to visualize the octree recursion itself, the output looks like this:
Green nodes are where the interval analysis determined that the output must be entirely outside the shape. Yellow nodes are where the octree recursion bottomed out.
C implementation
To achieve even faster performance, I decided to rewrite the implementation in C. While RPython is great for prototyping, it can be challenging to control low-level aspects of the code. The rewrite in C allowed me to experiment with several techniques I had been curious about:
musttailoptimization for the interpreter.- SIMD (Single Instruction, Multiple Data): Using Clang's
ext_vector_type, I process eight pixels at once using AVX (or some other SIMD magic that I don't properly understand). - Efficient struct packing: I packed the operations struct into just 8 bytes by limiting the maximum number of operations to 65,536, with the idea of making the optimizer faster.
I didn't rigorously study the performance impact of each of these techniques individually, so it's possible that some of them might not have contributed significantly. However, the rewrite was a fun exercise for me to explore these techniques. The code can be found here.
Testing the C implementation
At various points I had bugs in the C implementation, leading to a fun glitchy version of prospero:
To find these bugs, I used the same random testing approach as in the RPython version. I generated random input programs as strings in Python and checked that the output of the C implementation was equivalent to the output of the RPython implementation (simply by calling out to the shell and reading the generated image, then comparing pixels). This helped ensure that the C implementation was correct and didn't introduce any bugs. It was surprisingly tricky to get this right, for reasons that I didn't expect. At lot of them are related to the fact that in C I used float and Python uses double for its (Python) float type. This made the random tester find weird floating point corner cases where rounding behaviour between the widths was different.
I solved those by using double in C when running the random tests by means of an IFDEF.
It's super fun to watch the random program generator produce random images, here are a few:
Performance
Some very rough performance results on my laptop (an AMD Ryzen 7 PRO 7840U with 32 GiB RAM running Ubuntu 24.04), comparing the RPython version, the C version (with and without demanded info), and Fidget (in vm mode, its JIT made things worse for me), both for 1024x1024 and 4096x4096 images:
The demanded info seem to help quite a bit, which was nice to see.
Conclusion
That's it! I had lots of fun with the challenge and have a whole bunch of other ideas I want to try out, thanks Matt for this interesting puzzle.