赞
踩
ACGAN与CGAN的区别如下
1 与CGAN一样的是,在生成网络的输入都混入label;
2 不一样的是在鉴别网络输入时,ACGAN不再混入label,而是在鉴别网络的输出时,把label作为target进行反馈来提交给鉴别网络的学习能力。
3 另一个不一样的是,生成网络和鉴别网络的网络层不再是CGAN的全连接,而是ACGAN的深层卷积网络(这是在DCGAN开始引入的改变),卷积能够更好的提取图片的特征值,所有ACGAN生成的图片边缘更具有连续性,感觉更真实。
如下生成网络model,和CGAN的一模一样。
noise = Input(shape=(self.latent_dim,)) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.num_classes, 100)(label)) model_input = multiply([noise, label_embedding]) img = model(model_input) |
如下鉴别网络model,传入还是img,但是输出包含两个部分:
1 validity,即鉴别图片是不是伪造的结果。
2 label,使用softmax激活,输出10维的结果即属于哪个数字。
img = Input(shape=self.img_shape) # Extract feature representation features = model(img) # Determine validity and label of the image validity = Dense(1, activation="sigmoid")(features) label = Dense(self.num_classes, activation="softmax")(features) return Model(img, [validity, label]) |
所以在鉴别网络或生成网络训练的时候,提供了target img_labels和sampled_labels。
# Train the discriminator d_loss_real = self.discriminator.train_on_batch(imgs, [valid, img_labels]) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, [fake, sampled_labels]) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- # Train the generator g_loss = self.combined.train_on_batch([noise, sampled_labels], [valid, sampled_labels]) |
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。