当前位置:   article > 正文

Convlstm时空预测(keras框架、实战)_convlstm3d多步输出

convlstm3d多步输出

本人博客:www.erickun.com

Convlstm新手实战

   毕设临近截止,故写一篇心得以供新手学习,理论在知乎上有很多介绍的不错的文章,这里强烈推荐微信公众号:AI蜗牛车,这位东南老哥写了时空预测系列文章,能够帮助了解时空领域模型的演变,同时也向他请教了一些训练技巧。

   我的本科毕设大概是这样的:先计算某个区域的风险,计算得到一段时间的风险矩阵,这里用的是自己的模型去计算的,数据如何生成,本文不做赘述,主要讲解如果通过每个时刻下的矩阵数据去预测未来的矩阵。

回顾理论基础

   在ConvLSTM中,网络用于捕获数据集中的时空依赖性。ConvLSTM和FC-LSTM之间的区别在于,ConvLSTM将LSTM的前馈方法从Hadamard乘积变为卷积,即input-to-gate和gate-to-gate两个方向的运算均做卷积,也就是之前W和h点乘改为卷积(*)。 ConvLSTM的主要公式如下所示:
i t = σ ( W x i ∗ x t + W h i ∗ h t − 1 + b i ) f t = σ ( W x f ∗ x t + W h f ∗ h t − 1 + b f ) o t = σ ( W x o ∗ x t + W h o ∗ h t − 1 + b o ) C t = f t ∘ C t − 1 + i t ∘ tanh ⁡ ( W x c ∗ x t + W h c ∗ h t − 1 + b c ) H t = o t ∘ tanh ⁡ ( c t )

it=σ(Wxixt+Whiht1+bi)ft=σ(Wxfxt+Whfht1+bf)ot=σ(Wxoxt+Whoht1+bo)Ct=ftCt1+ittanh(Wxcxt+Whcht1+bc)Ht=ottanh(ct)
it=σ(Wxixt+Whiht1+bi)ft=σ(Wxfxt+Whfht1+bf)ot=σ(Wxoxt+Whoht1+bo)Ct=ftCt1+ittanh(Wxcxt+Whcht1+bc)Ht=ottanh(ct)
详细可参考:【时空序列预测第二篇】Convolutional LSTM Network-paper reading

官方keras案例

   实战过的朋友应该了解,关于Convlstm,可参考的案例非常少,基本上就集中在keras的官方案例(电影帧预测——视频预测官方案例知乎解说
在这里插入图片描述
官方模型核心代码:

from keras.models import Sequential
from keras.layers.convolutional import Conv3D
from keras.layers.convolutional_recurrent import ConvLSTM2D
from keras.layers.normalization import BatchNormalization
import numpy as np
import pylab as plt
 
seq = Sequential()
seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                   input_shape=(None, 40, 40, 1),
                   padding='same', return_sequences=True))
seq.add(BatchNormalization())

seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                   padding='same', return_sequences=True))
seq.add(BatchNormalization())

seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                   padding='same', return_sequences=True))
seq.add(BatchNormalization())

seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                   padding='same', return_sequences=True))
seq.add(BatchNormalization())

seq.add(Conv3D(filters=1, kernel_size=(3, 3, 3),
               activation='sigmoid',
               padding='same', data_format='channels_last'))
seq.compile(loss='binary_crossentropy', optimizer='adadelta')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

   模型结构可以如官方一样:用前20个预测后20个,这里先解释一下官方模型结构的维度:

    (如已熟悉,请跳过)对于新手来说,看上去似乎很复杂,其实弄清楚后会发现不过如此,请耐心听我讲完:先从第一个Convlstm说起,输入的是(None, 40, 40, 1),输出的维度(None,None,40,40,40),这里的输入维度(input_shape)其实是每个时刻下的输入,如下图:比如这里用20个预测后20个,那么整理的第一个样本就是0至19个矩阵,label(标签)就是20至39个矩阵,每一个矩阵维度为(40,40,1)最后的这个1为通道数,如果是图片,那就对应多通道了,那么整理的样本X就应该是(样本个数,20,40,40,1),对应标签Y就是(样本个数,20,40,40,1)这样每个样本和标签才能一一对应,由于reurn_sequencetrue,即每个时刻单元都有输出,也就是20个预测20个嘛,那么第一层的Convlstm输出的维度就是(None,None,40,40,40)这里第一个None是batchsize毫无疑问,第二个其实就是20,至于最后一个维度是40,和filter个数直接相关,(因为一个卷积核对样本做一次特征提取,40个就有40个特征提取)。

   接下来N层Convlstm均如此,最后为啥要接一个Conv3d,很好解释,因为你的label维度是(样本个数,20,40,40,1),这里的最后维度还得回归到1啊,所以Conv3d的filter这才设置为了1,以此类推,如果你的一个数据是三通道的图像,这里filter自然就是3了,一定要和label维度对应即可。

ConvLSTM参数介绍

  • filters: 卷积核的数目
  • kernel_size: 卷积核大小(1乘1的state-to-state kernel size很难抓住时空移动的特征,所以效果差很多,所以更大的size更能够获取时空的联系)
  • strides: (1,1)为卷积的步长,即卷积核向右和向下一次移动几格,默认步长为1
  • padding: 补0,为“valid”或 “same”。若要保证卷积核提取特征后前后维度一致,那就“same”
  • data_format: 即红绿蓝三个通道(channel)是在前面还是在后面,channels_last (默认) (width, height, channel)或 channels_first (channel, width, height) 之一, 输入中维度的顺序
  • activation: 激活函数,即下图中的RELU层,为预定义的激活函数名,如果不指定该参数,将不会使用任何激活函数(即使用线性激活函数:a(x)=x)

模型改造

   不过我由于数据量比较少,我把模型结构改造成了20个预测1个(样本数较少的童鞋可以参考),在convlstm最后一个层的reurn_sequence参数改为flase、Conv3d改2d即可。

   其实了解了reurn_sequence这个参数后,改造就顺理成章了,在最后一个Convlstm这里将reurn_sequence改为false,那么就只在最后一个单元有输出了,第二个None维度就没了,然后再把Conv3d改为2d即可,这样就要求整理数据集的时候,样本和标签分别整理成这样:(样本数,20,40,40,1) 和(样本数,40,40,1),也就是20个预测1个。

from keras.models import Sequential
from keras.layers.convolutional import Conv3D ,Conv2D
from keras.layers.convolutional_recurrent import ConvLSTM2D
from keras.layers.normalization import BatchNormalization
from keras_contrib.losses import DSSIMObjective

import numpy as np

seq = Sequential()
seq.add(ConvLSTM2D(filters=30, kernel_size=(3, 3),
                   input_shape=(None, 60, 93, 3),
                   padding='same', return_sequences=True))
seq.add(BatchNormalization())
 
seq.add(ConvLSTM2D(filters=30, kernel_size=(3, 3),
                   padding='same', return_sequences=True))
seq.add(BatchNormalization())
 
seq.add(ConvLSTM2D(filters=30, kernel_size=(3, 3),
                   padding='same', return_sequences=False))
seq.add(BatchNormalization())
 
seq.add(Conv2D(filters=3, kernel_size=(3, 3),
               activation='sigmoid',
               padding='same', data_format='channels_last'))

seq.compile(loss= DSSIMObjective(kernel_size=3), optimizer='adadelta')
seq.summary()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

模型经验及调参

   先看看结果图吧,随便抽一张示意一下,预测的点相对比较准确,但是模糊度还没解决掉,毕竟只训练了十几分钟,有这个效果也还算可以了:

    整个模型看上去不算复杂,但是实际效果比较差,有以下几个要稍微注意的地方:

  • 1.矩阵数据是否过于稀疏,如果0太多,建议先转成图片再做训练,否则效果会奇差无比,原因可能是求梯度的时候网络出了问题,直接崩了。
  • 2.如果输入是图片张量,需要提前做好归一化,我用的简单处理,直接元素除255.0,显示的时候再乘回来即可,可能有一丢丢颜色误差,但是不太影响。
  • 3.预测图片出现模糊大概有以下几个原因:

    (1)网络结构不够优(继续调就完事了),往往这种情况下,得到的预测点也不会太准确。

    (2)由于是多个时刻下的数据去预测一个,那么必然存在信息叠加(融合),这样导致的模糊是不可避免的,如果数据量很大,那么可以采用20帧预测20帧这样的结构,应该会有效减缓一点模糊程度。

    (3)重要: 损失函数若使用MSE则会默认模糊,如果换成SSIM(结构相似性)则会明显改观(亲测有效)

       在模糊处理方面,我也想尝试改进,但是还没有找到比较好的方式,蜗牛车老哥建议调小学习率,训练时间长一点,可以这样去尝试一下!反卷积也尝试了,但是效果不佳,后期准备使用TrajGRU来实战(预测解码模块采用了上采样层理论上应该会提高清晰度)。

   模型调参的过程其实是最无聊也最艰辛的,无非就是改改层结构,多一层少一层,改一下filter、batchsize个数,时空预测这种图像的预测和别的领域有一点不同,文本的只要acc、f1-score上去了就行,所以可以用grid search来自动化调参,但是图像预测还必须得肉眼去看效果,否则结果真可能是千差万别,loss看上去已经很低了但是效果很差的情况比比皆是,尝试多换几种loss来实验,后面也还可以尝试自定义loss看效果,整个调参过程确实是不断试错的过程,两个字:“炼丹”!

------------------------下图为2020.6.1更新,本科毕设最新效果,采用trick--------如有需要详细,可私信-------------------

------------------图错误
发现图好像画错了,
在这里插入图片描述

用20个预测下一个,应该是这样画图才对吧

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/219253
推荐阅读
相关标签
  

闽ICP备14008679号