I have started using Genetic Algorithms (GA) in my PhD research. They are a very effect and simple tool used to solve optimisation problems. Using a “natural selection” like proses to explore the solution space they can quickly arrive at a good candidate solution. The nice thing about GAs is that they are very easy to implement. However, a basic implementation can be a bit slow.

In deep learning (DL) to speed up the computation we make heavy use of GPUs (and it some cases TPUs). A lot of work has been doing DL with JAX. parallelise If you don’t know, JAX is an autograd and XLA library for python. It allows you to write python code with the numpy API that can be compiled, vectorised, and run on GPUs. Inspired by a blog post from Will Whitney outlining a cool way to use jax.vmap to train multiple models at once. I asked myself, could I use the same technique to speed up my GA and run it on the GPU? It turns out, with a few minor caveats, you can! Check out the code on my github

  • What is a GA?
    • A quick overview of what a GA is and how it works
  • SIMD?
  • How it works?

What is a GA?

A quick overview of what a GA is and how it works.

The GA is inspired by the biologic method of evolution through natural selection, where fitter individuals are more likely to pass on their genetic material. There is an initial population of \(N\) candidates. During each iteration of the GA a new population is created by combining and mutating the existing members based on their fitness. An iteration of the algorithm is analogous to a generation.

Each candidate has genetic representation, a set of “chromosomes” that make up a solution. A basic encoding is simply a binary string comprised of 1s and 0s for each solution variable. We also need a fitness function that can take a solution representation and evaluate how good it is. It will return a scalar value (a single number) to represent how good the solution is, lower is normally better.

With search-based optimisation methods there is a trade-off between exploration and exploitation. For the same “budget,” we can either explore more areas of the solution space in the hope of finding improvement; or we can exploit the local optima (goodness) we have found and try and make it better. GAs can strike a good balance in the trade-off between exploration and exploitation.

Crossover and Mutation

This is how new members (candidate solutions) are added to the population. The processes mimics that of reproduction, two members are selected and exchange genetic material and create a new child solution. We then apply a mutate operator the new child. The crossover is how we exploit the search space while mutation drives the exploration.

In the case of the basic binary string encoding. The crossover operator just randomly chooses the value of each chromosome from each of the two parents.

Candidate = list[bool]
def crossover(p1:Candidate, p2:Candidate, threshold:float=0.5) -> Candidate:
    c = p1[:] # Copy
    for i in range(len(p1)):
        if uniform(0, 1) < threshold:
            c[i] = p2[i]
    return c

While the mutate operator can simply randomly flip bits.

def mutate(s: Candidate, mutate_rate) -> Candidate
    for i in range(len(s))
         if uniform(0, 1) < mutate_rate
            s[i] = not s[i]
    return s

Having taught undergrads GAs for a few years now it’s always amusing how invested some of them get to the parent-child analogy, on more than one occasion I have seen code along the lines of crossover(mum, dad).

Parent Selection

This is where the section part of the “natural selection” bit happens. In a typical GA, parent selection is done where each candidate is given a probability proportional to its fitness. That is to say the more fit an individual, the more likely it is to be selected to be a parent. Note that the two parents could be the same.

However there are alternative methods, for example each generation only the best \(n\%\) of solutions in the population can have children and then are sampled uniformly. In fact there are a lot of variations that can be made to the GA. You can play with the parent selection criteria, scale the gene selection based on fitness, have the mutation rate drop with each generation etc.

Bellow is a basic implementation that just takes the best half of the population and randomly samples from that to create a new generation. It also keeps the best half of the population from the previous generation.

population = init_pop()
pop_size = len(population)
mutate_rate = 0.3

for n in NUM_GENERATIONS:
    fitness = [fit_fn(p) for p in population]
    candidates = population[fitness.argsort()[:pop_size / 2 ]]

    new_sols = []
    for _ in range(pop_size / 2):
        p1, p2 = randint(pop_size / 2), randint(pop_size / 2)
        child = mutate(crossover(candidates[p1], candidates[p2]), mutate_rate)
        new_sols.append(child) 

    population = candidates + new_sols

SIMD

Have you ever wondered why deep learning runs on GPUs? It’s all thanks to SIMD. SIMD stands for “single instruction, multiple data” and is one of the four types of parallel processing. SIMD is distinct from the typical way we think about parallel code i.e. a multi-processor architecture where each core runs its own set of instructions on its own data, “multiple instruction, multiple data (MIMD).” In the case of SIMD we want to apply the same operation to lots of data. This is normally done thought vectorisation of the program, using specific instructions that leverage specialised hardware, we can apply the same instruction to a vector (an array) of data.

Imagine we have a vector, and we want to add 1 to every element. In a classic process we would have to loop over every element e.g:

for i in range(len(vector))
    vector[i] += 1

In this case the processor would update each element in the vector one after the other. However using jax.vmap we can vectorise this operation:

jax.vmap(lambda x: x+1)(vector)

By using SIMD rather than the operation taking len(vector) cycles, using the correct instructions and hardware, it can be done in a single cycle. Modern CPUs have support for vector operations with their AVX instruction set, but GPUs are much better. And thanks to libraries like JAX it’s, at times, trivially easy to convert code to SIMD.

How does this relate to deep learning? Well all those fancy models are mostly just made up of a lot of operations that look something like \(X \cdot W + b\) over vary big matrices. This is something that GPUs can parallelise very well.

Speeding up GAs

So what, does all this have to do with GAs? In short, we can use JAX to offload the work to the GPU and jax.vmap SIMD framework to maximise the GPUs usage.

My first step was to jax.jit compile the GA so that it ran on the GPU. Doing so had a minimal improvement on performance. After running some performance analysis, I could see that overall GPU usage was low. Based on Will Whitney blog post and my previous experience training multiple models at once, the solution was obviously to use jax.vmap. Looking at the code for above GA there are 2 key places where we can apply vectorisation. In the call to the fit function and the crossover-mutate loop. Appling vmap to fitness function was an easy update: fitness = jax.vmap(fit_fn)(population). For the crossover and mutate it’s a little more involved because of JAX and random numbers. I wrote a new function that does both crossover and mutate that can be vectorized

def mix_and_mutate(mix_key: PRNGKeyArray, mutate_key: PRNGKeyArray, twist: Array, parent):
    # Some jax tricks to deal with randomness 
    mix_key = jax.random.fold_in(mix_key, twist)
    mutate_key = jax.random.fold_in(mutate_key, twist)
    # This is the mix and mutate step
    child = mix(mix_key, parent[0], parent[1])
    child = mutate(mutate_key, child, rate=self.mutate_rate)
    return child

Putting it all together the main loop of the GA no looks something like:

for n in NUM_GENERATIONS:
    # Parallel call to fitness function
    fitness = jax.vmap(fit_fn)(population)
    
    # "Select" and Sample
    candidates = population[fitness.argsort()[:pop_size / 2 ]]
    parents = jax.random.choice(key, candidates, replace=True, shape=(population_size // 2, 2))

    # Parallel call to mix and mutate
    twists = jnp.arrange(population_size // 2,  dtype=int) # This is needed for randomness
    new_pop = jax.vmap(mix_and_mutate, in_axes=(None, None, 0, 0))(mix_key, mutate_key, twists, parents)

    # Make the new generation
    population = jax.stack([candidates, new_pop])

With these changes, the GA runs on GPU and makes good use of it. There is still a bit of overhead with the jax.jit compilation. This is in part because JAX will unroll the outer generation loop. A way to get around this is to use the jax.lax.scan function which produces a single XLA step. Scan functions similarly haskell’s fold or python’s reduce operator. Per the docs, its haskell type signature would be (c -> a -> (c, b)) -> c -> [a] -> (c, [b]). In English, scan takes as input:

  • A function that takes as input a carry object, c, and an a, and returns a new carry and a b.
  • A cary object
  • A list of a

And returns a tuple of the final cary and a list of bs.

scan applies the function it takes as input sequentially to every element in list of as to make a list of bs; each time it calls the function it also passes in the cary object returned by the previous calls. By doing this it allows the function to maintain some state so that it can acts like it’s looping over the list. To use it in our GA the can update it to:

def run_generation(population, _):
    # The same loop body for the GA
    ...
    return new_population

population, _ = jax.lax.scan(run_generation, init=population, length=self.generations, xs=None)

Since we are using scan to replace the loop we don’t care about the actual value (the bs) we just ignore them. The return value (in the cary object) will be the final population which should contain the best solution found i.e the fittest.

The Caveats

As I alluded to, there are a few limitations to this approach:

  1. The objective function must be jax.jit compilable. This is the biggest limitation. While JAX is very expressive and anything that can be written using numpy should work fine. I can see there being issues with existing code / complex simulations.
  2. Because of limitations with how jax.vmap and jax.scan use memory there are limits on the population size/number of generations. However, with some changes, it should be possible to get around these issues.

For my purpose I was able to write a fitness function that works with JAX. In fact for my implementation, the fitness function is also conditioned on some input values. Using another layer of vmaps I was able to get it to run 10,000 samples, for 1000 generations with a population of 256 in about 5 mins.