赞
踩
本人博客:www.erickun.com
毕设临近截止,故写一篇心得以供新手学习,理论在知乎上有很多介绍的不错的文章,这里强烈推荐微信公众号:AI蜗牛车,这位东南老哥写了时空预测系列文章,能够帮助了解时空领域模型的演变,同时也向他请教了一些训练技巧。
我的本科毕设大概是这样的:先计算某个区域的风险,计算得到一段时间的风险矩阵,这里用的是自己的模型去计算的,数据如何生成,本文不做赘述,主要讲解如果通过每个时刻下的矩阵数据去预测未来的矩阵。
在ConvLSTM中,网络用于捕获数据集中的时空依赖性。ConvLSTM和FC-LSTM之间的区别在于,ConvLSTM将LSTM的前馈方法从Hadamard乘积变为卷积,即input-to-gate和gate-to-gate两个方向的运算均做卷积,也就是之前W和h点乘改为卷积(*)。 ConvLSTM的主要公式如下所示:
i
t
=
σ
(
W
x
i
∗
x
t
+
W
h
i
∗
h
t
−
1
+
b
i
)
f
t
=
σ
(
W
x
f
∗
x
t
+
W
h
f
∗
h
t
−
1
+
b
f
)
o
t
=
σ
(
W
x
o
∗
x
t
+
W
h
o
∗
h
t
−
1
+
b
o
)
C
t
=
f
t
∘
C
t
−
1
+
i
t
∘
tanh
(
W
x
c
∗
x
t
+
W
h
c
∗
h
t
−
1
+
b
c
)
H
t
=
o
t
∘
tanh
(
c
t
)
详细可参考:【时空序列预测第二篇】Convolutional LSTM Network-paper reading
实战过的朋友应该了解,关于Convlstm,可参考的案例非常少,基本上就集中在keras的官方案例(电影帧预测——视频预测官方案例)知乎解说
官方模型核心代码:
from keras.models import Sequential from keras.layers.convolutional import Conv3D from keras.layers.convolutional_recurrent import ConvLSTM2D from keras.layers.normalization import BatchNormalization import numpy as np import pylab as plt seq = Sequential() seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3), input_shape=(None, 40, 40, 1), padding='same', return_sequences=True)) seq.add(BatchNormalization()) seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3), padding='same', return_sequences=True)) seq.add(BatchNormalization()) seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3), padding='same', return_sequences=True)) seq.add(BatchNormalization()) seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3), padding='same', return_sequences=True)) seq.add(BatchNormalization()) seq.add(Conv3D(filters=1, kernel_size=(3, 3, 3), activation='sigmoid', padding='same', data_format='channels_last')) seq.compile(loss='binary_crossentropy', optimizer='adadelta')
模型结构可以如官方一样:用前20个预测后20个,这里先解释一下官方模型结构的维度:
(如已熟悉,请跳过)对于新手来说,看上去似乎很复杂,其实弄清楚后会发现不过如此,请耐心听我讲完:先从第一个Convlstm说起,输入的是(None, 40, 40, 1),输出的维度(None,None,40,40,40),这里的输入维度(input_shape
)其实是每个时刻下的输入,如下图:比如这里用20个预测后20个,那么整理的第一个样本就是0至19个矩阵,label(标签)就是20至39个矩阵,每一个矩阵维度为(40,40,1)最后的这个1为通道数,如果是图片,那就对应多通道了,那么整理的样本X就应该是(样本个数,20,40,40,1),对应标签Y就是(样本个数,20,40,40,1)这样每个样本和标签才能一一对应,由于reurn_sequence
为true
,即每个时刻单元都有输出,也就是20个预测20个嘛,那么第一层的Convlstm输出的维度就是(None,None,40,40,40)这里第一个None是batchsize毫无疑问,第二个其实就是20,至于最后一个维度是40,和filter
个数直接相关,(因为一个卷积核对样本做一次特征提取,40个就有40个特征提取)。
接下来N层Convlstm均如此,最后为啥要接一个Conv3d,很好解释,因为你的label维度是(样本个数,20,40,40,1),这里的最后维度还得回归到1啊,所以Conv3d的filter
这才设置为了1,以此类推,如果你的一个数据是三通道的图像,这里filter自然就是3了,一定要和label维度对应即可。
不过我由于数据量比较少,我把模型结构改造成了20个预测1个(样本数较少的童鞋可以参考),在convlstm最后一个层的reurn_sequence
参数改为flase
、Conv3d改2d即可。
其实了解了reurn_sequence
这个参数后,改造就顺理成章了,在最后一个Convlstm这里将reurn_sequence
改为false
,那么就只在最后一个单元有输出了,第二个None维度就没了,然后再把Conv3d改为2d即可,这样就要求整理数据集的时候,样本和标签分别整理成这样:(样本数,20,40,40,1) 和(样本数,40,40,1),也就是20个预测1个。
from keras.models import Sequential from keras.layers.convolutional import Conv3D ,Conv2D from keras.layers.convolutional_recurrent import ConvLSTM2D from keras.layers.normalization import BatchNormalization from keras_contrib.losses import DSSIMObjective import numpy as np seq = Sequential() seq.add(ConvLSTM2D(filters=30, kernel_size=(3, 3), input_shape=(None, 60, 93, 3), padding='same', return_sequences=True)) seq.add(BatchNormalization()) seq.add(ConvLSTM2D(filters=30, kernel_size=(3, 3), padding='same', return_sequences=True)) seq.add(BatchNormalization()) seq.add(ConvLSTM2D(filters=30, kernel_size=(3, 3), padding='same', return_sequences=False)) seq.add(BatchNormalization()) seq.add(Conv2D(filters=3, kernel_size=(3, 3), activation='sigmoid', padding='same', data_format='channels_last')) seq.compile(loss= DSSIMObjective(kernel_size=3), optimizer='adadelta') seq.summary()
先看看结果图吧,随便抽一张示意一下,预测的点相对比较准确,但是模糊度还没解决掉,毕竟只训练了十几分钟,有这个效果也还算可以了:
整个模型看上去不算复杂,但是实际效果比较差,有以下几个要稍微注意的地方:
模型调参的过程其实是最无聊也最艰辛的,无非就是改改层结构,多一层少一层,改一下filter、batchsize个数,时空预测这种图像的预测和别的领域有一点不同,文本的只要acc、f1-score上去了就行,所以可以用grid search来自动化调参,但是图像预测还必须得肉眼去看效果,否则结果真可能是千差万别,loss看上去已经很低了但是效果很差的情况比比皆是,尝试多换几种loss来实验,后面也还可以尝试自定义loss看效果,整个调参过程确实是不断试错的过程,两个字:“炼丹”!
------------------------下图为2020.6.1更新,本科毕设最新效果,采用trick--------如有需要详细,可私信-------------------
------------------图错误
发现图好像画错了,
用20个预测下一个,应该是这样画图才对吧
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。