赞
踩
在(一)中简单的对全文进行了翻译,实际上很多句子并不能靠翻译来理解,所以只能将自己的想法融入在译文里面。今天我们记录一下代码与论文。
以前我认为将代码逐行解析就是“懂了”,实际上能用自然语言将代码要做的事情讲出来,这才是真的明白论文与代码在表达什么。所以接下来我不会涉及到具体代码,我将以自然语言的方式叙述STAR算法在做一件怎样的事情,希望看博客的朋友从算法的角度思考问题。
1. 数据是怎样被处理的?
STAR模型架构以ETH、UCY两个数据集作为输入,其中ETH数据集包含两个子集,UCY数据集包含三个子集。我们用7个场景的数据进行训练,用1个场景的数据进行测试。原始数据为csv的形式,大小为,其中4代表有四行:第一行表示该场景下的帧数;第二行为行人ID;第三、四行分别为行人的坐标,代表说整个csv文件有列,也代表整个场景下行人轨迹点的坐标。这一份信息很重要,这是读者自己在看代码时理解的基础,我特意用黄色的背景显示,代表说这一段真的很重要。
紧接着原始数据会分场景进行处理,这一段代码也就是——行人轨迹预处理。那具体被怎样处理,我想这是读者应该思考的事情。换句话说,行人轨迹会被分门别类处理成对应场景、对应行人的轨迹点。那这个时候数据会被送入到模型中训练吗?代码中告诉我们,轨迹预处理过后的数据在送入模型之前还要经过旋转、切片等数据增强的操作。在这一系列处理之后,7个场景的数据总共被划分为223个batch,每个batch由20帧的序列组成。但在送入模型前我们只取其前19帧的数据,读者想一想这是为什么?
2. STAR算法
希望在看博客的朋友从算法的角度开始思考问题。下图是论文中给的网络结构,读者先问自己。整个模型在做一件怎样的事情呢?
首先我们要明白整个模型在做一件怎样的事情,模型在做的事情就是以行人历史轨迹的前8帧作为一个输入序列,输出行人12帧的目标序列。现在我们以一个人来说,我们的输入应该是一个大小为的序列,或者说是,其中1代表1个行人,8代表8帧,2代表行人的坐标;我们的输出应该是一个,其中12代表咱们需要预测的12帧。那这是1个人,如果我们有一群人呢?一群人我们的输入就应该是,代表这一群人。刚才有说我们的一个batch是由20帧组成,并且这20帧每一帧的长度都不相同,也就是说是可变长序列,但我们只取前19帧,那这19帧是一同输入给我们的模型进行训练吗?读者看过代码的话就知道,我们是一帧一帧送给STAR进行训练的。具体来说:当frame=0时,输入维度:,输出维度:;当frame=1时,输入维度:(注意2代表前2帧,而不是第二帧),输出维度:;... ;当frame=7时,输入维度:(同理,8代表前8帧,而不是第八帧),输出维度:;当frame=8时,输入维度:(为什么人数再变?因为有的人在某一帧的时候存在,但到了某一帧的时候不存在,所以人数一直在变),输出维度:;... ;当frame=18时,输入维度:,输出维度:。所以我发现,我们训练的时候在做这样一件事:输入第1帧预测第二帧的结果,输入前2帧预测第三帧的结果,...,输入前19帧预测第二十帧的结果。总的来说,我们是用前19帧来预测第20帧。所以整个模型就是用前多少帧来预测后一帧。那读者会由疑问说:论文里说的观测前8帧预测后12帧该怎么理解呢?我想读者理解了刚才我所说的模型在做一件怎样的事情后,立马就能反映过来了:我给你前8帧的历史行人轨迹作为一个序列,通过模型后得到第9帧的预测序列,再把预测得到的第9帧的序列,叠加到前8帧上,通过模型后得到第10帧的预测序列,再把预测得到的第10帧的序列叠加到第9帧上,送入模型后得到第11帧的预测序列....那这个过程就反复下去,直到得到第20帧的预测序列停止。可以看出我们的损失就是计算真实轨迹与预测轨迹之间的L2距离,所以用均方差损失能够完成该任务。
我并没有精细的去画图示意,不过准博士生的你也一定知道上面这张图我在表达什么。
在明白了STAR算法在做一件怎样的事情之后,我们来看一看中间过程维度是怎样变换的,以及一些小Trick的使用。
如果你对网络足够熟悉,那你应该看着网络结构就能说出各部分的维度变换,如果不能,那就不是真的懂。为了方便叙述,我们假定输入是,其中代表前帧,239代表我们同时考虑239个人,2代表二维坐标。在Encoder 1 的上分支,FC的输入维度是,输出维度是,代表说我将二维坐标进行了一个32维的映射。请注意,这里对输入到FC的数据是经过归一化处理的(读者思考是什么归一化?如何做的?),为什么要做这样一个处理呢,因为我们的上分支是空间Transformer,它的输入是经过处理的图卷积数据。Ok,那全连接层的输出是直接喂给我们的空间Transformer吗?不是,我们会怎么做呢?我们会进行一个切片的操作,具体来说:我们对全连接层的输出进行切片,只取数据的最后一个维度。所以输入到空间Transformer的维度是。要理解为什么要做切片这件事,其实可以这样思考:举例来说我以前7帧作为全连接层的输入,对于空间Transformer来说,它存在的目的是学习到行人的空间信息,所以数据输入到空间Transformer里时,只需要考虑最后一帧就ok,不需要再重复考虑前面已经考虑过的帧。希望我这么说能够帮你解开疑惑。好,那你会有疑问说空间Transform的输入真的就是吗?当然不是,因为我们还需要做一个维度的调整,因为Transformer的输入维度应该是:Batchsize x Seq_length x embed_dim,所以真正的输入维度应该是,经空间Transformer后,输出维度应该是,再经过维度调整变成。对于Encoder 1的下分支,FC的输入是原封未动的数据,大小为,输出维度是。由于时序Transformer存在的意义是提取行人间的时序关系,所以时序Transformer的输入就是全连接层的输出,经时序Transformer后,输出维度,再经过切片只取最后一帧的时序数据,维度变为,这样上下两路分支才可以进行Concatenate拼接。
Encoder 2的维度就相对比较简单,这里就不写了。在这里要注意的是把空间Transformer放在前面,时序Transformer放在后面这肯定是跑代码试出来的,包括后面所加的高斯噪声,这种技巧是逐渐积累起来的,你可以想说为什么在这里加高斯噪声效果就好,为什么不加在别的地方,为什么加的是高斯噪声而不是别的噪声?所以这种小Trick平时要多积累。其次是GM储存机制,其实就相当于一步缓存器,那你要思考说作者为什么能想到这种小Trick?
损失函数可以帮助我们了解网络朝哪方面进行优化。在STAR算法中用到的损失是MSE损失,计算的是预测值与真实值之间距离的平方和。我们优化的目标是使评价指标ADE、FDE尽可能小。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。