How to Design and Train Generative Adversarial Networks (GANs)
Get an overview of generative adversarial networks (GANs) and walk through how to design and train one using MATLAB®.
GANs are composed of two deep neural networks, a generator and a discriminator, which are adversaries of each other (thus the term “adversarial”). The generator creates new data instances, while the discriminator evaluates them for authenticity (i.e., it decides whether each instance of data that it reviews belongs to the actual training dataset or not).
The video outlines that process, and it shows, step by step, how to use MATLAB to train GANs networks.
Hi, my name is Joe Hicklin. I'm a senior developer at The MathWorks. I'm going to show you what generative adversarial networks are. I'm going to show you how they work and how to make one using MathWorks deep learning tools.
GANs are generative because they generate something. In this example, we'll be generating images. They're called adversarial because we use two separate networks fighting against each other. Each one improves in sort of an arms race as the other one improves. And it's this improvement that will allow us to train the generator to generate the images we want to make.
Training a GAN is pretty interesting. We're going to use a setup like this. We start with two networks-- here's our generator and here's our discriminator. We'll start out just by feeding random noise into the generator. Now the generator hasn't been trained yet, so he'll produce images, but his images are just random data. He'll produce a whole bunch of images.
We'll take these images-- we're going to call these fake data-- and we're going to send them to the discriminator. And he's supposed to say whether they're real or fake. But he hasn't been trained yet either, so his output is garbage. He'll call some real and some fake, right? Next we'll take some real data and we'll feed that to the discriminator. But still, the discriminator hasn't been trained, so his output is garbage. And so this is our first step.
But now we're ready to train them. We want to train the generator to fool the discriminator. So here, every time the discriminator has called the input real, the generator has succeeded, and every time the discriminator calls it fake, the generator has failed. So that gives us an error signal we can feed back to the generator and update its parameters to do better.
At the same time, we want the discriminator to label all the fake data as fake and all the real data as real. So everywhere the discriminator called the fake data real, that's an error. Every time it called the fake data fake, that's getting the right answer. And the opposite for real data-- everywhere it calls it real, that's a right answer, and where it calls it fake, that's a wrong answer. So that gives us error signals that we can feed back to the discriminator to update it. All right.
So we go back, we update each network, and repeat the process. And at first, the networks do a very bad job. But slowly, they improve. And each time we go around the loop, passing in the data and doing the training, the networks get better and better and better. And after a great many times around the loop, we wind up with a generator that's doing a pretty good job generating our artificial images and a discriminator that is getting pretty good at telling them apart.
Usually we'll throw the discriminator away when we're done. I just use the generator to generate our images.
Now that we've seen how GANs work, let's create one in MATLAB. I'm going to create and train a GAN that generates artificial images of sunflowers. This process is going to have five steps. We're going to have to organize our data, create our two networks, generate a gradients function, write a training script, and finally train the two networks.
The gradient function is the GAN-specific part of this program, and I'm going to start with that. Training a neural net is an optimization problem. We're trying to minimize the errors in the output of the neural net. This is usually done through gradient descent. If we can find the rate of change of the errors with respect to the learnable parameters, we can change the parameters a little bit in the direction of the gradient and hopefully improve a little bit. And we repeat that many times and hopefully converge to a good solution.
The core of this is finding the gradients. We need to write a function that will calculate the gradients we need to train these things. We need to write that function, and it's already written right here.
The gradients function will take in the generator network, the discriminator network, the real images, and a noise vector. And it's going to return the gradients for the generator and the gradients for the discriminator that we'll use later in the training loop.
Now this function is going to work just like how a GANs network works that I described earlier. First thing we're going to do is pass the noise into the generator and get some fake images out of it. We'll pass those fake images into the discriminator and get the discriminator's predictions on the fake images. We'll pass the real images in the discriminator and get the predictions on the real data.
Once we've got that, we're ready to calculate the errors. The error in the generator is his error in the prediction on the fake images. The error on the discriminator has two parts. It's the sum of his errors predicting on fake images and his errors predicting on real images. And we'll just add those two together.
Now that we've got the errors, we can call dlgradient. Dlgradient will use a form of automatic differentiation to do the heavy lifting for us, and he'll calculate the gradient of the loss with respect to the learnable parameters in the generator. And we do the same thing for the discriminator and its loss. And that gives us the gradients that we're looking for from this function. So there's our gradients function.
Now that we have our gradients function, we're ready to do the other four steps. And here I've got a script that does all of that. First I'm going to get my data together into an image data set like we usually do, then I'm going to load the two networks from disk. I built these networks earlier using Deep Network Designer. Then I'm going to specify training options, and these are the same kind of training options you've seen in a lot of examples.
Finally, we're ready to train the network. Now normally, you call Train Network to do that and it has a training loop inside of it and does all that for you. But you can't in this case because we're using a different gradients function and we need to treat it specially. So we're going to write that ourselves. But it's the same sort of thing that's usually happening.
The first thing we'll do is generate some noise that we're going to use later for validation, keeping track of our iteration count. Then for each epoch-- for each time through our data-- we're going to reset our image set and then we're going to move through the data. As long as the image set has data, we're going to take the next batch of images. This skips the last batch because sometimes, they're not the right size, so we'll just skip them. And finally, we're ready to go.
Here we're going to take our batch of sunflower images and put them into dl arrays and move them onto the GPU. Now we're going to do the same thing. We're going to generate some random noise and put that in dl arrays and put that on the GPU. And now we're right to calculate our gradients. We're going to call dlfeval passing in our gradients function that we wrote earlier, passing in all the other arguments-- the generator, the discriminator, and the images. And that will calculate our gradients for us.
Now we're ready to update the networks. adamupdate does that. We give it the learnables for the network and the gradients and it will produce new learnables. Same thing for the discriminator-- we'll pass in the discriminator learnables and its gradient. It'll give us new learnables. And that's it. We've taken one step. This little section of code here just draws some random output every so often for our entertainment. And then we repeat that loop. And so that's all it takes. If I start that running, it'll move through all these steps and we'll see the output changing every few seconds as it calculates.
Now it'll take about an hour to do a really good job here. And we'll zoom through that through the magic of video.
So we let this go for an hour. And you can see it's done a pretty convincing job of producing images of sunflowers. We've seen how GANs networks work, we've seen how to implement them using MathWorks deep learning tools. If you want to learn more about how to do this, check out the links below.
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.