How to build a GAN in Python
Learn how to easily build a working Generative Adversarial Network (GAN) in Python, using machine learning to allow an AI to “create” realistic contents!
Table Of Contents
- Introduction
- What is a Generative Adversarial Network?
- A home made GAN
- An artificial dataset
- Our GAN in small pieces
- How to train it?
- Conclusions
Introduction
Generative Adversarial Networks (GANs) are a hot topic in machine learning for several and good reasons. I quote just three of them:
- They can provide astonishing results in creating new things (images, texts, sounds, etc.) trying and succeeding in imitating samples they are previously exposed to.
- They provide a new paradigm of machine learning, the generative one, which combines pre-existing techniques to provide actual and brand new ideas and results.
- They are a recent (2014) creation of Ian Goodfellow, the former Google and now Apple researcher (who also authored a standard reference in deep learning, along with Joshua Bengio and Aaron Courville).
It is quite probable that you have already examined some of the impressive results of GANs especially in the realm of image processing. Such networks are able, upon request, to draw a picture of a red flower, a black bird or even a violet cat. Furthermore, that flower, that bird and that cat just do not exist at all, they are “imagined” by the network somehow.
How is this possible? And, can we share the fun? I will try to address both questions in this article, with some working Python code which can be run on your laptop. Of course, you’ll need some packages if you still miss them in your Python installation, but that’s what pip stands for…
What is a Generative Adversarial Network?
Neural networks (NNs) were devised as prediction and classification models. They are powerful non-linear optimizers which may be trained to evolve their inner parameters (neuron weights) to fit the training data. This will make the NN able to predict and classify unknown data of the same kind.
We all know how impressive neural networks data approximations may be, where “data” can mean everything. However, the features of such algorithms imply also some of their drawbacks, such as:
- Neural network needs labelled data to be trained properly
- Worse, they need a lot of labelled data
- Worse still, we have no idea about the meaning of neuron contents, but for some special cases
In particular, neural networks are intrinsically supervised algorithms. Nonetheless, some of their variants work fine as unsupervised algorithms, thus they can be trained on any kind of data, with no need to know the “label” we usually attach to it to discriminate known things so that the net will be able to discriminate unknown things.
I have already written about some examples of unsupervised networks, for instance when dealing with time series. The trick there is that any time series may be considered as a labelled training set by looking at one if its points as the prediction while the remaining are the input data (see this article for more details).
Another interesting unsupervised setting for neural networks to play is exactly the GAN paradigm, which I’ll briefly recall hereinafter.
Let us begin by the bare words whom the acronym GAN stands for: generative, adversarial, networks. The latter is the most obvious, network: GANs are built up by neural networks, usually deep neural networks. So, we’ll have an input layer with a certain amount of parallel input neurons (one for each number representing the input data), some hidden layers and an output layer, connected in a directed graph and trained by a variant of the gradient-descent backpropagation algorithm.
Next, we come to the word generative, which denotes the aim of this class of algorithms. They produce data rather than consuming them. Better said, the data they produce contains new information of the same “class” of the input data used to generate them. Indeed the generation process is not spontaneous, but rather data are generated from other data, via a mechanism that I’ll describe in a moment.
Finally the word adversarial, the most mysterious, explains how generation is done, thus by a competition between two adversaries. In GAN case, the adversaries are neural network.
Therefore, a GAN aims at generating new data via networks which are put into competition to do that. More specifically, a GAN is always split into two components, which are two neural (usually deep) networks. The first one is the discriminator, and it is trained to discriminate a set of data from pure noise. For example, we could have a collection of flower photos, and a huge amount of other images which have nothing to do with flowers. We so not have explicit labels for each photo, but we just know which is the collection of flower photos, and which is the remaining one.
Then we could train a network which discriminates flowers from non-flowers or, for what it matters, a network which discriminates photos from a picture made of random pixels. This first component of the GAN is a standard network trained to classify things. The input of the discriminator is an example of data we want to generate (a bunch of flower photos if we want to generate flowers), while the output is a yes/no flag.
The other network is the generator: this one produces as output the kind of data the discriminator is able to discriminate. To do that, the generator uses a random input. At first it will produce a random output, but it is trained to backpropagate the information whether its output is similar or not to the data we want to produce (e.g. flower photos).
To do that, we input its predictions to the discriminator. The latter is trained to recognize genuine flowers (or whatever), and if the generator can counterfeit a flower so to trick the discriminator, then our GAN can produce fake flower photos that a well trained observer (the discriminator) takes for genuine flower photos.
At last, we have our generating task accomplished.
Thus, a GAN is like a room where both a forger and an art critic meet: the former submits its fake paintings affirming they are authentic; the latter tries to discover if they actually are. If the forger is so good at counterfeiting so that the critic mistakes the fakes for the actual ones, then we can sell the fakes at an auction hoping that someone will buy them…
Notice that, at a first glance, GANs seems to be analogous to reinforcement learning but that’s just a seeming analogy. Indeed in the former case two networks are put into competition to augment their opposite skills and get fake data which seems genuine one, while in the latter case, a single agent is checked against an environment and reinforced or punished somehow to correct its behaviour. There’s no actual competition but, so to say, a pattern to survive to be discovered.
Rather one may think to GANs as a generalization of the Turing test principle: the discriminator is the tester and the generator the machine willing to pass it. The difference is that in this case both are machines (we already explained why Turing ideas were seminal for machine learning).
A home made GAN
Usually GANs find their most spectacular applications in counterfeiting images, as we have said. However even videos, texts and sounds may be generated, although some technical issues make complicated to implement such “time series generators”.
In most tutorials the classical image generation is shown, typically by using the MNIST dataset to teach the GAN how to write letters and digits. But convolutional networks are required for that, and much details are spent not in the GAN part but in setting the convolutional and “deconvolutional” networks which implement the discriminator and the generator. In addition, to train them is quite long with no appropriate equipment (a description of such GANs can be found in another contribution to Codemotion magazine).
More lazy and more inept, here I’ll show a simple GAN programmed in Python, by using the Keras library, which can be run on any laptop to teach it how to draw a specific class of curves. I’ve chosen sinusoids, but we could use any other pattern.
Thus, in the next lines I’ll do the following:
- To generate a dataset of sinusoids.
- To set up the discriminator and generator networks.
- To use them to build up the GAN.
- To train the GAN showing how to combine the training of its components.
- To contemplate a somewhat skew and distorted sinusoid drawn by the program from pure noise.
Let’s start.
An artificial dataset
Instead of picking a bunch of images, I’ll produce a description of the curves I am interested in: sinusoids may be mathematically described as the graph of functions
a sin(bx+c)
where a, b, c are parameters which determine the height, frequency and phase of the curve. Some example of such curves are plotted in the following picture, which of course we produce via a Python snippet.
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import randint, uniform
X_MIN = -5.0
X_MAX = 5.0
X_COORDS = np.linspace(X_MIN , X_MAX, SAMPLE_LEN)
fig, axis = plt.subplots(1, 1)
for i in range(4):
axis.plot(X_COORDS, uniform(0.1,2.0)*np.sin(uniform(0.2,2.0)*X_COORDS + uniform(2)))
We want our GAN to generate curve with such a form. To keep things simple we just consider a=1and let b∈[1/2,2] and c∈[0,π].
First of all, we define some constants and produce a dataset of such curves. To describe a curve, we do not use the symbolic form by means of the sinus function, but rather we choose some points in the curve, sampled over the same x values, and represent the curve y = f(x) by the vector (y1,…,yN) where yi = f(xi) for the fixed xs.
So we generate those y values by using the previous formula for random values of b and c within the prescribed intervals. After defining the training set, some of these curves are plotted.
import numpy as np
from numpy.random import uniform
import matplotlib.pyplot as plt
SAMPLE_LEN = 64 # number N of points where a curve is sampled
SAMPLE_SIZE = 32768 # number of curves in the training set
X_MIN = -5.0 # least ordinate where to sample
X_MAX = 5.0 # last ordinate where to sample
# The set of coordinates over which curves are sampled
X_COORDS = np.linspace(X_MIN , X_MAX, SAMPLE_LEN)
# The training set
SAMPLE = np.zeros((SAMPLE_SIZE, SAMPLE_LEN))
for i in range(0, SAMPLE_SIZE):
b = uniform(0.5, 2.0)
c = uniform(np.math.pi)
SAMPLE[i] = np.array([np.sin(b*x + c) for x in X_COORDS])
# We plot the first 8 curves
fig, axis = plt.subplots(1, 1)
for i in range(8):
axis.plot(X_COORDS, SAMPLE[i])
Our GAN in small pieces
Now we define our discriminator, which is the neural network used to distinguish a sinusoidal curve from any other set of sampled points. Thus it accepts an input vector (y1, …, yN) and returns 1 if it corresponds to a sinusoidal curve, otherwise 0.
I use the Keras library to create a Sequence
object where to stack the different layers of the network. We arrange this discriminator as a simple shallow multilayer perceptron, with three layers: the input one with N neurons, being N the size of the input vectors, the second with the same number of hidden neurons and the third with just one neuron, the output one.
The input and hidden layer have their outputs filtered by a “relu” function (which cuts negative values of its argument x) and by a “dropout” (which randomly sets input units to 0 with the prescribed frequency at each step during training time, to prevent overfitting).
The output neuron is activated via a sigmoid function which smoothly extends from 0 to 1, the two possible answers.
from keras.models import Sequential
from keras.layers import Dense, Dropout, LeakyReLU
DROPOUT = Dropout(0.4) # Empirical hyperparameter
discriminator = Sequential()
discriminator.add(Dense(SAMPLE_LEN, activation="relu"))
discriminator.add(DROPOUT)
discriminator.add(Dense(SAMPLE_LEN, activation="relu"))
discriminator.add(DROPOUT)
discriminator.add(Dense(1, activation = "sigmoid"))
discriminator.compile(optimizer = "adam", loss = "binary_crossentropy", metrics = ["accuracy"])
Next we come to the generator network. This is in some sense specular to the discriminator, indeed we still have three layers, where the input one accepts a noisy input of the same size of the output (a vector with N elements), applies “leaky relu” function (which cuts negative values of its argument x to a small multiple of the x itself) but do not perform dropout and outputs the result via a hyperbolic tangent function. Since we are not classifying, we use the “mean square error” as loss function instead of “binary cross entropy” when training the network and predicting from it.
LEAKY_RELU = LeakyReLU(0.2) # Empirical hyperparameter
generator = Sequential()
generator.add(Dense(SAMPLE_LEN))
generator.add(LEAKY_RELU)
generator.add(Dense(512))
generator.add(LEAKY_RELU)
generator.add(Dense(SAMPLE_LEN, activation = "tanh"))
generator.compile(optimizer = "adam", loss = "mse", metrics = ["accuracy"])
Next, we plug the output of the generator as input to the discriminant, so to get the all GAN network ready to be trained.
gan = Sequential()
gan.add(generator)
gan.add(discriminator)
gan.compile(optimizer = "adam", loss = "binary_crossentropy", metrics = ["accuracy"])
How to train it?
Now we have our GAN ready to be trained. Instead of launching the fit
Keras method on the gan
object we just instantiated, we take a pause and reflect on the concept of GAN to understand how to properly train it.
Indeed, we said that the discriminator needs to learn how to distinguish between a sinusoid and another curve, and this can be done by just train it on our SAMPLES
dataset and a noisy dataset, labelling elements in the former as sinusoids and non sinusoids the elements in the latter.
However, the aim of the discriminator is not to learn our dataset but to intercept the fakes produced by the generator. For this reason, we train the discriminator as follows:
- For each epoch we perform a batch training of both the discriminator and the generator.
- This batch training starts by asking the generator to generate a batch of curves.
- This output is coupled to a batch of sinusoid from the
SAMPLE
dataset we produced, and a dataset with labels 1 (=genuine sinusoid) and 0 (=sinusoid produced by the generator) is provided to batch train the discriminator, which is therefore trained to recognize the generated sinusoid among the genuine ones. - The generator is batch trained on random data: this training backpropagates along the all GAN network, but weights in the discriminator are left untouched.
Thus, the discriminator is not trained to recognize sinusoids, but to distinguish among sinusoids in our datasets and sinusoids produced by the generator. Moreover, the latter is trained to produce sinusoids from random data so to mistake the discriminator.
When the accuracy of this mistaking is good (from the discriminator point of view), the GAN is able to generate fake sinusoids. Since we want our code to run without starving our laptops (which I assume with no GPUs, etc.) we choose relatively small parameters to produce our dataset and train the GAN. Therefore, we do not expect the network to draw a smooth sinusoid, but a somewhat flickering line, which however should display the sinusoidal pattern.
To appreciate how the GAN, during its apprenticeship, starts from drawing randomly and gradually improve its skill at drawing a sinusoid, I plot some of the GAN outputs during the training (a plot each 10 epochs, since we stick to just 64 epochs).
EPOCHS = 64
NOISE = uniform(X_MIN, X_MAX, size = (SAMPLE_SIZE, SAMPLE_LEN))
ONES = np.ones((SAMPLE_SIZE))
ZEROS = np.zeros((SAMPLE_SIZE))
print("epoch | dis. loss | dis. acc | gen. loss | gen. acc")
print("------+-----------+----------+-----------+----------")
fig = plt.figure(figsize = (8, 12))
ax_index = 1
for e in range(EPOCHS):
for k in range(SAMPLE_SIZE//BATCH):
# Addestra il discriminatore a riconoscere le sinusoidi vere da quelle prodotte dal generatore
n = randint(0, SAMPLE_SIZE, size = BATCH)
# Ora prepara un batch di training record per il discriminatore
p = generator.predict(NOISE[n])
x = np.concatenate((SAMPLE[n], p))
y = np.concatenate((ONES[n], ZEROS[n]))
d_result = discriminator.train_on_batch(x, y)
discriminator.trainable = False
g_result = gan.train_on_batch(NOISE[n], ONES[n])
discriminator.trainable = True
print(f" {e:04n} | {d_result[0]:.5f} | {d_result[1]:.5f} | {g_result[0]:.5f} | {d_result[1]:.5f}")
# At 3, 13, 23, ... plots the last generator prediction
if e % 10 == 3:
ax = fig.add_subplot(8, 1, ax_index)
plt.plot(X_COORDS, p[-1])
ax.xaxis.set_visible(False)
plt.ylabel(f"Epoch: {e}")
ax_index += 1
# Plots a curve generated by the GAN
y = generator.predict(uniform(X_MIN, X_MAX, size = (1, SAMPLE_LEN)))[0]
ax = fig.add_subplot(8, 1, ax_index)
plt.plot(X_COORDS, y)
The output is as follows:
epoch | dis. loss | dis. acc | gen. loss | gen. acc
------+-----------+----------+-----------+----------
0000 | 0.10589 | 0.96484 | 7.93257 | 0.96484
0001 | 0.03285 | 1.00000 | 8.62279 | 1.00000
0002 | 0.01879 | 1.00000 | 9.54678 | 1.00000
0003 | 0.01875 | 1.00000 | 11.18307 | 1.00000
0004 | 0.00816 | 1.00000 | 13.98673 | 1.00000
0005 | 0.01707 | 0.99609 | 16.46034 | 0.99609
0006 | 0.00579 | 1.00000 | 13.86913 | 1.00000
0007 | 0.00189 | 1.00000 | 17.36512 | 1.00000
0008 | 0.00688 | 1.00000 | 17.61729 | 1.00000
0009 | 0.00306 | 1.00000 | 18.18118 | 1.00000
0010 | 0.00045 | 1.00000 | 24.42766 | 1.00000
0011 | 0.00137 | 1.00000 | 18.18817 | 1.00000
0012 | 0.06852 | 0.98438 | 7.04744 | 0.98438
0013 | 0.20359 | 0.91797 | 4.13820 | 0.91797
0014 | 0.17984 | 0.93750 | 3.62651 | 0.93750
0015 | 0.18223 | 0.91797 | 3.20522 | 0.91797
0016 | 0.20050 | 0.91797 | 2.61011 | 0.91797
0017 | 0.24295 | 0.90625 | 2.62364 | 0.90625
0018 | 0.34922 | 0.83203 | 1.88428 | 0.83203
0019 | 0.25503 | 0.88281 | 2.24889 | 0.88281
0020 | 0.28527 | 0.88281 | 1.84421 | 0.88281
0021 | 0.27210 | 0.88672 | 1.92973 | 0.88672
0022 | 0.30241 | 0.88672 | 2.13511 | 0.88672
0023 | 0.33156 | 0.82422 | 2.02396 | 0.82422
0024 | 0.26693 | 0.86328 | 2.46276 | 0.86328
0025 | 0.39710 | 0.82422 | 1.64815 | 0.82422
0026 | 0.34780 | 0.83984 | 2.34444 | 0.83984
0027 | 0.26145 | 0.90625 | 2.20919 | 0.90625
0028 | 0.28858 | 0.86328 | 2.15237 | 0.86328
0029 | 0.34291 | 0.83984 | 2.15610 | 0.83984
0030 | 0.31965 | 0.86719 | 2.10919 | 0.86719
0031 | 0.27913 | 0.89844 | 1.92525 | 0.89844
0032 | 0.31357 | 0.87500 | 2.10098 | 0.87500
0033 | 0.38449 | 0.83984 | 2.03964 | 0.83984
0034 | 0.34802 | 0.81641 | 1.73214 | 0.81641
0035 | 0.28982 | 0.87500 | 1.74905 | 0.87500
0036 | 0.33509 | 0.85156 | 1.83760 | 0.85156
0037 | 0.29839 | 0.86719 | 1.90305 | 0.86719
0038 | 0.34962 | 0.83594 | 1.86196 | 0.83594
0039 | 0.32271 | 0.84766 | 2.21418 | 0.84766
0040 | 0.31684 | 0.84766 | 2.22909 | 0.84766
0041 | 0.37983 | 0.83984 | 1.79734 | 0.83984
0042 | 0.31909 | 0.83984 | 2.10337 | 0.83984
0043 | 0.30426 | 0.86719 | 1.98194 | 0.86719
0044 | 0.30465 | 0.86328 | 2.31558 | 0.86328
0045 | 0.35478 | 0.84766 | 2.40368 | 0.84766
0046 | 0.30423 | 0.86328 | 1.93115 | 0.86328
0047 | 0.30887 | 0.83984 | 2.17885 | 0.83984
0048 | 0.35123 | 0.86719 | 2.00351 | 0.86719
0049 | 0.24366 | 0.90234 | 2.21016 | 0.90234
0050 | 0.33797 | 0.84375 | 1.99375 | 0.84375
0051 | 0.35846 | 0.84375 | 2.17887 | 0.84375
0052 | 0.35476 | 0.83203 | 2.15312 | 0.83203
0053 | 0.28164 | 0.87109 | 2.60571 | 0.87109
0054 | 0.25782 | 0.89844 | 1.87386 | 0.89844
0055 | 0.28027 | 0.87500 | 2.30517 | 0.87500
0056 | 0.31118 | 0.84375 | 2.00939 | 0.84375
0057 | 0.32034 | 0.85547 | 2.22501 | 0.85547
0058 | 0.34665 | 0.84375 | 2.11842 | 0.84375
0059 | 0.32069 | 0.85547 | 1.79891 | 0.85547
0060 | 0.32578 | 0.87500 | 1.85051 | 0.87500
0061 | 0.32067 | 0.87109 | 1.70326 | 0.87109
0062 | 0.31929 | 0.85938 | 1.99901 | 0.85938
0063 | 0.38814 | 0.83984 | 1.55212 | 0.83984
[<matplotlib.lines.Line2D at 0x1b5c3054c48>]
Notice that the first picture, after three epochs, is more or less random, while the remaining ones converge toward smoother curves (even if our 64 epochs are not enough for a nice curve) and, more important, toward a curve which displays a sinusoidal trend.
We have also shown the progress of loss and accuracy for both the discriminator and the all generative network during the training. On examining this log we can check that the loss value of the GAN is the one which the lower the better the curve approximates a sinusoid. Also, on examining the values for the discriminator, some adjustments in the hyper-parameters (or even in the architecture of the networks) are in order.
Conclusions
Our toy example may not seem so impressive, but actually it should. Indeed we just assembled two shallow networks which (dropout and leaky relu aside) could be programmed in the late 80s. The idea of putting them in competition has produced the outcome of a generating network which “draws” curves resembling the one with which we feed it.
Moreover, the network understands the models to imitate just by a small sampled description, and I bet that running the programs on your computer has taken at most a handful of minutes.
On combining more sophisticated networks along the same lines, one can obtain the result of providing a GAN able to generate digits and letters, but also more complex figures. Some modifications in the training techniques and in the representation of data allows to generate speeches, videos, and, in the near future, any kind of stuff for which there are plenty of examples on the Web. That is to say, almost everything!
If you are interested in these themes, do not miss the opportunity to attend our Deep Learning conference, which will be held online on 27 May: click here for more information and booking!
You can read the orginal version of this article at Codemotion.com, where you will find more related contents. https://www.codemotion.com/magazine/dev-hub/machine-learning-dev/how-to-build-a-gan-in-python/