Program Synthesis by Simulated Annealing With Thompson Sampling


problem statement (learning programs)

The problem we're trying to solve here is program synthesis; finding a program that minimizes a loss function. That loss function could be a typical machine learning type task where you're trying to reproduce some training data, or it could be reinforcement learning, or it could be something else. We want this to be a black box algorithm, where we make no assumptions about the loss function.

learning

One simple, classic and widely used algorithm for doing black-box optimization like this is simulated annealing, which is a random walk where at each step you evaluate the loss function and if the loss is worse than it was before you took the step you revert and try again at the old position. Also you scale the loss by an "annealing schedule" which is litreally the same thing as a learning rate schedule (like you might see in a neural net if you've seen a neural net from before everyone started using ADAM).

The problem with simulated annealing is while it almost always will eventually converge it's very slow. Like slow enough that you wonder if it would be faster to just try points at random. The thing we're doing to try to address this is to have the random walk distribution (the thing we're generating the random steps by sampling from) learn to take steps in good directions as we're doing our random walk.

We frame this learning as a bandit problem where our reward is an improvement in the simulated annealing loss. A convenient way to think about solving this bandit problem is as Thompson sampling, where the distribution we're both sampling from and updating is our random walk distribution.

programs

The points in the state space we're optimizing over are programs. The programs we're learning are in a register machine with just about the symplest bytecode format you can have; single accumulator, single operand, fixed width instructions. That is to say all instructions either take their input from or write their output to a single scratch register (which we call the X register). The ISA is described in more detail here (though there really isn't that much to describe). The feature representation for a program is literally just the instructions flattened out.

To evaluate the loss function for a given program we need to execute the program on some data, and we need to be able to do that fast, so it pays to JIT compile. This is done with a copy-and-patch style JIT compiler, except instead of doing the "stencil" application (as the paper calls it) with C++ templates we have a python program that generates a C source file with our x86 instruction sequences encoded as literal values. The source code for that is here.

el camino de monte carlo

What distribution are we using for the programs? Programs are discrete so classic options like mixture of gaussians won't work. One might think that using a kernel density estimator won't work because we can't fit all possible programs in memory, but there's a way around this. We break programs into "snippets", which are short instruction sequences, and we assume that the distribution of P(score goes up when you add this snippet) is power law, so we can use a heavy hitter data structure to store the distribution. Specifically we're using a count min sketch plus a top-k heap.

One way to think about this is as a statistical reframing of genetic programming, since the snippet picking is a lot like mutation and crossing over in genetic programming. I think the advantage to the stats perspective is we have the opportunity to be more rigorous about how we're using randomness since it's baked in the cake from the beginning and not layered on top. I hvaen't taken this approach far enough to really be able to take advantage of that though.

so how does it perform?

I haven't rigorously investigated this so the honest answer is that I don't know, but it can at least learn to memorize data, to do arithmetic, and to approximate arithmetic over trig functions. The main reason why I stopped working on this though is that I'm not sure what application to benchmark it against after that. If anyone has suggestions please send me an email: notes@hella.cheap . You can find the source code for all this here.