赞
踩
通常在机器学习里,我们会使用某个场景的大量数据来训练模型;然而当场景发生改变,模型就需要重新训练。但是对于人类而言,一个小朋友成长过程中会见过许多物体的照片,某一天,当Ta(第一次)仅仅看了几张狗的照片,就可以很好地对狗和其他物体进行区分。
元学习Meta Learning,含义为学会学习,即learn to learn,就是带着这种对人类这种“学习能力”的期望诞生的。Meta Learning希望使得模型获取一种“学会学习”的能力,使其可以在获取已有“知识”的基础上快速学习新的任务,如:
需要注意的是,虽然同样有“预训练”的意思在里面,但是元学习的内核区别于迁移学习(Transfer Learning),关于他们的区别,我会在下文进行阐述。
接下来,我们通过对比机器学习和元学习这两个概念的要素来加深对元学习这个概念的理解。
在机器学习中,训练单位是一条数据,通过数据来对模型进行优化;数据可以分为训练集、测试集和验证集。在元学习中,训练单位分层级了,第一层训练单位是任务,也就是说,元学习中要准备许多任务来进行学习,第二层训练单位才是每个任务对应的数据。
二者的目的都是找一个Function,只是两个Function的功能不同,要做的事情不一样。机器学习中的Function直接作用于特征和标签,去寻找特征与标签之间的关联;而元学习中的Function是用于寻找新的f,新的f才会应用于具体的任务。有种不同阶导数的感觉。又有种老千层饼的感觉,你看到我在第二层,你把我想象成第一层,而其实我在第五层。。。
我们先对比机器学习的过程来进一步理解元学习。如下图所示,机器学习的一般过程如下:
这里面有几个小问题:
## 网络构建部分: refer: https://github.com/dragen1860/MAML-TensorFlow ################################################# # 任务描述:5-ways,1-shot图像分类任务,图像统一处理成 84 * 84 * 3 = 21168的尺寸。 # support set:5 * 1 # query set:5 * 15 # 训练取1个batch的任务:batch size:4 # 对训练任务进行训练时,更新5次:K = 5 ################################################# print(support_x) # (4, 5, 21168) print(query_x) # (4, 75, 21168) print(support_y) # (4, 5, 5) print(query_y) # (4, 75, 5) print(meta_batchsz) # 4 print(K) # 5 model = MAML() model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train') class MAML: def __init__(self): pass def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'): """ :param support_xb: [4, 5, 84*84*3] :param support_yb: [4, 5, n-way] :param query_xb: [4, 75, 84*84*3] :param query_yb: [4, 75, n-way] :param K: 训练任务的网络更新步数 :param meta_batchsz: 任务数,4 """ self.weights = self.conv_weights() # 创建或者复用网络参数;训练任务对应的网络复用meta网络的参数 training = True if mode is 'train' else False def meta_task(input): """ :param support_x: [setsz, 84*84*3] (5, 21168) :param support_y: [setsz, n-way] (5, 5) :param query_x: [querysz, 84*84*3] (75, 21168) :param query_y: [querysz, n-way] (75, 5) :param training: training or not, for batch_norm :return: """ support_x, support_y, query_x, query_y = input query_preds, query_losses, query_accs = [], [], [] # 子网络更新K次,记录每一次queryset的结果 ## 第0次对网络进行更新 support_pred = self.forward(support_x, self.weights, training) # 前向计算support set support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set loss support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1), tf.argmax(support_y, axis=1)) grads = tf.gradients(support_loss, list(self.weights.values())) # 计算support set的梯度 gvs = dict(zip(self.weights.keys(), grads)) # 使用support set的梯度计算的梯度更新参数,theta_pi = theta - alpha * grads fast_weights = dict(zip(self.weights.keys(), \ [self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()])) # 使用梯度更新后的参数对quert set进行前向计算 query_pred = self.forward(query_x, fast_weights, training) query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y) query_preds.append(query_pred) query_losses.append(query_loss) # 第1到 K-1次对网络进行更新 for _ in range(1, K): loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training), labels=support_y) grads = tf.gradients(loss, list(fast_weights.values())) gvs = dict(zip(fast_weights.keys(), grads)) fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key] for key in fast_weights.keys()])) query_pred = self.forward(query_x, fast_weights, training) query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y) # 子网络更新K次,记录每一次queryset的结果 query_preds.append(query_pred) query_losses.append(query_loss) for i in range(K): query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1), tf.argmax(query_y, axis=1))) result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs] return result # return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs] out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K] result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb), dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn') support_pred_tasks, support_loss_tasks, support_acc_tasks, \ query_preds_tasks, query_losses_tasks, query_accs_tasks = result if mode is 'train': self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz for j in range(K)] self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz for j in range(K)] # 更新meta网络,只使用了第 K步的query loss。这里应该是个超参,更新几步可以调调 optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim') gvs = optimizer.compute_gradients(self.query_losses[-1]) # def ********
接下来回答一下上面的三个问题:
问题1:MAML的执行过程与model pretraining & transfer learning的区别是什么?
我们将meta learning与model pretraining的loss函数写出来。
注意这两个loss函数的区别:
MAML是使用子任务的参数,第二次更新的gradient的方向来更新参数(所以左图,第一个蓝色箭头的方向与第二个绿色箭头的方向平行;左图第二个蓝色箭头的方向与第二个橘色箭头的方向平行)
而model pretraining是使用子任务第一步更新的gradient的方向来更新参数(子任务的梯度往哪个方向走,model的参数就往哪个方向走)。
从sense上直观理解:
model pretraining最小化当前的model(只有一个)在所有任务上的loss,所以model pretraining希望找到一个在所有任务(实际情况往往是大多数任务)上都表现较好的一个初始化参数,这个参数要在多数任务上当前表现较好。
meta learning最小化每一个子任务训练一步之后,第二次计算出的loss,用第二步的gradient更新meta网络,这代表了什么呢?子任务从【状态0】,到【状态1】,我们希望状态1的loss小,说明meta learning更care的是初始化参数未来的潜力。
如下图所示,model pretraining找到的参数
ϕ
\phi
ϕ,在两个任务上当前的表现比较好(当下好,但训练之后不保证好);
而MAML的参数
ϕ
\phi
ϕ 在两个子任务当前的表现可能都不是很好,但是如果在两个子任务上继续训练下去,可能会达到各自任务的局部最优(潜力好)。
问题2:为何在meta网络赋值给具体训练任务(如任务m)后,要先更训练任务的参数,再计算梯度,更新meta网络?
这个问题其实在问题1中已经进行了回答,更新一步之后,避免了meta learning陷入了和model pretraining一样的训练模式,更重要的是,可以使得meta模型更关注参数的“潜力”。
问题3:在更新训练任务的网络时,只走了一步,然后更新meta网络。为什么是一步,可以是多步吗?
李宏毅老师的课程中提到:
只更新一次,速度比较快;因为meta learning中,子任务有很多,都更新很多次,训练时间比较久。
MAML希望得到的初始化参数在新的任务中fine tuning的时候效果好。如果只更新一次,就可以在新任务上获取很好的表现。把这件事情当成目标,可以使得meta网络参数训练是很好(目标与需求一致)。
当初始化参数应用到具体的任务中时,也可以fine tuning很多次。
Few-shot learning往往数据较少。
Reptile与MAML有点像,我们先看一下Reptile的训练简图:
Reptile的训练过程如下:
Reptile,每次sample出1个训练任务
Reptile,每次sample出1个batch训练任务
在Reptile中:
训练任务的网络可以更新多次
reptile不再像MAML一样计算梯度(因此带来了工程性能的提升),而是直接用一个参数
ξ
\xi
ξ 乘以meta网络与训练任务的网络参数的差来更新meta网络参数
从效果上来看,Reptile效果与MAML基本持平
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。