赞
踩
写在前面:主要目的是想对自己对于元学习的内容和问题进行总结,同时为想要学习Meta-Learning的同学提供一下简单的入门。笔者挑选了经典的paper详读,看了李宏毅老师深度学习课程元学习部分,并附了MAML的代码。为了通俗易懂,我将数学推导和工程实践分开两篇文章进行介绍。如果看不懂,欢迎来捶我( )~~
如果大家觉得有帮助,可以帮忙点个赞或者收藏一下,这将是我继续分享的动力~
以下是本文的主要框架:
通常在机器学习里,我们会使用某个场景的大量数据来训练模型;然而当场景发生改变,模型就需要重新训练。但是对于人类而言,一个小朋友成长过程中会见过许多物体的照片,某一天,当Ta(第一次)仅仅看了几张狗的照片,就可以很好地对狗和其他物体进行区分。
元学习Meta Learning,含义为学会学习,即learn to learn,就是带着这种对人类这种“学习能力”的期望诞生的。Meta Learning希望使得模型获取一种“学会学习”的能力,使其可以在获取已有“知识”的基础上快速学习新的任务,如:
需要注意的是,虽然同样有“预训练”的意思在里面,但是元学习的内核区别于迁移学习(Transfer Learning),关于他们的区别,我会在下文进行阐述。
接下来,我们通过对比机器学习和元学习这两个概念的要素来加深对元学习这个概念的理解。
在机器学习中,训练单位是一条数据,通过数据来对模型进行优化;数据可以分为训练集、测试集和验证集。在元学习中,训练单位分层级了,第一层训练单位是任务,也就是说,元学习中要准备许多任务来进行学习,第二层训练单位才是每个任务对应的数据。
二者的目的都是找一个Function,只是两个Function的功能不同,要做的事情不一样。机器学习中的Function直接作用于特征和标签,去寻找特征与标签之间的关联;而元学习中的Function是用于寻找新的f,新的f才会应用于具体的任务。有种不同阶导数的感觉。又有种老千层饼的感觉,你看到我在第二层,你把我想象成第一层,而其实我在第五层。。。
我们先对比机器学习的过程来进一步理解元学习。如下图所示,机器学习的一般过程如下:
机器学习过程,引自李宏毅《深度学习》
其中,红色方框里的“配置”都是由人为设计的,我们又叫做“超参数“。Meta Learning中希望把这些配置,如网络结构,参数初始化,优化器等由机器自行设计(注:此处区别于AutoML,迁移学习(Transfer Learning)和终身学习(Life Long Learning) ),使网络有更强的学习能力和表现。
上文已经提到,【元学习中要准备许多任务来进行学习,而每个任务又有各自的训练集和测试集】。我们结合一个具体的任务,来介绍元学习和MAML的实施过程。
有一个图像数据集叫Omniglot:https://github.com/brendenlake/omniglot。Omniglot包含1623个不同的火星文字符,每个字符包含20个手写的case。这个任务是判断每个手写的case属于哪一个火星文字符。
如果我们要进行N-ways,K-shot(数据中包含N个字符类别,每个字符有K张图像)的一个图像分类任务。比如20-ways,1-shot分类的意思是说,要做一个20分类,但是每个分类下只有1张图像的任务。我们可以依据Omniglot构建很多N-ways,K-shot任务,这些任务将作为元学习的任务来源。构建的任务分为训练任务(Train Task),测试任务(Test Task)。特别地,每个任务包含自己的训练数据、测试数据,在元学习里,分别称为Support Set和Query Set。
MAML的目的是获取一组更好的模型初始化参数(即让模型自己学会初始化)。我们通过(许多)N-ways,K-shot的任务(训练任务)进行元学习的训练,使得模型学习到“先验知识”(初始化的参数)。这个“先验知识”在新的N-ways,K-shot任务上可以表现的更好。
接下来介绍MAML的算法流程:
MAML算法流程
当然,在“预训练”阶段,也可以sample出1个batch的几个任务,那么在更新meta网络时,要使用sample出所有任务的梯度之和。
注意:在MAML中, meta网络与子任务的网络结构必须完全相同。
这里面有几个小问题:
这三个问题是MAML中很核心的问题,大家可以先思考一下,我们将在后文进行解答。我们先看一下MAML的实现代码。
- ## 网络构建部分: refer: https://github.com/dragen1860/MAML-TensorFlow
-
- #################################################
- # 任务描述:5-ways,1-shot图像分类任务,图像统一处理成 84 * 84 * 3 = 21168的尺寸。
- # support set:5 * 1
- # query set:5 * 15
- # 训练取1个batch的任务:batch size:4
- # 对训练任务进行训练时,更新5次:K = 5
- #################################################
-
- print(support_x) # (4, 5, 21168)
- print(query_x) # (4, 75, 21168)
- print(support_y) # (4, 5, 5)
- print(query_y) # (4, 75, 5)
- print(meta_batchsz) # 4
- print(K) # 5
-
- model = MAML()
- model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')
-
- class MAML:
- def __init__(self):
- pass
- def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):
- """
- :param support_xb: [4, 5, 84*84*3]
- :param support_yb: [4, 5, n-way]
- :param query_xb: [4, 75, 84*84*3]
- :param query_yb: [4, 75, n-way]
- :param K: 训练任务的网络更新步数
- :param meta_batchsz: 任务数,4
- """
-
- self.weights = self.conv_weights() # 创建或者复用网络参数;训练任务对应的网络复用meta网络的参数
- training = True if mode is 'train' else False
- def meta_task(input):
- """
- :param support_x: [setsz, 84*84*3] (5, 21168)
- :param support_y: [setsz, n-way] (5, 5)
- :param query_x: [querysz, 84*84*3] (75, 21168)
- :param query_y: [querysz, n-way] (75, 5)
- :param training: training or not, for batch_norm
- :return:
- """
-
- support_x, support_y, query_x, query_y = input
- query_preds, query_losses, query_accs = [], [], [] # 子网络更新K次,记录每一次queryset的结果
-
- ## 第0次对网络进行更新
- support_pred = self.forward(support_x, self.weights, training) # 前向计算support set
- support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set loss
- support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
- tf.argmax(support_y, axis=1))
- grads = tf.gradients(support_loss, list(self.weights.values())) # 计算support set的梯度
- gvs = dict(zip(self.weights.keys(), grads))
- # 使用support set的梯度计算的梯度更新参数,theta_pi = theta - alpha * grads
- fast_weights = dict(zip(self.weights.keys(), \
- [self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()]))
-
- # 使用梯度更新后的参数对quert set进行前向计算
- query_pred = self.forward(query_x, fast_weights, training)
- query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
- query_preds.append(query_pred)
- query_losses.append(query_loss)
-
- # 第1到 K-1次对网络进行更新
- for _ in range(1, K):
- loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training),
- labels=support_y)
- grads = tf.gradients(loss, list(fast_weights.values()))
- gvs = dict(zip(fast_weights.keys(), grads))
- fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key]
- for key in fast_weights.keys()]))
- query_pred = self.forward(query_x, fast_weights, training)
- query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
- # 子网络更新K次,记录每一次queryset的结果
- query_preds.append(query_pred)
- query_losses.append(query_loss)
-
- for i in range(K):
- query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1),
- tf.argmax(query_y, axis=1)))
- result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
- return result
-
- # return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
- out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K]
- result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),
- dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')
- support_pred_tasks, support_loss_tasks, support_acc_tasks, \
- query_preds_tasks, query_losses_tasks, query_accs_tasks = result
-
- if mode is 'train':
- self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz
- self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz
- for j in range(K)]
- self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz
- self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz
- for j in range(K)]
-
- # 更新meta网络,只使用了第 K步的query loss。这里应该是个超参,更新几步可以调调
- optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
- gvs = optimizer.compute_gradients(self.query_losses[-1])
- # def ********
接下来回答一下上面的三个问题:
问题1:MAML的执行过程与model pretraining & transfer learning的区别是什么?
我们将meta learning与model pretraining的loss函数写出来。
meta learning与model pretraining的loss函数
注意这两个loss函数的区别:
看一下二者的更新过程简图:
meta learning与model pretraining训练过程,引自李宏毅《深度学习》
从sense上直观理解:
一个关注当下,一个关注潜力。
引自李宏毅《深度学习》
这里有一个toy example可以表现MAML的执行过程与model pretraining & transfer learning的区别。
训练任务:给定N个函数,y = asinx + b(通过给a和b不同的取值可以得到很多sin函数),从每个函数中sample出K个点,用sample出的K个点来预估最初的函数,即求解a和b的值。
训练过程:用这N个训练任务sample出的数据点分别通过MAML与model pretraining训练网络,得到预训练的参数。
如下图,用橘黄色的sin函数作为测试任务,三角形的点是测试任务中sample出的样本点,在测试任务中,我们希望用sample出的样本点还原橘黄色的线。
Toy example,引自李宏毅《深度学习》
问题2:为何在meta网络赋值给具体训练任务(如任务m)后,要先更训练任务的参数,再计算梯度,更新meta网络?
这个问题其实在问题1中已经进行了回答,更新一步之后,避免了meta learning陷入了和model pretraining一样的训练模式,更重要的是,可以使得meta模型更关注参数的“潜力”。
问题3:在更新训练任务的网络时,只走了一步,然后更新meta网络。为什么是一步,可以是多步吗?
李宏毅老师的课程中提到:
那么MAML中的训练任务的网络可以更新多次后,再更新meta网络吗?
我觉得可以。直观上感觉,更新次数决定了子任务对于meta网络的影响程度,我觉得这个步数可以作为一个参数来调。
另外,即将介绍的下一个网络——Reptile,也是对训练任务网络进行多次更新的。
Reptile与MAML有点像,我们先看一下Reptile的训练简图:
Reptile训练过程,引自李宏毅《深度学习》
Reptile的训练过程如下:
Reptile,每次sample出1个训练任务
Reptile,每次sample出1个batch训练任务
在Reptile中:
元学习入门部分的文章基本就分享到这里了~
分享一个关于元学习的搞笑的图。。。
老千层饼,你永远都不知道你咬下去的这一口有多少层。。
接下来可能会分享一篇MAML的数学推导,以及想把当前工作里的model pretraining模型切到meta learning看一下效果。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。