赞
踩
本文主要是有关convLSTM的pytorch实现代码的理解,原理请移步其他博客。
在pytorch中实现LSTM或者GRU等RNN一般需要重写cell,每个cell中包含某一个时序的计算,也就是以下:
在传统LSTM中,LSTM每次要调用t次cell,t就是时序的总长度,如果是n层LSTM就相当于一共调用了n*t次cell
- class ConvLSTMCell(nn.Module):
-
- def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
- """
- Initialize ConvLSTM cell.
- Parameters
- ----------
- input_size: (int, int)
- Height and width of input tensor as (height, width).
- input_dim: int
- Number of channels of input tensor.
- hidden_dim: int
- Number of channels of hidden state.
- kernel_size: (int, int)
- Size of the convolutional kernel.
- bias: bool
- Whether or not to
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。