赞
踩
接下来将按照顺序讲解每一个文件的作用
ab_mixmatch.py
这段代码定义了额外的标志来测试MixMatch方法实现的不同部分。MixMatch算法是一种半监督学习方法,利用标记和未标记的数据来训练模型。
- import functools
- import os
-
- from absl import app
- from absl import flags
- from easydict import EasyDict
- from libml import layers, utils, models
- from libml.data_pair import DATASETS
- from libml.layers import MixMode
- import tensorflow as tf
这段代码是一个 Python 代码文件的一部分,它使用了一些常用的 Python 库和自定义库来实现深度学习的数据处理和模型训练。
下面是这段代码的主要作用和功能:
1、导入必要的Python库:
functools
:用于高阶函数编程。os
:用于与操作系统进行交互,例如获取环境变量和文件路径等。absl
:一个用于 Python 应用程序的命令行参数解析器。easydict
:提供了一种更加方便的字典方式来访问字典对象中的元素。2.导入自定义库:
libml
:这是一个自定义的 Python 库,包含了一些用于深度学习的数据处理和模型训练的模块。在这段代码中,我们使用了 layers
、utils
和 models
模块。libml.data_pair
:这是一个自定义的 Python 模块,它包含了一些用于深度学习的数据处理的方法。3.定义一个 MixMode
枚举变量,用于表示数据集混合的模式。
4.使用 TensorFlow 2.x 版本的 API 构建深度学习模型。
FLAGS = flags.FLAGS
这一行代码定义了一个全局变量 FLAGS
,它是 absl.flags.FLAGS
对象的一个实例。这个实例用于存储和管理命令行参数,以便在 Python 应用程序中使用这些参数。
在使用 absl.flags
库时,首先需要创建一个 FLAGS
对象实例,然后可以使用 DEFINE_xxx()
方法来定义命令行参数。在程序中引用这些参数时,可以通过 FLAGS.xxx
的方式来访问它们的值。
- class AblationMixMatch(models.MultiModel):
-
- def augment(self, x, l, beta, **kwargs):
- assert 0, 'Do not call.'
-
- def guess_label(self, y, classifier, T, getter, **kwargs):
- del kwargs
- logits_y = [classifier(yi, training=True, getter=getter) for yi in y]
- logits_y = tf.concat(logits_y, 0)
- # Compute predicted probability distribution py.
- p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])
- p_model_y = tf.reduce_mean(p_model_y, axis=0)
- # Compute the target distribution.
- p_target = tf.pow(p_model_y, 1. / T)
- p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
- return EasyDict(p_target=p_target, p_model=p_model_y)
-
- def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):
- hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
- x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
- y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
- l_in = tf.placeholder(tf.int32, [None], 'labels')
- wd *= lr
- w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
- augment = MixMode(mixmode)
- classifier = functools.partial(self.classifier, **kwargs)
-
- classifier(x_in, training=True) # Instantiate network.
- ema = tf.train.ExponentialMovingAverage(decay=ema)
- ema_op = ema.apply(utils.model_vars())
- ema_getter = functools.partial(utils.getter_ema, ema)
-
- y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
- guess = self.guess_label(tf.split(y, nu), classifier,
- getter=ema_getter if use_ema_guess else None, **kwargs)
- ly = tf.stop_gradient(guess.p_target)
- lx = tf.one_hot(l_in, self.nclass)
- xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
- x, y = xy[0], xy[1:]
- labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
- del xy, labels_xy
-
- batches = layers.interleave([x] + y, batch)
- skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- logits = [classifier(batches[0], training=True)]
- post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
- for batchi in batches[1:]:
- logits.append(classifier(batchi, training=True))
- logits = layers.interleave(logits, batch)
- logits_x = logits[0]
- logits_y = tf.concat(logits[1:], 0)
-
- loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
- loss_xe = tf.reduce_mean(loss_xe)
- loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
- loss_l2u = tf.reduce_mean(loss_l2u)
- tf.summary.scalar('losses/xe', loss_xe)
- tf.summary.scalar('losses/l2u', loss_l2u)
-
- post_ops.append(ema_op)
- post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
-
- train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
- with tf.control_dependencies([train_op]):
- train_op = tf.group(*post_ops)
-
- # Tuning op: only retrain batch norm.
- skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- classifier(batches[0], training=True)
- train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- if v not in skip_ops])
-
- return EasyDict(
- x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
- classify_raw=tf.nn.softmax(classifier(x_in, training=False)),
- classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
这段代码实现了一个名为AblationMixMatch的模型类,继承自MultiModel类,用于训练一个混合标签的半监督学习模型。其中的函数和参数的作用如下:
在model()函数中,首先定义了输入placeholder的形状,然后对分类器进行了实例化,同时对模型进行了初始化和平均操作。接下来,对数据进行了数据增强和混合,然后将增强后的数据分别送入分类器中进行训练,并计算交叉熵和L2损失。最后定义了训练和调参操作,以及输出分类器的原始输出和指数平均后的输出。
具体来说,就是
- def augment(self, x, l, beta, **kwargs):
- assert 0, 'Do not call.'
'运行
这个方法是一个占位符,代码中没有实际使用到。它被定义在 AblationMixMatch 类中作为一个抽象方法。如果这个方法被调用,代码会抛出一个异常,提示不应该直接调用它。这种设计方式通常是为了让子类必须实现这个方法,而不是使用父类的默认实现。在本例中,它的目的可能是为了强制子类实现一个数据增强的方法。
- def guess_label(self, y, classifier, T, getter, **kwargs):
- del kwargs
- logits_y = [classifier(yi, training=True, getter=getter) for yi in y]
- logits_y = tf.concat(logits_y, 0)
- # Compute predicted probability distribution py.
- p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])
- p_model_y = tf.reduce_mean(p_model_y, axis=0)
- # Compute the target distribution.
- p_target = tf.pow(p_model_y, 1. / T)
- p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
- return EasyDict(p_target=p_target, p_model=p_model_y)
'运行
这段代码是实现了一个模型对给定的标签 y 进行预测,其使用了一个分类器对标签进行推断,同时还有一个温度参数 T,用于控制预测的概率分布平滑程度。
首先,对每个标签 y,使用分类器classifier得到一个输出 logits_y,将所有的logits_y在第0个维度上进原始的logits_y是一个列表,得到一个新的张量。假设原始的logits_y每个元素都是形状为[batch_size, num_classes]的张量,那么拼接后的张量形状为[(len(logits_y) * batch_size), num_classes],其中len(logits_y)表示logits_y列表的长度。(这一步得到的是所有logits_y的值,)
然后,使用 softmax 函数将其转换为概率分布 p_model_y。
接着,将所有标签的概率分布 p_model_y 求平均得到整个数据集的概率分布(这里的平均就是按列先求和再平均)。
最后,使用温度参数 T 对整个数据集的概率分布进行平滑,得到目标分布 p_target。该函数的返回值包含了目标分布 p_target 和整个数据集的概率分布 p_model_y。
是对张量p_target在第1个维度(即num_classes维度)上进行归一化,得到一个新的张量p_target。具体地说,如果原始的p_target是一个形状为[batch_size, num_classes]的张量,那么经过reduce_sum操作后,得到的是一个形状为[batch_size, 1]的张量,其中每个元素是原始张量在该维度上的和。接着,使用除法操作将原始张量中的每个元素除以对应的和,从而得到新张量p_target。
- def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):
- hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
- x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
- y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
- l_in = tf.placeholder(tf.int32, [None], 'labels')
- wd *= lr
- w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
- augment = MixMode(mixmode)
- classifier = functools.partial(self.classifier, **kwargs)
-
- classifier(x_in, training=True) # Instantiate network.
- ema = tf.train.ExponentialMovingAverage(decay=ema)
- ema_op = ema.apply(utils.model_vars())
- ema_getter = functools.partial(utils.getter_ema, ema)
-
- y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
- guess = self.guess_label(tf.split(y, nu), classifier,
- getter=ema_getter if use_ema_guess else None, **kwargs)
- ly = tf.stop_gradient(guess.p_target)
- lx = tf.one_hot(l_in, self.nclass)
- xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
- x, y = xy[0], xy[1:]
- labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
- del xy, labels_xy
-
- batches = layers.interleave([x] + y, batch)
- skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- logits = [classifier(batches[0], training=True)]
- post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
- for batchi in batches[1:]:
- logits.append(classifier(batchi, training=True))
- logits = layers.interleave(logits, batch)
- logits_x = logits[0]
- logits_y = tf.concat(logits[1:], 0)
-
- loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
- loss_xe = tf.reduce_mean(loss_xe)
- loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
- loss_l2u = tf.reduce_mean(loss_l2u)
- tf.summary.scalar('losses/xe', loss_xe)
- tf.summary.scalar('losses/l2u', loss_l2u)
-
- post_ops.append(ema_op)
- post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
-
- train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
- with tf.control_dependencies([train_op]):
- train_op = tf.group(*post_ops)
-
- # Tuning op: only retrain batch norm.
- skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- classifier(batches[0], training=True)
- train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- if v not in skip_ops])
-
- return EasyDict(
- x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
- classify_raw=tf.nn.softmax(classifier(x_in, training=False)),
- classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
'运行
该段代码是一个方法model,包含了训练模型的整个过程。该方法接受一些参数,如nu、w_match、warmup_kimg、batch、lr、wd、ema、beta、mixmode、use_ema_guess等,其中x_in表示输入图像,y_in表示与x_in对应的标签图像,l_in表示标签的类别。该方法的目标是在给定标签样本的情况下,使用半监督学习算法来训练分类模型。
该方法的大致流程如下:
定义输入placeholder:x_in、y_in、l_in。
对于给定的参数,进行预处理:计算wd * lr、w_match、augment、classifier等。
对输入图像x_in进行一次前向传播,以便实例化网络。同时,定义一个ExponentialMovingAverage对象ema,并应用于模型变量。
将标签图像y_in展开成一维张量,并根据guess_label方法和classifier对其进行预测。其中,guess_label方法会使用模型对标签图像进行猜测,并返回猜测后的标签,即p_target。使用tf.stop_gradient方法对p_target进行梯度截断,以防止误差反向传播。
对标签图像进行one-hot编码,得到labels_x,将x_in和y用MixMode方法进行数据增强,并将labels_x和p_target合并成labels_y。
将增强后的数据集拆分成batch,并使用分类器对每个batch进行前向传播,得到对应的logits。将logits_x和logits_y分别提取出来。
计算交叉熵损失loss_xe和l2正则化损失loss_l2u,并计算它们的平均值。
进行优化操作。首先,使用Adam优化器对loss_xe和w_match * loss_l2u进行优化。然后,将ema_op和model_vars()中所有名称带有kernel的变量进行指数滑动平均操作,再将它们乘以(1-wd)进行权重衰减。最后,将所有操作合并成train_op。
对于调参,只重新训练batch norm,即将所有除skip_ops外的其他更新操作合并为train_bn。
返回一个EasyDict对象,包含了x_in、y_in、l_in、train_op、tune_op、classify_raw、classify_op等。其中,classify_raw表示在没有应用ema时,分类模型对x_in进行前向传播的结果,classify_op表示在应用ema之后,分类模型对x_in进行前向传播的结果。
展开来讲:
- def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):
- hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
- x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
- y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
- l_in = tf.placeholder(tf.int32, [None], 'labels')
- wd *= lr
- w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
- augment = MixMode(mixmode)
- classifier = functools.partial(self.classifier, **kwargs)
'运行
这是一个模型定义函数,输入参数包括nu(无标签数据集的数量)、w_match(匹配权重)、warmup_kimg(预热步长)、batch(批大小)、lr(学习率)、wd(权重衰减)、ema(指数滑动平均系数)、beta(数据增强的beta参数)、mixmode(数据增强模式)和其他可选参数。该函数返回一个分类器。
该函数首先根据数据集的高、宽和通道数创建一个输入占位符x_in和一个标签占位符y_in。
然后,将权重衰减乘以学习率,并将匹配权重乘以一个warmup_kimg参数,以在前几个迭代中逐渐增加该权重。
接着,使用给定的数据增强模式创建一个数据增强器augment。
最后,函数返回一个分类器,该分类器使用self.classifier函数作为主要分类器,其中的参数使用了kwargs,该函数是一个偏函数,其中已经部分确定了一些参数。
- classifier(x_in, training=True) # Instantiate network.
- ema = tf.train.ExponentialMovingAverage(decay=ema)
- ema_op = ema.apply(utils.model_vars())
- ema_getter = functools.partial(utils.getter_ema, ema)
这段代码用于实例化一个神经网络分类器(classifier),并使用指数移动平均(Exponential Moving Average)对其参数进行平滑处理。
首先,通过调用classifier(x_in, training=True)来实例化网络,其中x_in是输入数据,training=True表示在训练模式下运行网络。
然后,使用指数移动平均(Exponential Moving Average,简称EMA)对网络的参数进行平滑处理。具体来说,通过调用tf.train.ExponentialMovingAverage(decay=ema)来创建一个指数移动平均器,其中decay参数指定了平均的衰减率。
接着,通过调用ema.apply(utils.model_vars())来将指数移动平均器应用于网络的所有参数。这将为每个参数创建一个EMA副本,并更新其值。
最后,使用functools.partial将utils.getter_ema和EMA副本绑定在一起,创建一个ema_getter函数,用于在测试模式下获取网络参数的EMA副本。这将确保在测试模式下,网络参数将始终是平滑的EMA副本,而不是训练模式下的原始参数。
-
- y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
- guess = self.guess_label(tf.split(y, nu), classifier,
- getter=ema_getter if use_ema_guess else None, **kwargs)
- ly = tf.stop_gradient(guess.p_target)
- lx = tf.one_hot(l_in, self.nclass)
- xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
- x, y = xy[0], xy[1:]
- labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
- del xy, labels_xy
这段代码用于对输入数据进行一些操作,包括将输入标签(y_in)进行转置和重塑,使用分类器对标签进行预测,然后进行数据增强和标签拼接。
首先,使用tf.transpose将输入标签y_in进行转置,以便在后面进行reshape操作。然后,使用tf.reshape将转置后的标签y_in重塑为[-1] + hwc的形状,其中hwc表示标签y_in的高度、宽度和通道数。这将y_in从一个5D张量转换为一个2D张量。num_views
表示每个样本所包含的视角数量。
(详细解释)tf.transpose(y_in, [1,0,2,3,4]),将y_in的维度从(batch_size, num_views, height, width, channels)变为(num_views, batch_size, height, width, channels)。使用tf.reshape函数将转置后的标签y_in重塑为一个新的形状,即[-1] + hwc。将y_in从一个5D张量转换为一个2D张量,其中第一维度表示了所有样本和视角的总数。具体来说,将第二到第五维度平坦化,即(batch_size * num_views, height, width, channels)转换为(batch_size * num_views * height * width * channels)。这种重塑操作可以将标签变为一个长向量,方便后续操作。
接下来,使用分类器对标签进行预测。具体来说,使用self.guess_label函数对重塑后的标签y进行预测,其中guess.p_target是预测的概率,可以用于计算分类器的损失函数。如果use_ema_guess为True,则使用ema_getter获取分类器参数的EMA副本进行预测。
(详细解释)y形状为(batch_size * num_views * height * width * channels),tf.split(y, nu)
将 y
按照 nu
的值在第一维度上进行分割,得到一个包含 nu
个张量的列表。classifier
是用于分类的网络模型,它将每个标签数据映射为一个类别,并同时输出每个类别的置信度。getter
是一个函数,用于获取模型中的参数,这里使用 ema_getter
函数来获取使用指数移动平均法(Exponential Moving Average,EMA)计算的模型参数,以提高模型的鲁棒性。**kwargs
表示其他可选的参数,这些参数会传递给 guess_label
方法。
生成标签的独热编码和停止梯度的标签分布。具体来说,l_in
是一个形状为 (batch_size, )
的张量,表示输入的真实标签。self.nclass
是一个标量,表示标签的类别数量。因此tf.one_hot(l_in, self.nclass)
会将真实标签 l_in
编码为一个形状为 (batch_size, self.nclass)
的独热编码张量,其中每一行表示一个标签的独热编码。
guess
是通过 guess_label
方法生成的一组伪标签。在该方法中,伪标签的生成是通过预测标签的分布来实现的,即 guess.p_target
表示标签数据的估计分布。为了避免在训练时反向传播误差到伪标签,导致网络训练不稳定,这里使用 tf.stop_gradient
函数将 guess.p_target
停止梯度,生成一个形状与其相同的新张量 ly
。
最终,lx
和 ly
分别表示真实标签和伪标签的独热编码,它们会被用于训练网络。
然后,进行数据增强和标签拼接。具体来说,将输入数据x_in和预测的标签guess.p_target(使用tf.split对预测的标签进行分割)传递给augment函数,对它们进行数据增强(augmentation)。augment函数返回增强后的数据和标签。最后,将增强后的数据x和增强后的标签y拆分为单独的张量,并将输入标签l_in进行one-hot编码,得到labels_x和labels_y。最后,删除xy和labels_xy以释放内存。
(详细解释)augment
是一个数据增强的函数,它接受三个参数:data
、labels
和 params
,分别表示原始数据、标签和数据增强的参数。在这里,[x_in] + tf.split(y, nu)
表示将原始数据 x_in
和伪标签数据 y
按照视角数 nu
进行拆分,拼接成一个列表传递给 augment
函数。类似地,[lx] + [ly] * nu
表示将真实标签的独热编码 lx
和伪标签的独热编码 ly
按照视角数 nu
进行拆分,并使用列表推导式生成一个长度为 nu
的列表,最后将这两个列表拼接起来。params
参数中包含了两个值,都是标量 beta
。它们用于控制数据增强时两种操作的强度,具体操作是随机剪裁和随机翻转。augment
函数的返回值是一个元组,包含增强后的数据和标签。在这里,xy
和 labels_xy
分别表示增强后的数据和标签。其中,xy[0]
表示增强后的原始数据,xy[1:]
表示增强后的伪标签数据;labels_xy[0]
表示增强后的真实标签独热编码,labels_xy[1:]
表示增强后的伪标签独热编码。最后,通过将 xy
拆分为 x
和 y
,将 labels_xy
拆分为 labels_x
和 labels_y
,分别表示增强后的原始数据、增强后的伪标签数据、增强后的真实标签独热编码和增强后的伪标签独热编码,用于训练网络。最后通过 del xy, labels_xy
删除不再需要的变量,释放内存。
- batches = layers.interleave([x] + y, batch)
- skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- logits = [classifier(batches[0], training=True)]
- post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
- for batchi in batches[1:]:
- logits.append(classifier(batchi, training=True))
- logits = layers.interleave(logits, batch)
- logits_x = logits[0]
- logits_y = tf.concat(logits[1:], 0)
这段代码主要是为了计算训练数据和伪标签数据的logits(分类器输出的未经softmax处理的概率),并计算损失函数。
首先通过调用layers.interleave
函数将训练数据和伪标签数据交错分组,形成一个新的batch列表,其中第一个元素是训练数据,其余元素是伪标签数据。
然后通过循环遍历每个batch,调用分类器函数classifier
计算每个batch的logits。在计算logits的过程中,通过设置training=True
来启用训练模式,以便在BN层中记录训练过程中的均值和方差,并在测试过程中使用它们进行归一化。
计算完logits后,通过调用layers.interleave
函数将它们重新交错分组,然后将第一个元素赋给logits_x
变量,其余元素赋给logits_y
变量。
- loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
- loss_xe = tf.reduce_mean(loss_xe)
- loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
- loss_l2u = tf.reduce_mean(loss_l2u)
- tf.summary.scalar('losses/xe', loss_xe)
- tf.summary.scalar('losses/l2u', loss_l2u)
这段代码计算了两个损失函数。第一个是 softmax 交叉熵损失函数,用来计算有标签数据的分类误差,它被赋值给了变量 loss_xe
。第二个是 L2 损失函数,用于衡量无标签数据的预测结果与其平滑后的伪标签之间的差异,它被赋值给了变量 loss_l2u
。这两个损失函数分别使用了 TensorFlow 中的 tf.nn.softmax_cross_entropy_with_logits_v2()
和 tf.square()
函数进行计算,并用 tf.reduce_mean()
函数求取了它们的平均值。在这里,tf.summary.scalar()
函数被用来在 TensorBoard 中记录损失函数的值。
- post_ops.append(ema_op)
- post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
-
- train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
- with tf.control_dependencies([train_op]):
- train_op = tf.group(*post_ops)
-
- # Tuning op: only retrain batch norm.
- skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- classifier(batches[0], training=True)
- train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- if v not in skip_ops])
这段代码是在训练模型的过程中,定义了一些后处理操作(post_ops
),然后使用Adam优化器最小化交叉熵损失(loss_xe
)和L2正则化损失(loss_l2u
)的和。其中,L2正则化损失用于匹配有标签样本和无标签样本的特征分布,以实现半监督学习的目的。tf.summary.scalar
用于记录损失的变化情况。with tf.control_dependencies([train_op])
语句确保在进行后续操作之前,train_op
操作先被执行。另外,还定义了一个操作train_bn
,用于只重新训练BN层。
- return EasyDict(
- x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
- classify_raw=tf.nn.softmax(classifier(x_in, training=False)),
- classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
这段代码返回了一个包含各种操作和张量的EasyDict对象。它包括输入张量x_in,y_in和l_in,两个分类器的softmax输出,训练操作train_op和调整操作train_bn,以及其他一些操作。这个对象的目的是使训练和测试代码更加简洁和易于理解。
- def main(argv):
- del argv # Unused.
- dataset = DATASETS[FLAGS.dataset]()
- log_width = utils.ilog2(dataset.width)
- model = AblationMixMatch(
- os.path.join(FLAGS.train_dir, dataset.name),
- dataset,
- lr=FLAGS.lr,
- wd=FLAGS.wd,
- arch=FLAGS.arch,
- batch=FLAGS.batch,
- nclass=dataset.nclass,
- ema=FLAGS.ema,
- beta=FLAGS.beta,
-
- use_ema_guess=FLAGS.use_ema_guess,
- T=FLAGS.T,
- mixmode=FLAGS.mixmode,
- nu=FLAGS.nu,
- w_match=FLAGS.w_match,
- warmup_kimg=FLAGS.warmup_kimg,
- scales=FLAGS.scales or (log_width - 2),
- filters=FLAGS.filters,
- repeat=FLAGS.repeat)
- model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
'运行
这段代码是一个调用AblationMixMatch类实例化一个模型,并训练的函数。首先会根据FLAGS中的数据集名称选择对应的数据集,然后通过AblationMixMatch类构造模型。FLAGS中包含了训练需要用到的超参数,例如学习率、权重衰减、卷积神经网络结构等。最后,调用模型的train方法,训练模型并输出训练结果。其中,FLAGS.train_kimg和FLAGS.report_kimg是指训练步数和结果输出步数,都需要左移10位,因为模型使用的是Mini-batch SGD,每一步的batch size是2的整数次幂
- if __name__ == '__main__':
- utils.setup_tf()
- flags.DEFINE_float('wd', 0.02, 'Weight decay.')
- flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
- flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.')
-
- flags.DEFINE_bool('use_ema_guess', False, 'Whether to use EMA parameters when guessing labels.')
- flags.DEFINE_float('T', 0.5, 'Softmax sharpening temperature.')
- flags.DEFINE_enum('mixmode', 'xxy.yxy', MixMode.MODES, 'Mixup mode')
- flags.DEFINE_float('w_match', 100, 'Weight for distribution matching loss.')
- flags.DEFINE_integer('warmup_kimg', 128, 'Warmup in kimg for the matching loss.')
- flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
- flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
- flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
- FLAGS.set_default('dataset', 'cifar10.3@250-5000')
- FLAGS.set_default('batch', 64)
- FLAGS.set_default('lr', 0.002)
- FLAGS.set_default('train_kimg', 1 << 16)
- app.run(main)
这段代码是一个 Python 脚本的主函数,会在运行时执行。该脚本提供了许多可选的命令行参数,用于指定不同的超参数设置。在此之后,脚本调用了 utils.setup_tf()
函数,该函数是一个工具函数,用于设置 TensorFlow 运行时的 GPU 环境等配置。最后,脚本调用了 app.run(main)
函数来运行 main
函数。main
函数主要是构建了一个 AblationMixMatch
模型,并调用 train
函数来训练模型。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。