当前位置:   article > 正文

元学习Meta-Learning_元学习的情景训练

元学习的情景训练

1. Introduction

通常在机器学习里,我们会使用某个场景的大量数据来训练模型;然而当场景发生改变,模型就需要重新训练。但是对于人类而言,一个小朋友成长过程中会见过许多物体的照片,某一天,当Ta(第一次)仅仅看了几张狗的照片,就可以很好地对狗和其他物体进行区分。

元学习Meta Learning,含义为学会学习,即learn to learn,就是带着这种对人类这种“学习能力”的期望诞生的。Meta Learning希望使得模型获取一种“学会学习”的能力,使其可以在获取已有“知识”的基础上快速学习新的任务,如:

  • 让Alphago迅速学会下象棋
  • 让一个猫咪图片分类器,迅速具有分类其他物体的能力

需要注意的是,虽然同样有“预训练”的意思在里面,但是元学习的内核区别于迁移学习(Transfer Learning),关于他们的区别,我会在下文进行阐述。
接下来,我们通过对比机器学习和元学习这两个概念的要素来加深对元学习这个概念的理解。
在这里插入图片描述
在机器学习中,训练单位是一条数据,通过数据来对模型进行优化;数据可以分为训练集、测试集和验证集。在元学习中,训练单位分层级了,第一层训练单位是任务,也就是说,元学习中要准备许多任务来进行学习,第二层训练单位才是每个任务对应的数据
二者的目的都是找一个Function,只是两个Function的功能不同,要做的事情不一样。机器学习中的Function直接作用于特征和标签,去寻找特征与标签之间的关联;而元学习中的Function是用于寻找新的f,新的f才会应用于具体的任务。有种不同阶导数的感觉。又有种老千层饼的感觉,你看到我在第二层,你把我想象成第一层,而其实我在第五层。。。

2. Meta Learning实施——以MAML为例

我们先对比机器学习的过程来进一步理解元学习。如下图所示,机器学习的一般过程如下:

  • 设计网络网络结构,如CNN、RNN等;
  • 选定某个分布来初始化参数;(以上其实决定了初始的f的长相,选择不同的网络结构或参数相当于定义了不同的f);
  • 喂训练数据,根据选定的Loss
  • Function计算Loss; 梯度下降,逐步更新 ;
  • 得到最终的f
    机器学习过程,引自李宏毅《深度学习》
    其中,红色方框里的“配置”都是由人为设计的,我们又叫做“超参数“。Meta Learning中希望把这些配置,如网络结构,参数初始化,优化器等由机器自行设计(注:此处区别于AutoML,迁移学习(Transfer Learning)和终身学习(Life Long Learning) ),使网络有更强的学习能力和表现。
    上文已经提到,【元学习中要准备许多任务来进行学习,而每个任务又有各自的训练集和测试集】。我们结合一个具体的任务,来介绍元学习和MAML的实施过程。
    有一个图像数据集叫Omniglot:https://github.com/brendenlake/omniglot。Omniglot包含1623个不同的火星文字符,每个字符包含20个手写的case。这个任务是判断每个手写的case属于哪一个火星文字符。
    如果我们要进行N-ways,K-shot(数据中包含N个字符类别,每个字符有K张图像)的一个图像分类任务。比如20-ways,1-shot分类的意思是说,要做一个20分类,但是每个分类下只有1张图像的任务。我们可以依据Omniglot构建很多N-ways,K-shot任务,这些任务将作为元学习的任务来源。构建的任务分为训练任务(Train Task),测试任务(Test Task)。特别地,每个任务包含自己的训练数据、测试数据,在元学习里,分别称为Support Set和Query Set
    MAML的目的是获取一组更好的模型初始化参数(即让模型自己学会初始化)。我们通过(许多)N-ways,K-shot的任务(训练任务)进行元学习的训练,使得模型学习到“先验知识”(初始化的参数)。这个“先验知识”在新的N-ways,K-shot任务上可以表现的更好。
    接下来介绍MAML的算法流程:
    MAML算法流程
    当然,在“预训练”阶段,也可以sample出1个batch的几个任务,那么在更新meta网络时,要使用sample出所有任务的梯度之和。
    注意:在MAML中,meta网络与子任务的网络结构必须完全相同。

这里面有几个小问题:

  1. MAML的执行过程与model pretraining & transfer learning的区别是什么?
  2. 为何在meta网络赋值给具体训练任务(如任务m)后,要先更训练任务的参数,再计算梯度,更新meta网络?
  3. 在更新训练任务的网络时,只走了一步,然后更新meta网络。为什么是一步,可以是多步吗?
    这三个问题是MAML中很核心的问题,大家可以先思考一下,我们将在后文进行解答。我们先看一下MAML的实现代码。
## 网络构建部分: refer: https://github.com/dragen1860/MAML-TensorFlow

#################################################
# 任务描述:5-ways,1-shot图像分类任务,图像统一处理成 84 * 84 * 3 = 21168的尺寸。
# support set:5 * 1
# query set:5 * 15
# 训练取1个batch的任务:batch size:4
# 对训练任务进行训练时,更新5次:K = 5
#################################################

print(support_x) # (4, 5, 21168) 
print(query_x) # (4, 75, 21168)
print(support_y) # (4, 5, 5)
print(query_y) # (4, 75, 5)
print(meta_batchsz) # 4
print(K) # 5

model = MAML()
model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')

class MAML:
    def __init__(self):
        pass
    def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):
        """
        :param support_xb: [4, 5, 84*84*3] 
        :param support_yb: [4, 5, n-way]
        :param query_xb:  [4, 75, 84*84*3]
        :param query_yb: [4, 75, n-way]
        :param K:  训练任务的网络更新步数
        :param meta_batchsz: 任务数,4
        """

        self.weights = self.conv_weights() # 创建或者复用网络参数;训练任务对应的网络复用meta网络的参数
        training = True if mode is 'train' else False      
        def meta_task(input):
            """
            :param support_x:   [setsz, 84*84*3] (5, 21168)
            :param support_y:   [setsz, n-way] (5, 5)
            :param query_x:     [querysz, 84*84*3] (75, 21168)
            :param query_y:     [querysz, n-way] (75, 5)
            :param training:    training or not, for batch_norm
            :return:
            """

            support_x, support_y, query_x, query_y = input
            query_preds, query_losses, query_accs = [], [], [] # 子网络更新K次,记录每一次queryset的结果
 
            ## 第0次对网络进行更新
            support_pred = self.forward(support_x, self.weights, training) # 前向计算support set
            support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set loss
            support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
                                                         tf.argmax(support_y, axis=1))
            grads = tf.gradients(support_loss, list(self.weights.values())) # 计算support set的梯度
            gvs = dict(zip(self.weights.keys(), grads))
            # 使用support set的梯度计算的梯度更新参数,theta_pi = theta - alpha * grads
            fast_weights = dict(zip(self.weights.keys(), \
                    [self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()]))

            # 使用梯度更新后的参数对quert set进行前向计算
            query_pred = self.forward(query_x, fast_weights, training)
            query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
            query_preds.append(query_pred)
            query_losses.append(query_loss)
 
            # 第1到 K-1次对网络进行更新
            for _ in range(1, K):           
                loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training),
                                                               labels=support_y)
                grads = tf.gradients(loss, list(fast_weights.values()))
                gvs = dict(zip(fast_weights.keys(), grads))
                fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key]
                                         for key in fast_weights.keys()]))
                query_pred = self.forward(query_x, fast_weights, training)
                query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
                # 子网络更新K次,记录每一次queryset的结果
                query_preds.append(query_pred)
                query_losses.append(query_loss)

            for i in range(K):
                query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1),
                                                                tf.argmax(query_y, axis=1)))
            result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
            return result

        # return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
        out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K]
        result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),
                           dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')
        support_pred_tasks, support_loss_tasks, support_acc_tasks, \
            query_preds_tasks, query_losses_tasks, query_accs_tasks = result

        if mode is 'train':
            self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz
            self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz
                                                    for j in range(K)]
            self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz
            self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz
                                                    for j in range(K)]

            # 更新meta网络,只使用了第 K步的query loss。这里应该是个超参,更新几步可以调调
            optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
            gvs = optimizer.compute_gradients(self.query_losses[-1])
   # def ********
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104

接下来回答一下上面的三个问题:

问题1:MAML的执行过程与model pretraining & transfer learning的区别是什么?

我们将meta learning与model pretraining的loss函数写出来。
在这里插入图片描述
注意这两个loss函数的区别:

  • meta learning的L来源于训练任务上网络的参数更新过一次后(该网络更新过一次以后,网络的参数与meta网络的参数已经有一些区别),然后使用Query Set计算的loss;
  • model pretraining的L来源于同一个model的参数(只有一个),使用训练数据计算的loss和梯度对model进行更新;如果有多个训练任务,我们可以将这个参数在很多任务上进行预训练,训练的所有梯度都会直接更新到model的参数上。
    看一下二者的更新过程简图:
    在这里插入图片描述
  1. MAML是使用子任务的参数,第二次更新的gradient的方向来更新参数(所以左图,第一个蓝色箭头的方向与第二个绿色箭头的方向平行;左图第二个蓝色箭头的方向与第二个橘色箭头的方向平行)

  2. 而model pretraining是使用子任务第一步更新的gradient的方向来更新参数(子任务的梯度往哪个方向走,model的参数就往哪个方向走)。
    从sense上直观理解:

  3. model pretraining最小化当前的model(只有一个)在所有任务上的loss,所以model pretraining希望找到一个在所有任务(实际情况往往是大多数任务)上都表现较好的一个初始化参数,这个参数要在多数任务上当前表现较好

  4. meta learning最小化每一个子任务训练一步之后,第二次计算出的loss,用第二步的gradient更新meta网络,这代表了什么呢?子任务从【状态0】,到【状态1】,我们希望状态1的loss小,说明meta learning更care的是初始化参数未来的潜力。

如下图所示,model pretraining找到的参数 ϕ \phi ϕ,在两个任务上当前的表现比较好(当下好,但训练之后不保证好);
而MAML的参数 ϕ \phi ϕ 在两个子任务当前的表现可能都不是很好,但是如果在两个子任务上继续训练下去,可能会达到各自任务的局部最优(潜力好)。
在这里插入图片描述

问题2:为何在meta网络赋值给具体训练任务(如任务m)后,要先更训练任务的参数,再计算梯度,更新meta网络?

这个问题其实在问题1中已经进行了回答,更新一步之后,避免了meta learning陷入了和model pretraining一样的训练模式,更重要的是,可以使得meta模型更关注参数的“潜力”。

问题3:在更新训练任务的网络时,只走了一步,然后更新meta网络。为什么是一步,可以是多步吗?

李宏毅老师的课程中提到:

只更新一次,速度比较快;因为meta learning中,子任务有很多,都更新很多次,训练时间比较久。
MAML希望得到的初始化参数在新的任务中fine tuning的时候效果好。如果只更新一次,就可以在新任务上获取很好的表现。把这件事情当成目标,可以使得meta网络参数训练是很好(目标与需求一致)。
当初始化参数应用到具体的任务中时,也可以fine tuning很多次。
Few-shot learning往往数据较少。

3. Reptile

Reptile与MAML有点像,我们先看一下Reptile的训练简图:
在这里插入图片描述
Reptile的训练过程如下:在这里插入图片描述
Reptile,每次sample出1个训练任务
在这里插入图片描述
Reptile,每次sample出1个batch训练任务

在Reptile中:

训练任务的网络可以更新多次
reptile不再像MAML一样计算梯度(因此带来了工程性能的提升),而是直接用一个参数 ξ \xi ξ 乘以meta网络与训练任务的网络参数的差来更新meta网络参数
从效果上来看,Reptile效果与MAML基本持平

参考:https://zhuanlan.zhihu.com/p/136975128

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/679364
推荐阅读
相关标签
  

闽ICP备14008679号