赞
踩
ACGAN的原理GAN(CGAN)
相似。对于CGAN和ACGAN,生成器输入均为潜在矢量及其标签,输出是属于输入类标签的伪造图像。对于CGAN,判别器的输入是图像(包含假的或真实的图像)及其标签, 输出是图像属于真实图像的概率。对于ACGAN,判别器的输入是一幅图像,而输出是该图像属于真实图像的概率以及其类别概率。
本质上,在CGAN中,向网络提供了标签。在ACGAN中,使用辅助解码器网络重建辅助信息。ACGAN理论认为,强制网络执行其他任务可以提高原始任务的性能。在这种情况下,辅助任务是图像分类。原始任务是生成伪造图像。
判别器目标函数:
L
(
D
)
=
−
E
x
∼
p
d
a
t
a
l
o
g
D
(
x
)
−
E
z
l
o
g
[
1
−
D
(
G
(
z
∣
y
)
)
]
−
E
x
∼
p
d
a
t
a
p
(
c
∣
x
)
−
E
z
l
o
g
p
(
c
∣
g
(
z
∣
y
)
)
\mathcal L^{(D)} = -\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_zlog[1 − D(G(z|y))]-\mathbb E_{x\sim p_{data}}p(c|x)-\mathbb E_zlogp(c|g(z|y))
L(D)=−Ex∼pdatalogD(x)−Ezlog[1−D(G(z∣y))]−Ex∼pdatap(c∣x)−Ezlogp(c∣g(z∣y))
生成器目标函数:
L
(
G
)
=
−
E
z
l
o
g
D
(
g
(
z
∣
y
)
)
−
E
z
l
o
g
p
(
c
∣
g
(
z
∣
y
)
)
\mathcal L^{(G)} = -\mathbb E_{z}logD(g(z|y))-\mathbb E_zlogp(c|g(z|y))
L(G)=−EzlogD(g(z∣y))−Ezlogp(c∣g(z∣y))
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import os
import math
from PIL import Image
def generator(inputs,image_size,activation='sigmoid',labels=None): """生成网络 Arguments: inputs (layer): 输入 image_size (int): 图片尺寸 activation (string): 输出层激活函数 labels (tensor): 标签 returns: model: 生成网络 """ image_resize = image_size // 4 kernel_size = 5 layer_filters = [128,64,32,1] inputs = [inputs,labels] x = keras.layers.concatenate(inputs,axis=1) x = keras.layers.Dense(image_resize*image_resize*layer_filters[0])(x) x = keras.layers.Reshape((image_resize,image_resize,layer_filters[0]))(x) for filters in layer_filters: if filters > layer_filters[-2]: strides = 2 else: strides = 1 x = keras.layers.BatchNormalization()(x) x = keras.layers.Activation('relu')(x) x = keras.layers.Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(x) if activation is not None: x = keras.layers.Activation(activation)(x) return keras.Model(inputs,x,name='generator')
def discriminator(inputs,activation='sigmoid',num_labels=None): """生成网络 Arguments: inputs (Layer): 输入 activation (string): 输出层激活函数 num_labels (int): 类别数 Returns: Model: 鉴别网络 """ kernel_size = 5 layer_filters = [32,64,128,256] x = inputs for filters in layer_filters: if filters == layer_filters[-1]: strides = 1 else: strides = 2 x = keras.layers.LeakyReLU(0.2)(x) x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(x) x = keras.layers.Flatten()(x) outputs = keras.layers.Dense(1)(x) if activation is not None: print(activation) outputs = keras.layers.Activation(activation)(outputs) if num_labels: #ACGAN有第二个输出,用于输出图片的类别 layer = keras.layers.Dense(layer_filters[-2])(x) labels = keras.layers.Dense(num_labels)(layer) labels = keras.layers.Activation('softmax',name='label')(labels) outputs = [outputs,labels] return keras.Model(inputs,outputs,name='discriminator')
def build_and_train_models(): """The ACGAN training """ #数据加载及预处理 (x_train,y_train),_ = keras.datasets.mnist.load_data() image_size = x_train.shape[1] x_train = np.reshape(x_train,[-1,image_size,image_size,1]) x_train = x_train.astype('float32') / 255. num_labels = len(np.unique(y_train)) y_train = keras.utils.to_categorical(y_train) #超参数 model_name = 'acgan-mnist' latent_size = 100 batch_size = 64 train_steps = 40000 lr = 2e-4 decay = 6e-8 input_shape = (image_size,image_size,1) label_shape = (num_labels,) #discriminator inputs = keras.layers.Input(shape=input_shape,name='discriminator_input') discriminator = discriminator(inputs,num_labels=num_labels) optimizer = keras.optimizers.RMSprop(lr=lr,decay=decay) loss = ['binary_crossentropy','categorical_crossentropy'] discriminator.compile(loss=loss,optimizer=optimizer,metrics=['acc']) discriminator.summary() #generator input_shape = (latent_size,) inputs = keras.layers.Input(shape=input_shape,name='z_input') labels = keras.layers.Input(shape=label_shape,name='labels') generator = generator(inputs,image_size,labels=labels) generator.summary() optimizer = keras.optimizers.RMSprop(lr=lr*0.5,decay=decay*0.5) discriminator.trainable = False adversarial = keras.Model([inputs,labels],discriminator(generator([inputs,labels])), name=model_name) adversarial.compile(loss=loss,optimizer=optimizer,metrics=['acc']) adversarial.summary() models = (generator,discriminator,adversarial) data = (x_train,y_train) params = (batch_size,latent_size,train_steps,num_labels,model_name) train(models,data,params)
def train(models,data,params): """Train the discriminator and adversarial Networks Arguments: models (list): generator,discriminator,adversarial data (list): x_train,y_train params (list): network parameter """ generator,discriminator,adversarial = models x_train,y_train = data batch_size,latent_size,train_steps,num_labels,model_name = params save_interval = 500 noise_input = np.random.uniform(-1.,1.,size=[16,latent_size]) noise_label = np.eye(num_labels)[np.arange(0,16) % num_labels] train_size = x_train.shape[0] print(model_name,'Labels for generated images: ',np.argmax(noise_label,axis=1)) for i in range(train_steps): #训练鉴别器 rand_indexes = np.random.randint(0,train_size,size=batch_size) real_images = x_train[rand_indexes] real_labels = y_train[rand_indexes] #产生伪造图片 noise = np.random.uniform(-1.,1.,size=(batch_size,latent_size)) fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)] fake_images = generator.predict([noise,fake_labels]) #构造输入 x = np.concatenate((real_images,fake_images)) #训练类别标签 labels = np.concatenate((real_labels,fake_labels)) #标签 y = np.ones([2*batch_size,1]) y[batch_size:,:] = 0.0 #训练模型 metrics = discriminator.train_on_batch(x,[y,labels]) fmt = '%d: [disc loss: %f, srcloss: %f],' fmt += 'lbloss: %f, srcacc: %f, lblacc: %f' log = fmt % (i,metrics[0],metrics[1],metrics[2],metrics[3],metrics[4]) #train adversarial network for 1 batch noise = np.random.uniform(-1.,1.,size=(batch_size,latent_size)) fake_labels = np.eye(num_labels)[np.random.choice(num_labels,batch_size)] y = np.ones([batch_size,1]) metrics = adversarial.train_on_batch([noise,fake_labels],[y,fake_labels]) fmt = "%s [advr loss: %f, srcloss: %f," fmt += "lblloss: %f, srcacc: %f, lblacc: %f]" log = fmt % (log, metrics[0], metrics[1], metrics[2], metrics[3], metrics[4]) print(log) if (i + 1) % save_interval == 0: # 绘制生成图片 plot_images(generator,noise_input=noise_input, noise_label=noise_label,show=False, step=(i + 1), model_name=model_name) generator.save(model_name + ".h5")
def plot_images(generator, noise_input, noise_label=None, noise_codes=None, show=False, step=0, model_name="gan"): """生成虚假图片及绘制 # Arguments generator (Model): 生成模型 noise_input (ndarray): 潜在模型 show (bool): 是否展示 step (int): step值 model_name (string): 模型名称 """ os.makedirs(model_name, exist_ok=True) filename = os.path.join(model_name, "%05d.png" % step) rows = int(math.sqrt(noise_input.shape[0])) if noise_label is not None: noise_input = [noise_input, noise_label] if noise_codes is not None: noise_input += noise_codes images = generator.predict(noise_input) plt.figure(figsize=(2.2, 2.2)) num_images = images.shape[0] image_size = images.shape[1] for i in range(num_images): plt.subplot(rows, rows, i + 1) image = np.reshape(images[i], [image_size, image_size]) plt.imshow(image, cmap='gray') plt.axis('off') plt.savefig(filename) if show: plt.show() else: plt.close('all')
#运行
if __name__ == '__main__':
build_and_train_models()
step=1000:
step=15000:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。