当前位置:   article > 正文

半监督学习之MixMatch(代码解读 ablation)

mixmatch

接下来将按照顺序讲解每一个文件的作用

ablation 

ab_mixmatch.py

这段代码定义了额外的标志来测试MixMatch方法实现的不同部分。MixMatch算法是一种半监督学习方法,利用标记和未标记的数据来训练模型。

  1. import functools
  2. import os
  3. from absl import app
  4. from absl import flags
  5. from easydict import EasyDict
  6. from libml import layers, utils, models
  7. from libml.data_pair import DATASETS
  8. from libml.layers import MixMode
  9. import tensorflow as tf

这段代码是一个 Python 代码文件的一部分,它使用了一些常用的 Python 库和自定义库来实现深度学习的数据处理和模型训练

下面是这段代码的主要作用和功能:

1、导入必要的Python库:

  • functools:用于高阶函数编程。
  • os:用于与操作系统进行交互,例如获取环境变量和文件路径等。
  • absl:一个用于 Python 应用程序的命令行参数解析器。
  • easydict:提供了一种更加方便的字典方式来访问字典对象中的元素。

2.导入自定义库:

  • libml:这是一个自定义的 Python 库,包含了一些用于深度学习的数据处理和模型训练的模块。在这段代码中,我们使用了 layersutilsmodels 模块。
  • libml.data_pair:这是一个自定义的 Python 模块,它包含了一些用于深度学习的数据处理的方法。

3.定义一个 MixMode 枚举变量,用于表示数据集混合的模式。

4.使用 TensorFlow 2.x 版本的 API 构建深度学习模型。

FLAGS = flags.FLAGS

这一行代码定义了一个全局变量 FLAGS,它是 absl.flags.FLAGS 对象的一个实例。这个实例用于存储和管理命令行参数,以便在 Python 应用程序中使用这些参数。

在使用 absl.flags 库时,首先需要创建一个 FLAGS 对象实例,然后可以使用 DEFINE_xxx() 方法来定义命令行参数。在程序中引用这些参数时,可以通过 FLAGS.xxx 的方式来访问它们的值。

  1. class AblationMixMatch(models.MultiModel):
  2. def augment(self, x, l, beta, **kwargs):
  3. assert 0, 'Do not call.'
  4. def guess_label(self, y, classifier, T, getter, **kwargs):
  5. del kwargs
  6. logits_y = [classifier(yi, training=True, getter=getter) for yi in y]
  7. logits_y = tf.concat(logits_y, 0)
  8. # Compute predicted probability distribution py.
  9. p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])
  10. p_model_y = tf.reduce_mean(p_model_y, axis=0)
  11. # Compute the target distribution.
  12. p_target = tf.pow(p_model_y, 1. / T)
  13. p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
  14. return EasyDict(p_target=p_target, p_model=p_model_y)
  15. def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):
  16. hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
  17. x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
  18. y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
  19. l_in = tf.placeholder(tf.int32, [None], 'labels')
  20. wd *= lr
  21. w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
  22. augment = MixMode(mixmode)
  23. classifier = functools.partial(self.classifier, **kwargs)
  24. classifier(x_in, training=True) # Instantiate network.
  25. ema = tf.train.ExponentialMovingAverage(decay=ema)
  26. ema_op = ema.apply(utils.model_vars())
  27. ema_getter = functools.partial(utils.getter_ema, ema)
  28. y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
  29. guess = self.guess_label(tf.split(y, nu), classifier,
  30. getter=ema_getter if use_ema_guess else None, **kwargs)
  31. ly = tf.stop_gradient(guess.p_target)
  32. lx = tf.one_hot(l_in, self.nclass)
  33. xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
  34. x, y = xy[0], xy[1:]
  35. labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
  36. del xy, labels_xy
  37. batches = layers.interleave([x] + y, batch)
  38. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  39. logits = [classifier(batches[0], training=True)]
  40. post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
  41. for batchi in batches[1:]:
  42. logits.append(classifier(batchi, training=True))
  43. logits = layers.interleave(logits, batch)
  44. logits_x = logits[0]
  45. logits_y = tf.concat(logits[1:], 0)
  46. loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
  47. loss_xe = tf.reduce_mean(loss_xe)
  48. loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
  49. loss_l2u = tf.reduce_mean(loss_l2u)
  50. tf.summary.scalar('losses/xe', loss_xe)
  51. tf.summary.scalar('losses/l2u', loss_l2u)
  52. post_ops.append(ema_op)
  53. post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
  54. train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
  55. with tf.control_dependencies([train_op]):
  56. train_op = tf.group(*post_ops)
  57. # Tuning op: only retrain batch norm.
  58. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  59. classifier(batches[0], training=True)
  60. train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  61. if v not in skip_ops])
  62. return EasyDict(
  63. x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
  64. classify_raw=tf.nn.softmax(classifier(x_in, training=False)),
  65. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))

这段代码实现了一个名为AblationMixMatch的模型类,继承自MultiModel类,用于训练一个混合标签的半监督学习模型。其中的函数和参数的作用如下:

  • augment()函数:用于数据增强,但是这个函数的实现是抛出异常,不会被调用,因此可以被认为是空函数。
  • guess_label()函数:用于猜测标签,将无标签数据集的输出与训练好的分类器结合,计算猜测的标签。
  • model()函数:定义了整个模型的结构和训练过程,接受一系列的超参数,包括nu(无标签数据集的数量)、w_match(混合损失的权重)、warmup_kimg(warmup的步长)、batch(batch size)、lr(学习率)、wd(权重衰减系数)、ema(指数平均的系数)、beta(数据增强的系数)等。

在model()函数中,首先定义了输入placeholder的形状,然后对分类器进行了实例化,同时对模型进行了初始化和平均操作。接下来,对数据进行了数据增强和混合,然后将增强后的数据分别送入分类器中进行训练,并计算交叉熵和L2损失。最后定义了训练和调参操作,以及输出分类器的原始输出和指数平均后的输出。

具体来说,就是

  1. def augment(self, x, l, beta, **kwargs):
  2. assert 0, 'Do not call.'
'
运行

这个方法是一个占位符,代码中没有实际使用到。它被定义在 AblationMixMatch 类中作为一个抽象方法。如果这个方法被调用,代码会抛出一个异常,提示不应该直接调用它。这种设计方式通常是为了让子类必须实现这个方法,而不是使用父类的默认实现。在本例中,它的目的可能是为了强制子类实现一个数据增强的方法。

  1. def guess_label(self, y, classifier, T, getter, **kwargs):
  2. del kwargs
  3. logits_y = [classifier(yi, training=True, getter=getter) for yi in y]
  4. logits_y = tf.concat(logits_y, 0)
  5. # Compute predicted probability distribution py.
  6. p_model_y = tf.reshape(tf.nn.softmax(logits_y), [len(y), -1, self.nclass])
  7. p_model_y = tf.reduce_mean(p_model_y, axis=0)
  8. # Compute the target distribution.
  9. p_target = tf.pow(p_model_y, 1. / T)
  10. p_target /= tf.reduce_sum(p_target, axis=1, keep_dims=True)
  11. return EasyDict(p_target=p_target, p_model=p_model_y)
'
运行

这段代码是实现了一个模型对给定的标签 y 进行预测,其使用了一个分类器对标签进行推断,同时还有一个温度参数 T,用于控制预测的概率分布平滑程度。

首先,对每个标签 y,使用分类器classifier得到一个输出 logits_y,将所有的logits_y在第0个维度上进原始的logits_y是一个列表,得到一个新的张量。假设原始的logits_y每个元素都是形状为[batch_size, num_classes]的张量,那么拼接后的张量形状为[(len(logits_y) * batch_size), num_classes],其中len(logits_y)表示logits_y列表的长度。(这一步得到的是所有logits_y的值,)

然后,使用 softmax 函数将其转换为概率分布 p_model_y

接着,将所有标签的概率分布 p_model_y 求平均得到整个数据集的概率分布这里的平均就是按列先求和再平均)。

最后,使用温度参数 T 对整个数据集的概率分布进行平滑,得到目标分布 p_target。该函数的返回值包含了目标分布 p_target 和整个数据集的概率分布 p_model_y。

  • 对张量p_model_y中的每个元素取T次方根(即1/T次幂),得到一个新的张量p_target。如果原始的p_model_y是一个形状为[num_classes]的张量,那么经过tf.pow操作后,得到的新张量p_target形状仍为[num_classes],其中每个元素的值是原始张量对应元素的T次方根。
  • 是对张量p_target在第1个维度(即num_classes维度)上进行归一化,得到一个新的张量p_target。具体地说,如果原始的p_target是一个形状为[batch_size, num_classes]的张量,那么经过reduce_sum操作后,得到的是一个形状为[batch_size, 1]的张量,其中每个元素是原始张量在该维度上的和。接着,使用除法操作将原始张量中的每个元素除以对应的和,从而得到新张量p_target。

  1. def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):
  2. hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
  3. x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
  4. y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
  5. l_in = tf.placeholder(tf.int32, [None], 'labels')
  6. wd *= lr
  7. w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
  8. augment = MixMode(mixmode)
  9. classifier = functools.partial(self.classifier, **kwargs)
  10. classifier(x_in, training=True) # Instantiate network.
  11. ema = tf.train.ExponentialMovingAverage(decay=ema)
  12. ema_op = ema.apply(utils.model_vars())
  13. ema_getter = functools.partial(utils.getter_ema, ema)
  14. y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
  15. guess = self.guess_label(tf.split(y, nu), classifier,
  16. getter=ema_getter if use_ema_guess else None, **kwargs)
  17. ly = tf.stop_gradient(guess.p_target)
  18. lx = tf.one_hot(l_in, self.nclass)
  19. xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
  20. x, y = xy[0], xy[1:]
  21. labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
  22. del xy, labels_xy
  23. batches = layers.interleave([x] + y, batch)
  24. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  25. logits = [classifier(batches[0], training=True)]
  26. post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
  27. for batchi in batches[1:]:
  28. logits.append(classifier(batchi, training=True))
  29. logits = layers.interleave(logits, batch)
  30. logits_x = logits[0]
  31. logits_y = tf.concat(logits[1:], 0)
  32. loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
  33. loss_xe = tf.reduce_mean(loss_xe)
  34. loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
  35. loss_l2u = tf.reduce_mean(loss_l2u)
  36. tf.summary.scalar('losses/xe', loss_xe)
  37. tf.summary.scalar('losses/l2u', loss_l2u)
  38. post_ops.append(ema_op)
  39. post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
  40. train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
  41. with tf.control_dependencies([train_op]):
  42. train_op = tf.group(*post_ops)
  43. # Tuning op: only retrain batch norm.
  44. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  45. classifier(batches[0], training=True)
  46. train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  47. if v not in skip_ops])
  48. return EasyDict(
  49. x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
  50. classify_raw=tf.nn.softmax(classifier(x_in, training=False)),
  51. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
'
运行

该段代码是一个方法model,包含了训练模型的整个过程。该方法接受一些参数,如nu、w_match、warmup_kimg、batch、lr、wd、ema、beta、mixmode、use_ema_guess等,其中x_in表示输入图像,y_in表示与x_in对应的标签图像,l_in表示标签的类别。该方法的目标是在给定标签样本的情况下,使用半监督学习算法来训练分类模型。

该方法的大致流程如下:

  1. 定义输入placeholder:x_in、y_in、l_in。

  2. 对于给定的参数,进行预处理:计算wd * lr、w_match、augment、classifier等。

  3. 对输入图像x_in进行一次前向传播,以便实例化网络。同时,定义一个ExponentialMovingAverage对象ema,并应用于模型变量。

  4. 将标签图像y_in展开成一维张量,并根据guess_label方法和classifier对其进行预测。其中,guess_label方法会使用模型对标签图像进行猜测,并返回猜测后的标签,即p_target。使用tf.stop_gradient方法对p_target进行梯度截断,以防止误差反向传播。

  5. 对标签图像进行one-hot编码,得到labels_x,将x_in和y用MixMode方法进行数据增强,并将labels_x和p_target合并成labels_y。

  6. 将增强后的数据集拆分成batch,并使用分类器对每个batch进行前向传播,得到对应的logits。将logits_x和logits_y分别提取出来。

  7. 计算交叉熵损失loss_xe和l2正则化损失loss_l2u,并计算它们的平均值。

  8. 进行优化操作。首先,使用Adam优化器对loss_xe和w_match * loss_l2u进行优化。然后,将ema_op和model_vars()中所有名称带有kernel的变量进行指数滑动平均操作,再将它们乘以(1-wd)进行权重衰减。最后,将所有操作合并成train_op。

  9. 对于调参,只重新训练batch norm,即将所有除skip_ops外的其他更新操作合并为train_bn。

  10. 返回一个EasyDict对象,包含了x_in、y_in、l_in、train_op、tune_op、classify_raw、classify_op等。其中,classify_raw表示在没有应用ema时,分类模型对x_in进行前向传播的结果,classify_op表示在应用ema之后,分类模型对x_in进行前向传播的结果。

展开来讲:

  1. def model(self, nu, w_match, warmup_kimg, batch, lr, wd, ema, beta, mixmode, use_ema_guess, **kwargs):
  2. hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
  3. x_in = tf.placeholder(tf.float32, [None] + hwc, 'x')
  4. y_in = tf.placeholder(tf.float32, [None, nu] + hwc, 'y')
  5. l_in = tf.placeholder(tf.int32, [None], 'labels')
  6. wd *= lr
  7. w_match *= tf.clip_by_value(tf.cast(self.step, tf.float32) / (warmup_kimg << 10), 0, 1)
  8. augment = MixMode(mixmode)
  9. classifier = functools.partial(self.classifier, **kwargs)
'
运行

这是一个模型定义函数,输入参数包括nu(无标签数据集的数量)、w_match(匹配权重)、warmup_kimg(预热步长)、batch(批大小)、lr(学习率)、wd(权重衰减)、ema(指数滑动平均系数)、beta(数据增强的beta参数)、mixmode(数据增强模式)和其他可选参数。该函数返回一个分类器。

该函数首先根据数据集的高、宽和通道数创建一个输入占位符x_in和一个标签占位符y_in。

然后,将权重衰减乘以学习率,并将匹配权重乘以一个warmup_kimg参数,以在前几个迭代中逐渐增加该权重。

接着,使用给定的数据增强模式创建一个数据增强器augment。

最后,函数返回一个分类器,该分类器使用self.classifier函数作为主要分类器,其中的参数使用了kwargs,该函数是一个偏函数,其中已经部分确定了一些参数。

  1. classifier(x_in, training=True) # Instantiate network.
  2. ema = tf.train.ExponentialMovingAverage(decay=ema)
  3. ema_op = ema.apply(utils.model_vars())
  4. ema_getter = functools.partial(utils.getter_ema, ema)

这段代码用于实例化一个神经网络分类器(classifier),并使用指数移动平均(Exponential Moving Average)对其参数进行平滑处理。

首先,通过调用classifier(x_in, training=True)来实例化网络,其中x_in是输入数据,training=True表示在训练模式下运行网络。

然后,使用指数移动平均(Exponential Moving Average,简称EMA)对网络的参数进行平滑处理。具体来说,通过调用tf.train.ExponentialMovingAverage(decay=ema)来创建一个指数移动平均器,其中decay参数指定了平均的衰减率

接着,通过调用ema.apply(utils.model_vars())来将指数移动平均器应用于网络的所有参数。这将为每个参数创建一个EMA副本,并更新其值。

最后,使用functools.partial将utils.getter_ema和EMA副本绑定在一起,创建一个ema_getter函数,用于在测试模式下获取网络参数的EMA副本。这将确保在测试模式下,网络参数将始终是平滑的EMA副本,而不是训练模式下的原始参数。

  1. y = tf.reshape(tf.transpose(y_in, [1, 0, 2, 3, 4]), [-1] + hwc)
  2. guess = self.guess_label(tf.split(y, nu), classifier,
  3. getter=ema_getter if use_ema_guess else None, **kwargs)
  4. ly = tf.stop_gradient(guess.p_target)
  5. lx = tf.one_hot(l_in, self.nclass)
  6. xy, labels_xy = augment([x_in] + tf.split(y, nu), [lx] + [ly] * nu, [beta, beta])
  7. x, y = xy[0], xy[1:]
  8. labels_x, labels_y = labels_xy[0], tf.concat(labels_xy[1:], 0)
  9. del xy, labels_xy

这段代码用于对输入数据进行一些操作,包括将输入标签(y_in)进行转置和重塑,使用分类器对标签进行预测,然后进行数据增强和标签拼接。

首先,使用tf.transpose将输入标签y_in进行转置,以便在后面进行reshape操作。然后,使用tf.reshape将转置后的标签y_in重塑为[-1] + hwc的形状,其中hwc表示标签y_in的高度、宽度和通道数。这将y_in从一个5D张量转换为一个2D张量num_views 表示每个样本所包含的视角数量。

(详细解释)tf.transpose(y_in, [1,0,2,3,4]),将y_in的维度从(batch_size, num_views, height, width, channels)变为(num_views, batch_size, height, width, channels)。使用tf.reshape函数将转置后的标签y_in重塑为一个新的形状,即[-1] + hwc。将y_in从一个5D张量转换为一个2D张量,其中第一维度表示了所有样本和视角的总数。具体来说,将第二到第五维度平坦化,即(batch_size * num_views, height, width, channels)转换为(batch_size * num_views * height * width * channels)。这种重塑操作可以将标签变为一个长向量,方便后续操作。

接下来,使用分类器对标签进行预测。具体来说,使用self.guess_label函数对重塑后的标签y进行预测,其中guess.p_target是预测的概率,可以用于计算分类器的损失函数。如果use_ema_guess为True,则使用ema_getter获取分类器参数的EMA副本进行预测。

(详细解释)y形状为(batch_size * num_views * height * width * channels),tf.split(y, nu)y 按照 nu 的值在第一维度上进行分割,得到一个包含 nu 个张量的列表。classifier 是用于分类的网络模型,它将每个标签数据映射为一个类别,并同时输出每个类别的置信度。getter 是一个函数,用于获取模型中的参数,这里使用 ema_getter 函数来获取使用指数移动平均法(Exponential Moving Average,EMA)计算的模型参数,以提高模型的鲁棒性。**kwargs 表示其他可选的参数,这些参数会传递给 guess_label 方法。

生成标签的独热编码和停止梯度的标签分布。具体来说,l_in 是一个形状为 (batch_size, ) 的张量,表示输入的真实标签。self.nclass 是一个标量,表示标签的类别数量。因此tf.one_hot(l_in, self.nclass) 会将真实标签 l_in 编码为一个形状为 (batch_size, self.nclass) 的独热编码张量,其中每一行表示一个标签的独热编码。

guess 是通过 guess_label 方法生成的一组伪标签。在该方法中,伪标签的生成是通过预测标签的分布来实现的,即 guess.p_target 表示标签数据的估计分布。为了避免在训练时反向传播误差到伪标签,导致网络训练不稳定,这里使用 tf.stop_gradient 函数将 guess.p_target 停止梯度,生成一个形状与其相同的新张量 ly

最终,lxly 分别表示真实标签和伪标签的独热编码,它们会被用于训练网络。

然后,进行数据增强和标签拼接。具体来说,将输入数据x_in和预测的标签guess.p_target(使用tf.split对预测的标签进行分割)传递给augment函数,对它们进行数据增强(augmentation)。augment函数返回增强后的数据和标签。最后,将增强后的数据x和增强后的标签y拆分为单独的张量,并将输入标签l_in进行one-hot编码,得到labels_x和labels_y。最后,删除xy和labels_xy以释放内存。

(详细解释)augment 是一个数据增强的函数,它接受三个参数:datalabelsparams,分别表示原始数据、标签和数据增强的参数。在这里,[x_in] + tf.split(y, nu) 表示将原始数据 x_in 和伪标签数据 y 按照视角数 nu 进行拆分,拼接成一个列表传递给 augment 函数。类似地,[lx] + [ly] * nu 表示将真实标签的独热编码 lx 和伪标签的独热编码 ly 按照视角数 nu 进行拆分,并使用列表推导式生成一个长度为 nu 的列表,最后将这两个列表拼接起来。params 参数中包含了两个值,都是标量 beta。它们用于控制数据增强时两种操作的强度,具体操作是随机剪裁和随机翻转。augment 函数的返回值是一个元组,包含增强后的数据和标签。在这里,xylabels_xy 分别表示增强后的数据和标签。其中,xy[0] 表示增强后的原始数据,xy[1:] 表示增强后的伪标签数据;labels_xy[0] 表示增强后的真实标签独热编码,labels_xy[1:] 表示增强后的伪标签独热编码。最后,通过将 xy 拆分为 xy,将 labels_xy 拆分为 labels_xlabels_y,分别表示增强后的原始数据、增强后的伪标签数据、增强后的真实标签独热编码和增强后的伪标签独热编码,用于训练网络。最后通过 del xy, labels_xy 删除不再需要的变量,释放内存。

  1. batches = layers.interleave([x] + y, batch)
  2. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  3. logits = [classifier(batches[0], training=True)]
  4. post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
  5. for batchi in batches[1:]:
  6. logits.append(classifier(batchi, training=True))
  7. logits = layers.interleave(logits, batch)
  8. logits_x = logits[0]
  9. logits_y = tf.concat(logits[1:], 0)

这段代码主要是为了计算训练数据和伪标签数据的logits(分类器输出的未经softmax处理的概率),并计算损失函数。

首先通过调用layers.interleave函数将训练数据和伪标签数据交错分组,形成一个新的batch列表,其中第一个元素是训练数据,其余元素是伪标签数据。

然后通过循环遍历每个batch,调用分类器函数classifier计算每个batch的logits。在计算logits的过程中,通过设置training=True来启用训练模式,以便在BN层中记录训练过程中的均值和方差,并在测试过程中使用它们进行归一化。

计算完logits后,通过调用layers.interleave函数将它们重新交错分组,然后将第一个元素赋给logits_x变量,其余元素赋给logits_y变量。

  1. loss_xe = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_x, logits=logits_x)
  2. loss_xe = tf.reduce_mean(loss_xe)
  3. loss_l2u = tf.square(labels_y - tf.nn.softmax(logits_y))
  4. loss_l2u = tf.reduce_mean(loss_l2u)
  5. tf.summary.scalar('losses/xe', loss_xe)
  6. tf.summary.scalar('losses/l2u', loss_l2u)

这段代码计算了两个损失函数。第一个是 softmax 交叉熵损失函数,用来计算有标签数据的分类误差,它被赋值给了变量 loss_xe。第二个是 L2 损失函数,用于衡量无标签数据的预测结果与其平滑后的伪标签之间的差异,它被赋值给了变量 loss_l2u。这两个损失函数分别使用了 TensorFlow 中的 tf.nn.softmax_cross_entropy_with_logits_v2()tf.square() 函数进行计算,并用 tf.reduce_mean() 函数求取了它们的平均值。在这里,tf.summary.scalar() 函数被用来在 TensorBoard 中记录损失函数的值。

  1. post_ops.append(ema_op)
  2. post_ops.extend([tf.assign(v, v * (1 - wd)) for v in utils.model_vars('classify') if 'kernel' in v.name])
  3. train_op = tf.train.AdamOptimizer(lr).minimize(loss_xe + w_match * loss_l2u, colocate_gradients_with_ops=True)
  4. with tf.control_dependencies([train_op]):
  5. train_op = tf.group(*post_ops)
  6. # Tuning op: only retrain batch norm.
  7. skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  8. classifier(batches[0], training=True)
  9. train_bn = tf.group(*[v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  10. if v not in skip_ops])

这段代码是在训练模型的过程中,定义了一些后处理操作(post_ops),然后使用Adam优化器最小化交叉熵损失(loss_xe)和L2正则化损失(loss_l2u)的和。其中,L2正则化损失用于匹配有标签样本和无标签样本的特征分布,以实现半监督学习的目的。tf.summary.scalar用于记录损失的变化情况。with tf.control_dependencies([train_op])语句确保在进行后续操作之前,train_op操作先被执行。另外,还定义了一个操作train_bn,用于只重新训练BN层。

  1. return EasyDict(
  2. x=x_in, y=y_in, label=l_in, train_op=train_op, tune_op=train_bn,
  3. classify_raw=tf.nn.softmax(classifier(x_in, training=False)),
  4. classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))

这段代码返回了一个包含各种操作和张量的EasyDict对象。它包括输入张量x_in,y_in和l_in,两个分类器的softmax输出,训练操作train_op和调整操作train_bn,以及其他一些操作。这个对象的目的是使训练和测试代码更加简洁和易于理解。

  1. def main(argv):
  2. del argv # Unused.
  3. dataset = DATASETS[FLAGS.dataset]()
  4. log_width = utils.ilog2(dataset.width)
  5. model = AblationMixMatch(
  6. os.path.join(FLAGS.train_dir, dataset.name),
  7. dataset,
  8. lr=FLAGS.lr,
  9. wd=FLAGS.wd,
  10. arch=FLAGS.arch,
  11. batch=FLAGS.batch,
  12. nclass=dataset.nclass,
  13. ema=FLAGS.ema,
  14. beta=FLAGS.beta,
  15. use_ema_guess=FLAGS.use_ema_guess,
  16. T=FLAGS.T,
  17. mixmode=FLAGS.mixmode,
  18. nu=FLAGS.nu,
  19. w_match=FLAGS.w_match,
  20. warmup_kimg=FLAGS.warmup_kimg,
  21. scales=FLAGS.scales or (log_width - 2),
  22. filters=FLAGS.filters,
  23. repeat=FLAGS.repeat)
  24. model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
'
运行

这段代码是一个调用AblationMixMatch类实例化一个模型,并训练的函数。首先会根据FLAGS中的数据集名称选择对应的数据集,然后通过AblationMixMatch类构造模型。FLAGS中包含了训练需要用到的超参数,例如学习率、权重衰减、卷积神经网络结构等。最后,调用模型的train方法,训练模型并输出训练结果。其中,FLAGS.train_kimg和FLAGS.report_kimg是指训练步数和结果输出步数,都需要左移10位,因为模型使用的是Mini-batch SGD,每一步的batch size是2的整数次幂

  1. if __name__ == '__main__':
  2. utils.setup_tf()
  3. flags.DEFINE_float('wd', 0.02, 'Weight decay.')
  4. flags.DEFINE_float('ema', 0.999, 'Exponential moving average of params.')
  5. flags.DEFINE_float('beta', 0.5, 'Mixup beta distribution.')
  6. flags.DEFINE_bool('use_ema_guess', False, 'Whether to use EMA parameters when guessing labels.')
  7. flags.DEFINE_float('T', 0.5, 'Softmax sharpening temperature.')
  8. flags.DEFINE_enum('mixmode', 'xxy.yxy', MixMode.MODES, 'Mixup mode')
  9. flags.DEFINE_float('w_match', 100, 'Weight for distribution matching loss.')
  10. flags.DEFINE_integer('warmup_kimg', 128, 'Warmup in kimg for the matching loss.')
  11. flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
  12. flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
  13. flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
  14. FLAGS.set_default('dataset', 'cifar10.3@250-5000')
  15. FLAGS.set_default('batch', 64)
  16. FLAGS.set_default('lr', 0.002)
  17. FLAGS.set_default('train_kimg', 1 << 16)
  18. app.run(main)

这段代码是一个 Python 脚本的主函数,会在运行时执行。该脚本提供了许多可选的命令行参数,用于指定不同的超参数设置。在此之后,脚本调用了 utils.setup_tf() 函数,该函数是一个工具函数,用于设置 TensorFlow 运行时的 GPU 环境等配置。最后,脚本调用了 app.run(main) 函数来运行 main 函数。main 函数主要是构建了一个 AblationMixMatch 模型,并调用 train 函数来训练模型。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/944233
推荐阅读
相关标签
  

闽ICP备14008679号