当前位置:   article > 正文

gan loss gan_揭开GAN的神秘面纱

基于gan生成手写数字集accan应用场景

gan loss gan

The aim of this article is to illustrate the structure and training method for Generative Adversarial Networks (GANs), highlighting the key ideas behind GANs and elucidating the topic with a working example, using Tensorflow and Keras to train a GAN on the MNIST dataset to produce handwritten digits. We will then explore methods of adding levels of control over the output of the generator.

本文的目的是说明生成式对抗网络(GAN)的结构和训练方法,重点介绍GAN背后的关键思想,并通过使用Tensorflow和Keras在MNIST数据集上训练GAN的工作示例来阐明该主题。手写数字。 然后,我们将探讨增加对发电机输出的控制级别的方法。

When I began recently to look for tutorials on training GANs I found many overly complicated scripts that offered little intuition regarding what was going on behind the scenes during training. It is my hope that the reader will have a clear understanding of why the code provided here works, and appreciate the simplicity of it. Familiarity with the Keras functional API as well as some general knowledge of deep learning will be useful.

最近,当我开始寻找有关GAN的教程时,我发现很多过于复杂的脚本,这些脚本对于培训期间幕后发生的事情几乎没有什么直觉。 我希望读者能清楚地理解这里提供的代码为什么起作用,并希望它的简单性。 熟悉Keras功能API以及一些深度学习常识将很有用。

GAN概述 (GAN Overview)

First, let us discuss the what we would like to accomplish in this tutorial. A GAN comprises of two separate neural networks, a Generator and a Discriminator. Both of these networks could have any structure you care to imagine, it depends on the task at hand.

首先,让我们讨论一下我们希望在本教程中完成的工作。 GAN包含两个独立的神经网络,即生成器鉴别器 。 这两个网络都可以具有您想要想象的任何结构,这取决于手头的任务。

The objective of the generator is to produce objects that look as though they belong to any given dataset given some random input, canonically 100 numbers drawn from the standard normal distribution. For example, we will be considering the MNIST dataset in this tutorial. This dataset comprises of 28 x 28 arrays (that makes it a 784-dimensional dataset), and each array represents an image of a handwritten digit, 0–9. Each class of digit can be thought of as a manifold in 784-dimensional space, with an image lying on each manifold if and only if it is recognisable to a human as the corresponding digit. A good generator must therefore be good at mapping random inputs onto these manifolds, so that it will only generate images that look as if they belong to the true dataset.

生成器的目的是产生看起来像它们属于任何给定数据集的对象,这些对象具有一些随机输入,从标准正态分布中正常抽取100个数字。 例如,在本教程中,我们将考虑MNIST数据集。 该数据集包含28 x 28个数组(使其成为784维数据集),每个数组代表一个手写数字0–9的图像。 可以将每个类别的数字视为在784维空间中的流形,并且当且仅当人类可以将其识别为对应的数字时,图像才会位于每个流形上。 因此,好的生成器必须擅长将随机输入映射到这些流形上,以便它只能生成看起来好像它们属于真实数据集的图像。

The second network, the discriminator, has the opposite objective. It must learn to discriminate between real examples from the dataset and the ‘fakes’ created by the Generator. The combined structure is as follows:

第二个网络,即鉴别器,具有相反的目标。 它必须学会区分数据集中的真实示例和生成器创建的“伪造”。 组合结构如下:

Image for post
Image by Author
图片作者

The ‘Adversarial’ part of the name refers to the method of training GANs. During training they compete (as adversaries), each trying to beat the other. The networks are trained alternately on each batch of data. First the discriminator will be trained to classify real data as ‘real’, and the fake images created by the generator as ‘fake’. Next, the generator will be trained to produce output that the discriminator classifies as ‘real’. This process is repeated for each batch in the data. It really is a simple idea, but it is a powerful one.

名称的“专家”部分是指训练GAN的方法。 在训练期间,他们竞争(作为对手),彼此竞争。 在每批数据上交替训练网络。 首先,将对鉴别器进行训练,以将真实数据分类为“真实”,而将生成器创建的伪图像分类为“伪造”。 接下来,将训练生成器以产生鉴别器分类为“真实”的输出。 对数据中的每个批次重复此过程。 这确实是一个简单的想法,但却是一个有力的想法。

It is interesting to note that the generator will never see any real data — it will simply learn how to fool the discriminator by learning from the gradients propagated through the discriminator via the backpropagation algorithm. For this reason GANs are particularly susceptible to the vanishing gradients problem. After all, if the gradients vanish before reaching the generator there is no way for it to learn! This is particularly important to consider when using very deep GANs, but it should not be a worry for us here as the networks we use are relatively small.

有趣的是,生成器将永远不会看到任何真实数据,它只会通过反向传播算法从通过鉴别器传播的梯度中学习,从而简单地学习如何欺骗鉴别器。 因此,GAN尤其容易受到梯度消失问题的影响。 毕竟,如果梯度在到达发生器之前消失了,就无法学习! 当使用非常深的GAN时,考虑这一点特别重要,但是这里我们不必担心,因为我们使用的网络相对较小。

Another common issue is that of ‘Mode Collapse’. In a way, this is where the generator starts thinking outside the box in order to be lazy. The generator can simply learn to generate the exact same output regardless of the input! If this output is convincing to the discriminator, then the generator has completed its task despite being totally useless to us.

另一个常见的问题是“模式崩溃”。 在某种程度上,这是生成器开始在框外思考以便变得懒惰的地方。 无论输入如何,生成器都可以简单地学习生成完全相同的输出! 如果此输出令人信服,则表明生成器已完成其任务,尽管对我们完全没有用。

We will consider overcoming the issue of mode collapse in this tutorial, but first we shall consider a very simple model in order to focus on understanding the training algorithm. Without further ado, let’s get stuck in.

在本教程中,我们将考虑解决模式崩溃的问题,但是首先,我们将考虑一个非常简单的模型,以便专注于理解训练算法。 事不宜迟,让我们陷入困境。

用Keras建立模型 (Building the Models with Keras)

There are a few parameters that are used throughout the code, they are presented here for clarity.

整个代码中都使用了一些参数,为清楚起见,在此介绍它们。

Introducing some variables.
介绍一些变量。

We will be using the MNIST dataset provided by the tensorflow_datasets module. The data must be loaded, and the pixel values scaled to range between -1 and 1, the range of the tanh activation function which will be used in the last layer of the generator. The data is then shuffled and batched according to the BATCH_SIZE and BUFFER_SIZE parameters. We use a buffer size of 60,000 (length of the dataset) so that the data is fully shuffled, and a Batch size of 128. The prefetch(1) call means that while one batch is being used to train the network, the next batch is being loaded into memory, which can help to prevent bottle-necking. This may not be a problem for MNIST as each image has relatively little data, but can make a difference for high resolution images.

我们将使用tensorflow_datasets模块提供的MNIST数据集。 必须加载数据,并且像素值的缩放比例应在-1和1之间,即在生成器的最后一层中使用的tanh激活函数的范围。 然后根据BATCH_SIZE和BUFFER_SIZE参数对数据进行混洗和批处理。 我们使用60,000(数据集的长度)的缓冲区大小,以使数据被完全改组,并且Batch大小为128。prefetch(1)调用意味着当使用一个批次来训练网络时,下一个批次正在被加载到内存中,这可以帮助防止瓶颈。 对于MNIST,这可能不是问题,因为每个图像的数据相对较少,但对于高分辨率图像可能会有所不同。

Loading, scaling, shuffling, batching and fetching the dataset.
加载,缩放,改组,批处理和获取数据集。

The simplicity of MNIST will allow us to get to grips with GANs with a simple dense structure. We will use a generator and discriminator with 3 dense hidden layers. Since we need to generate 28 x 28 images, the final layer will have 784 units, which can then be reshaped into the desired format. The other parameters can be played with. We will also use the standard tanh activation in the last layer, with the other layers having ReLu activation for simplicity.

MNIST的简单性将使我们能够掌握具有简单密集结构的GAN。 我们将使用具有3个密集隐藏层的生成器和鉴别器。 由于我们需要生成28 x 28的图像,因此最后一层将具有784个单位,然后可以将其重塑为所需的格式。 可以使用其他参数。 我们还将在最后一层中使用标准tanh激活,为简单起见,其他各层都具有ReLu激活。

Function to create a generator.
创建发电机的功能。

The discriminator is the mirror image of the generator. This makes sense intuitively as it is trying to undo what the generator has done. It will predict 1 for ‘Real’, and 0 for ‘Fake’, so the sigmoid activation is used in the final layer. We use ReLu in the other layers again.

鉴别符是生成器的镜像。 从直觉上讲这是合理的,因为它正试图撤消生成器所做的工作。 它将为“ Real”预测1,为“ Fake”预测0,因此在最后一层使用了S型激活。 我们再次在其他图层中使用ReLu。

Function to create a discriminator.
创建鉴别器的功能。

Now these two networks can be combined into a GAN as if they are just layers. It’s that simple!

现在,可以将这两个网络合并为一个GAN,就好像它们只是层一样。 就这么简单!

Function to create a GAN from a discriminator and generator.
从鉴别器和生成器创建GAN的功能。

用Keras训练GAN (Training the GAN with Keras)

As outlined previously, GANs can’t be trained using the model.fit() method that is used for simpler deep learning models in Keras. This is because we have two different networks that must be trained concurrently, but with opposite objectives. Therefore we must create our own training loop to iterate over the batched data and train each model separately. There are a few subtleties in the following code, so I would encourage the reader to take in the comments carefully, but the key points are these:

如前所述,无法使用用于Keras中更简单的深度学习模型的model.fit()方法来训练GAN。 这是因为我们有两个不同的网络,必须同时进行训练,但目标却相反。 因此,我们必须创建自己的训练循环以迭代批处理数据并分别训练每个模型。 以下代码中有一些微妙之处,因此,我鼓励读者谨慎地添加注释,但关键是:

  • (10–12) The discriminator and GAN must be compiled separately, and we make the discriminator untrainable when we compile the GAN. This allows us to train only the generator by calling gan.train_on_batch(), and only the discriminator when we call discriminator.train_on_batch(). This is a handy feature of Keras that is usually overlooked, leading to a less elegant code.

    (10–12)鉴别符和GAN必须分开编译,并且在编译GAN时我们使鉴别符不可训练。 这让我们发电机通过调用gan.train_on_batch(训练), 只有当我们调用discriminator.train_on_batch鉴()。 这是Keras的一个便捷功能,通常会被忽略,从而导致代码不太优雅。

  • (25–35) We train the discriminator to map ‘real’ images to 1, and ‘fake’ images to 0. Softening these values, that is using random numbers close to 1 and 0 is a standard trick that helps GANs learn.

    (25–35)我们训练鉴别器将“真实”图像映射为1,将“伪造”图像映射为0。软化这些值,即使用接近1和0的随机数是帮助GAN学习的标准技巧。
  • (40–47) We train the generator by training the GAN (with the discriminator’s weights untrainable) but with the labels reversed, that is mapping the ‘fake’ images to 1. In this way, we ask the generator to learn how to trick the discriminator.

    (40–47)我们通过训练GAN来训练生成器(区分器的权重不可训练),但是标签颠倒了,即将“伪造”图像映射到1。这样,我们要求生成器学习如何欺骗鉴别器。
A function to create and train a GAN for a number of epochs.
在多个时期创建和训练GAN的功能。

We observe the following convergence through 100 epochs of training. Although we see different digits appearing throughout training, it seems to have mostly settled on drawing 1s which could be an example of mode collapse. This is where the generator learns to generate just one class from the dataset. If left to train longer, the discriminator would likely learn to classify that class as fake, at which point the generator would learn to generate another class, with this cycle continuing. Since we made no effort to prevent such behaviour, I believe we can count this result as a success, we clearly have a working training method for this class of neural network.

我们通过100个时期的训练观察到以下收敛。 尽管我们看到整个训练过程中出现了不同的数字,但似乎主要决定于图1,这可能是模式崩溃的一个例子。 在这里,生成器学习从数据集中仅生成一个类的地方。 如果任其继续训练,则鉴别者将可能学会将该类别分类为假,此时生成器将学会生成另一个类别,并且此循环继续进行。 由于我们没有做出任何努力来防止这种行为,因此我相信我们可以将这一结果视为成功,因此对于此类神经网络,我们显然拥有一种有效的训练方法。

Image for post
Images by Author
图片作者

改善模型的性能-增加功能 (Improving the Performance of our Model — Adding Functionality)

While this first attempt has been successful to a certain degree, we can certainly do much better. After all, we have only used half of the dataset! We have neglected to make use of the labels that are available to us. With just a few tweaks to the above code, we can create a GAN which uses the labelled data, so that the generator will produce a digit from a given class chosen by us. As fun as generating random digits is, having control over the generated images would add a beautiful layer of sophistication to out model. This will also help the GAN to learn faster, and help to prevent mode collapse. The new structure of the model will be as follows:

虽然第一次尝试已在一定程度上取得了成功,但我们当然可以做得更好。 毕竟,我们只使用了一半的数据集! 我们忽略了使用我们可用的标签。 只需对上面的代码进行一些调整,我们就可以创建一个使用标记数据的GAN,以便生成器将根据我们选择的给定类产生一个数字。 就像生成随机数字一样有趣,对生成的图像进行控制将为模型增加精美的层次。 这也将有助于GAN学习得更快,并有助于防止模式崩溃。 该模型的新结构如下:

Image for post
Image by Author
图片作者

The generator must take random noise as input as before, but also a randomly selected class that it should generate. This means that, once trained, the generator should produce a digit from any class we choose. The discriminator will also take two inputs, the first being images, the second being corresponding class labels. This will allow the discriminator to make its decision of whether the image is real or fake not just based on how convincing the image is as a digit, but based on whether the digit belongs to the given class, which will force the generator to draw images from the correct class in order to perform well. We need only make some very simple tweaks to our code.

生成器必须像以前一样将随机噪声作为输入,但也应将其随机生成的类别作为输入。 这意味着,一旦经过训练,生成器应该从我们选择的任何类别中产生一个数字。 鉴别器还将接受两个输入,第一个是图像,第二个是对应的类别标签。 这将使判别器不仅可以根据图像作为数字的说服力来决定图像是真是假,还可以基于数字是否属于给定的类别来做出决定,这将迫使生成器绘制图像为了正确的表现。 我们只需要对代码进行一些非常简单的调整即可。

First, we will need a variable telling us how many features to expect for our data. Since we will be one-hot encoding the labels, each label will be a vector of length 10. We must also create the labelled dataset.

首先,我们需要一个变量来告诉我们期望数据有多少个特征。 由于我们将对标签进行一次编码,因此每个标签都是长度为10的向量。我们还必须创建标签数据集。

Declaring a new variable and creating a new dataset.
声明一个新变量并创建一个新数据集。

The following is where we see the advantage of using the Keras functional API compared to the sequential API. It is a triviality to add multiple inputs and concatenate them in this framework as follows:

以下是我们看到使用Keras功能API与顺序API相比的优势。 添加多个输入并在此框架中将它们串联起来是很简单的,如下所示:

A function to create a generator with the new features.
使用新功能创建生成器的功能。

The discriminator is tweaked similarly, with another hidden layer added in after inputting the class labels into the model. The extra layer is necessary as if the concatenated layer connected directly to the output neuron the discriminator would lack the flexibility to process the information given by the class label.

区分器的调整方法类似,将类别标签输入模型后,添加另一个隐藏层。 额外的层是必需的,就像级联层直接连接到鉴别器的输出神经元将缺乏处理类标签给出的信息的灵活性。

A function to create a discriminator with the new features.
使用新功能创建鉴别器的功能。

Creating the GAN follows intuitively as before.

像以前一样直观地创建GAN。

A function to create a GAN with the new features.
使用新功能创建GAN的功能。

Finally, training is precisely the same idea as before, but with the additional inputs generated and added in. Because of the way we generated the dataset, the code looks much the same.

最后,训练与以前完全相同,只是要生成和添加其他输入。由于我们生成数据集的方式,代码看起来几乎相同。

A function to create and train a GAN with the new features.
使用新功能创建和训练GAN的功能。

This model was then trained for 100 epochs as before, with the progress shown here. The generator should be producing a full sequence of digits, 0–9. It seems we have yet more success! With the exception of the digit 9, we have a convincing sequence of digits by the end of the training.

然后,像以前一样,对该模型进行了100个时期的训练,此处显示了进度。 生成器应生成完整的数字序列,即0–9。 看来我们还有更多的成功! 除数字9之外,到培训结束时,我们将获得令人信服的数字序列。

Image for post
Images by Author
图片作者

探索特征映射 (Exploring Feature Mapping)

We are also able to observe a fascinating characteristic of GANs. They actually learn to encode certain random inputs as features of the data they generate. This is best seen through an example:

我们还能够观察到GAN的迷人特征。 他们实际上学会将某些随机输入编码为它们生成的数据的特征。 最好通过一个例子看一下:

Image for post
Image by Author
图片作者

The images in each column are generated from the same random input, but with different class labels. It is clear that encoded within the random input are features of the resulting digits. For example, in the fifth column all strokes are quite thin, whereas all digits in column 8 have thick line strokes. Similarly all digits in column 2 appear to slant to the right, while those in column 3 slant slightly to the left. While these encodings would not be clear to any human trying to interpret them, it is possible to control the style of the generated digits to a certain extent by studying the behaviour of the generator on a sample of random inputs.

每列中的图像均由相同的随机输入生成,但具有不同的类别标签。 显然,在随机输入中进行编码是结果数字的特征。 例如,在第五列中,所有笔划都很细,而第8列中的所有数字都有粗线笔划。 同样,第2列中的所有数字似乎都向右倾斜,而第3列中的所有数字则略微向左倾斜。 尽管这些编码对于试图解释它们的任何人来说都不是很清楚,但是有可能通过研究随机输入样本上的生成器的行为来在一定程度上控制所生成数字的样式。

平均功能 (Averaging Features)

In order to demonstrate this, I have inspected the output of the generator on a large number of inputs, and sorted the inputs based on the type of output the generator produced. The categories I considered were ‘thick stroke, straight’, ‘thick stroke, slanted right’, ‘thin stroke, straight’ and ‘thin stroke, slanted right’. Each category was then averaged, so that I had four input vectors, one for each category. When these inputs are passed through the generator, they produce new images from that category, as shown below. Here I’m showing 1s, 4s, and 8s as these digits show the most exaggerated slanting.

为了证明这一点,我检查了发电机输入的大量输入,并根据发电机产生的输出类型对输入进行了分类。 我考虑的类别是“粗笔直”,“粗笔直”,“细笔直”和“细笔直”。 然后将每个类别取平均,这样我就有四个输入向量,每个类别一个。 当这些输入通过生成器传递时,它们将产生该类别的新图像,如下所示。 在这里,我显示的是1、4和8,因为这些数字显示的是最夸张的倾斜。

Image for post
Image by Author; (Left to right) Thick-Straight, Thick-Right, Thin-Straight, Thin-Right
图片由作者提供; (从左到右)粗直,粗直,细直,细直

缩放功能 (Scaling Features)

We can actually do even better than that! If we multiply these inputs by a number larger than 1, we can exaggerate the features of that input, and if we us a number less than one, we can reduce that feature. In fact, it seems that if we multiply by negative number we can generate digits with the reversed features! See the below gifs — they represent the output of our generator when the inputs for each category are multiplied by scalars ranging between 2 and -2. See how they start with exaggerated features of the above, and end as the reverse!

实际上,我们可以做得更好! 如果将这些输入乘以大于1的数字,则可以放大该输入的特征,而如果我们将其小于一个数,则可以缩小该特征。 实际上,如果我们乘以负数,就可以生成具有相反特征的数字! 请参见下面的gif图像-当每个类别的输入乘以2到-2之间的标量时,它们表示发电机的输出。 看看它们是如何从上述夸张的功能开始,然后以相反的方式结束的!

Image for post
Images by Author; (Left to right) Thick-Straight, Thick-Right, Thin-Straight, Thin-Right
图片由作者提供; (从左到右)粗直,粗直,细直,细直

What we are seeing here is a common feature in GANs. While the input may be random, the generator’s interpretation of them is anything but. It has learned to use certain vectors in the latent space to represent the features it draws.

我们在这里看到的是GAN中的常见功能。 尽管输入可能是随机的,但生成器对它们的解释不过是什么。 它学会了在潜在空间中使用某些矢量来表示其绘制的特征。

添加和减去特征 (Adding and Subtracting Features)

You can even add and subtract meaningfully in this latent space. What might you expect to get if you added the input for thick, straight digits to the input for thin, slanted digits?

您甚至可以在此潜在空间中有意义地加减。 如果将用于粗斜线数字的输入添加到用于细斜线数字的输入,您会期望得到什么?

Image for post
Image by Author
图片作者

You get an input for thick, slanted digits! And what about subtracting the input for thick, straight digits from the input for thick, slanted digits?

您会得到一个粗斜数字输入! 从粗斜数字输入中减去粗直数字的输入又如何呢?

Image for post
Image by Author
图片作者

You get an input for thin, slanted digits! How cool is that? If one had the time and resources, the output of the GAN could be categorised more fully and one could gain absolute control of the output being generated using this method.

您会输入细斜线数字! 多么酷啊? 如果有时间和资源,可以对GAN的输出进行更全面的分类,并且可以完全控制使用此方法生成的输出。

In future, I aim to apply the principles laid out in this article to train GANs on more complex datasets, and explore the feature mappings I obtain. The full code that I have used is available in this Github Repository:

将来,我打算应用本文中阐述的原理在更复杂的数据集上训练GAN,并探索我获得的特征映射。 我使用的完整代码可在以下Github存储库中找到:

翻译自: https://towardsdatascience.com/demystifying-gans-cc1ac011355

gan loss gan

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号