I have started using Genetic Algorithms (GA) in my PhD research. They are a very effect and simple tool used to solve optimisation problems. Using a “natural selection” like proses to explore the solution space they can quickly arrive at a good candidate solution. The nice thing about GAs is that they are very easy to implement. However, a basic implementation can be a bit slow.
In deep learning (DL) to speed up the computation we make heavy use of GPUs (and it some cases TPUs).
A lot of work has been doing DL with JAX. parallelise
If you don’t know, JAX is an autograd and XLA library for python.
It allows you to write python code with the numpy API that can be compiled, vectorised, and run on GPUs.
Inspired by a blog post from Will Whitney outlining a cool way to use jax.vmap
to train multiple models at once.
I asked myself, could I use the same technique to speed up my GA and run it on the GPU?
It turns out, with a few minor caveats, you can! Check out the code on my github
- What is a GA?
- A quick overview of what a GA is and how it works
- SIMD?
- How it works?
What is a GA?
A quick overview of what a GA is and how it works.
The GA is inspired by the biologic method of evolution through natural selection, where fitter individuals are more likely to pass on their genetic material. There is an initial population of \(N\) candidates. During each iteration of the GA a new population is created by combining and mutating the existing members based on their fitness. An iteration of the algorithm is analogous to a generation.
Each candidate has genetic representation, a set of “chromosomes” that make up a solution. A basic encoding is simply a binary string comprised of 1s and 0s for each solution variable. We also need a fitness function that can take a solution representation and evaluate how good it is. It will return a scalar value (a single number) to represent how good the solution is, lower is normally better.
With search-based optimisation methods there is a trade-off between exploration and exploitation. For the same “budget,” we can either explore more areas of the solution space in the hope of finding improvement; or we can exploit the local optima (goodness) we have found and try and make it better. GAs can strike a good balance in the trade-off between exploration and exploitation.
Crossover and Mutation
This is how new members (candidate solutions) are added to the population. The processes mimics that of reproduction, two members are selected and exchange genetic material and create a new child solution. We then apply a mutate operator the new child. The crossover is how we exploit the search space while mutation drives the exploration.
In the case of the basic binary string encoding. The crossover operator just randomly chooses the value of each chromosome from each of the two parents.
Candidate = list[bool]
def crossover(p1:Candidate, p2:Candidate, threshold:float=0.5) -> Candidate:
c = p1[:] # Copy
for i in range(len(p1)):
if uniform(0, 1) < threshold:
c[i] = p2[i]
return c
While the mutate operator can simply randomly flip bits.
def mutate(s: Candidate, mutate_rate) -> Candidate
for i in range(len(s))
if uniform(0, 1) < mutate_rate
s[i] = not s[i]
return s
Having taught undergrads GAs for a few years now it’s always amusing how invested some of them get to the parent-child analogy, on more than one occasion I have seen code along the lines of crossover(mum, dad)
.
Parent Selection
This is where the section part of the “natural selection” bit happens. In a typical GA, parent selection is done where each candidate is given a probability proportional to its fitness. That is to say the more fit an individual, the more likely it is to be selected to be a parent. Note that the two parents could be the same.
However there are alternative methods, for example each generation only the best \(n\%\) of solutions in the population can have children and then are sampled uniformly. In fact there are a lot of variations that can be made to the GA. You can play with the parent selection criteria, scale the gene selection based on fitness, have the mutation rate drop with each generation etc.
Bellow is a basic implementation that just takes the best half of the population and randomly samples from that to create a new generation. It also keeps the best half of the population from the previous generation.
population = init_pop()
pop_size = len(population)
mutate_rate = 0.3
for n in NUM_GENERATIONS:
fitness = [fit_fn(p) for p in population]
candidates = population[fitness.argsort()[:pop_size / 2 ]]
new_sols = []
for _ in range(pop_size / 2):
p1, p2 = randint(pop_size / 2), randint(pop_size / 2)
child = mutate(crossover(candidates[p1], candidates[p2]), mutate_rate)
new_sols.append(child)
population = candidates + new_sols
SIMD
Have you ever wondered why deep learning runs on GPUs? It’s all thanks to SIMD. SIMD stands for “single instruction, multiple data” and is one of the four types of parallel processing. SIMD is distinct from the typical way we think about parallel code i.e. a multi-processor architecture where each core runs its own set of instructions on its own data, “multiple instruction, multiple data (MIMD).” In the case of SIMD we want to apply the same operation to lots of data. This is normally done thought vectorisation of the program, using specific instructions that leverage specialised hardware, we can apply the same instruction to a vector (an array) of data.
Imagine we have a vector, and we want to add 1 to every element. In a classic process we would have to loop over every element e.g:
for i in range(len(vector))
vector[i] += 1
In this case the processor would update each element in the vector one after the other. However using jax.vmap
we can vectorise this operation:
jax.vmap(lambda x: x+1)(vector)
By using SIMD rather than the operation taking len(vector)
cycles, using the correct instructions and hardware, it can be done in a single cycle.
Modern CPUs have support for vector operations with their AVX instruction set, but GPUs are much better.
And thanks to libraries like JAX it’s, at times, trivially easy to convert code to SIMD.
How does this relate to deep learning? Well all those fancy models are mostly just made up of a lot of operations that look something like \(X \cdot W + b\) over vary big matrices. This is something that GPUs can parallelise very well.
Speeding up GAs
So what, does all this have to do with GAs?
In short, we can use JAX to offload the work to the GPU and jax.vmap
SIMD framework to maximise the GPUs usage.
My first step was to jax.jit
compile the GA so that it ran on the GPU.
Doing so had a minimal improvement on performance.
After running some performance analysis, I could see that overall GPU usage was low.
Based on Will Whitney blog post and my previous experience training multiple models at once, the solution was obviously to use jax.vmap
.
Looking at the code for above GA there are 2 key places where we can apply vectorisation.
In the call to the fit function and the crossover-mutate loop.
Appling vmap
to fitness function was an easy update: fitness = jax.vmap(fit_fn)(population)
.
For the crossover and mutate it’s a little more involved because of JAX and random numbers.
I wrote a new function that does both crossover and mutate that can be vectorized
def mix_and_mutate(mix_key: PRNGKeyArray, mutate_key: PRNGKeyArray, twist: Array, parent):
# Some jax tricks to deal with randomness
mix_key = jax.random.fold_in(mix_key, twist)
mutate_key = jax.random.fold_in(mutate_key, twist)
# This is the mix and mutate step
child = mix(mix_key, parent[0], parent[1])
child = mutate(mutate_key, child, rate=self.mutate_rate)
return child
Putting it all together the main loop of the GA no looks something like:
for n in NUM_GENERATIONS:
# Parallel call to fitness function
fitness = jax.vmap(fit_fn)(population)
# "Select" and Sample
candidates = population[fitness.argsort()[:pop_size / 2 ]]
parents = jax.random.choice(key, candidates, replace=True, shape=(population_size // 2, 2))
# Parallel call to mix and mutate
twists = jnp.arrange(population_size // 2, dtype=int) # This is needed for randomness
new_pop = jax.vmap(mix_and_mutate, in_axes=(None, None, 0, 0))(mix_key, mutate_key, twists, parents)
# Make the new generation
population = jax.stack([candidates, new_pop])
With these changes, the GA runs on GPU and makes good use of it.
There is still a bit of overhead with the jax.jit
compilation.
This is in part because JAX will unroll the outer generation loop.
A way to get around this is to use the jax.lax.scan
function which produces a single XLA step.
Scan functions similarly haskell’s fold
or python’s reduce
operator.
Per the docs, its haskell type signature would be (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
.
In English, scan takes as input:
- A function that takes as input a carry object,
c
, and ana
, and returns a new carry and ab
. - A cary object
- A list of
a
And returns a tuple of the final cary and a list of b
s.
scan
applies the function it takes as input sequentially to every element in list of a
s to make a list of b
s;
each time it calls the function it also passes in the cary object returned by the previous calls.
By doing this it allows the function to maintain some state so that it can acts like it’s looping over the list.
To use it in our GA the can update it to:
def run_generation(population, _):
# The same loop body for the GA
...
return new_population
population, _ = jax.lax.scan(run_generation, init=population, length=self.generations, xs=None)
Since we are using scan to replace the loop we don’t care about the actual value (the b
s) we just ignore them.
The return value (in the cary object) will be the final population which should contain the best solution found i.e the fittest.
The Caveats
As I alluded to, there are a few limitations to this approach:
- The objective function must be
jax.jit
compilable. This is the biggest limitation. While JAX is very expressive and anything that can be written using numpy should work fine. I can see there being issues with existing code / complex simulations. - Because of limitations with how
jax.vmap
andjax.scan
use memory there are limits on the population size/number of generations. However, with some changes, it should be possible to get around these issues.
For my purpose I was able to write a fitness function that works with JAX.
In fact for my implementation, the fitness function is also conditioned on some input values.
Using another layer of vmap
s I was able to get it to run 10,000 samples, for 1000 generations with a population of 256 in about 5 mins.