当前位置:   article > 正文

pytorch convLSTM实现_convlstm pytorch

convlstm pytorch

pytorch中已经有很多人实现了convLSTM,但貌似pytorch还没有公布官方版本的convLSTM。以下这一版是比较通用的一个版本,我做注释后放在这里,方便以后查看。

  1. import torch.nn as nn
  2. import torch
  3. class ConvLSTMCell(nn.Module):
  4. def __init__(self, input_dim, hidden_dim, kernel_size, bias):
  5. """
  6. Initialize ConvLSTM cell.
  7. Parameters
  8. ----------
  9. input_dim: int
  10. Number of channels of input tensor.
  11. hidden_dim: int
  12. Number of channels of hidden state.
  13. kernel_size: (int, int)
  14. Size of the convolutional kernel.
  15. bias: bool
  16. Whether or not to add the bias.
  17. """
  18. super(ConvLSTMCell, self).__init__()
  19. self.input_dim = input_dim
  20. self.hidden_dim = hidden_dim
  21. self.kernel_size = kernel_size
  22. self.padding = kernel_size[0] // 2, kernel_size[1] // 2 # 保证在传递过程中 (h,w)不变
  23. self.bias = bias
  24. self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
  25. out_channels=4 * self.hidden_dim, # i门,f门,o门,g门放在一起计算,然后在split开
  26. kernel_size=self.kernel_size,
  27. padding=self.padding,
  28. bias=self.bias)
  29. def forward(self, input_tensor, cur_state):
  30. h_cur, c_cur = cur_state # 每个timestamp包含两个状态张量:h和c
  31. combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis # 把输入张量与h状态张量沿通道维度串联
  32. combined_conv = self.conv(combined) # i门,f门,o门,g门放在一起计算,然后在split开
  33. cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
  34. i = torch.sigmoid(cc_i)
  35. f = torch.sigmoid(cc_f)
  36. o = torch.sigmoid(cc_o)
  37. g = torch.tanh(cc_g)
  38. c_next = f * c_cur + i * g # c状态张量更新
  39. h_next = o * torch.tanh(c_next) # h状态张量更新
  40. return h_next, c_next # 输出当前timestamp的两个状态张量
  41. def init_hidden(self, batch_size, image_size):
  42. """
  43. 初始状态张量初始化.第一个timestamp的状态张量0初始化
  44. :param batch_size:
  45. :param image_size:
  46. :return:
  47. """
  48. height, width = image_size
  49. init_h = torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
  50. init_c = torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
  51. return (init_h,init_c)
  52. class ConvLSTM(nn.Module):
  53. """
  54. Parameters:参数介绍
  55. input_dim: Number of channels in input# 输入张量的通道数
  56. hidden_dim: Number of hidden channels # h,c两个状态张量的通道数,可以是一个列表
  57. kernel_size: Size of kernel in convolutions # 卷积核的尺寸,默认所有层的卷积核尺寸都是一样的,也可以设定不通lstm层的卷积核尺寸不同
  58. num_layers: Number of LSTM layers stacked on each other # 卷积层的层数,需要与len(hidden_dim)相等
  59. batch_first: Whether or not dimension 0 is the batch or not
  60. bias: Bias or no bias in Convolution
  61. return_all_layers: Return the list of computations for all layers # 是否返回所有lstm层的h状态
  62. Note: Will do same padding. # 相同的卷积核尺寸,相同的padding尺寸
  63. Input:输入介绍
  64. A tensor of size [B, T, C, H, W] or [T, B, C, H, W]# 需要是5维的
  65. Output:输出介绍
  66. 返回的是两个列表:layer_output_list,last_state_list
  67. 列表0:layer_output_list--单层列表,每个元素表示一层LSTM层的输出h状态,每个元素的size=[B,T,hidden_dim,H,W]
  68. 列表1:last_state_list--双层列表,每个元素是一个二元列表[h,c],表示每一层的最后一个timestamp的输出状态[h,c],h.size=c.size = [B,hidden_dim,H,W]
  69. A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
  70. 0 - layer_output_list is the list of lists of length T of each output
  71. 1 - last_state_list is the list of last states
  72. each element of the list is a tuple (h, c) for hidden state and memory
  73. Example:使用示例
  74. >> x = torch.rand((32, 10, 64, 128, 128))
  75. >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
  76. >> _, last_states = convlstm(x)
  77. >> h = last_states[0][0] # 0 for layer index, 0 for h index
  78. """
  79. def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
  80. batch_first=False, bias=True, return_all_layers=False):
  81. super(ConvLSTM, self).__init__()
  82. self._check_kernel_size_consistency(kernel_size)
  83. # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
  84. kernel_size = self._extend_for_multilayer(kernel_size, num_layers) # 转为列表
  85. hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) # 转为列表
  86. if not len(kernel_size) == len(hidden_dim) == num_layers: # 判断一致性
  87. raise ValueError('Inconsistent list length.')
  88. self.input_dim = input_dim
  89. self.hidden_dim = hidden_dim
  90. self.kernel_size = kernel_size
  91. self.num_layers = num_layers
  92. self.batch_first = batch_first
  93. self.bias = bias
  94. self.return_all_layers = return_all_layers
  95. cell_list = []
  96. for i in range(0, self.num_layers): # 多层LSTM设置
  97. # 当前LSTM层的输入维度
  98. # if i==0:
  99. # cur_input_dim = self.input_dim
  100. # else:
  101. # cur_input_dim = self.hidden_dim[i - 1]
  102. cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] # 与上等价
  103. cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
  104. hidden_dim=self.hidden_dim[i],
  105. kernel_size=self.kernel_size[i],
  106. bias=self.bias))
  107. self.cell_list = nn.ModuleList(cell_list) # 把定义的多个LSTM层串联成网络模型
  108. def forward(self, input_tensor, hidden_state=None):
  109. """
  110. Parameters
  111. ----------
  112. input_tensor: 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
  113. hidden_state: todo
  114. None. todo implement stateful
  115. Returns
  116. -------
  117. last_state_list, layer_output
  118. """
  119. if not self.batch_first:
  120. # (t, b, c, h, w) -> (b, t, c, h, w)
  121. input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
  122. # Implement stateful ConvLSTM
  123. if hidden_state is not None:
  124. raise NotImplementedError()
  125. else:
  126. # Since the init is done in forward. Can send image size here
  127. b, _, _, h, w = input_tensor.size() # 自动获取 b,h,w信息
  128. hidden_state = self._init_hidden(batch_size=b,image_size=(h, w))
  129. layer_output_list = []
  130. last_state_list = []
  131. seq_len = input_tensor.size(1) # 根据输入张量获取lstm的长度
  132. cur_layer_input = input_tensor
  133. for layer_idx in range(self.num_layers): # 逐层计算
  134. h, c = hidden_state[layer_idx]
  135. output_inner = []
  136. for t in range(seq_len): # 逐个stamp计算
  137. h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],cur_state=[h, c])
  138. output_inner.append(h) # 第 layer_idx 层的第t个stamp的输出状态
  139. layer_output = torch.stack(output_inner, dim=1) # 第 layer_idx 层的第所有stamp的输出状态串联
  140. cur_layer_input = layer_output # 准备第layer_idx+1层的输入张量
  141. layer_output_list.append(layer_output) # 当前层的所有timestamp的h状态的串联
  142. last_state_list.append([h, c]) # 当前层的最后一个stamp的输出状态的[h,c]
  143. if not self.return_all_layers:
  144. layer_output_list = layer_output_list[-1:]
  145. last_state_list = last_state_list[-1:]
  146. return layer_output_list, last_state_list
  147. def _init_hidden(self, batch_size, image_size):
  148. """
  149. 所有lstm层的第一个timestamp的输入状态0初始化
  150. :param batch_size:
  151. :param image_size:
  152. :return:
  153. """
  154. init_states = []
  155. for i in range(self.num_layers):
  156. init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
  157. return init_states
  158. @staticmethod
  159. def _check_kernel_size_consistency(kernel_size):
  160. """
  161. 检测输入的kernel_size是否符合要求,要求kernel_size的格式是list或tuple
  162. :param kernel_size:
  163. :return:
  164. """
  165. if not (isinstance(kernel_size, tuple) or
  166. (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
  167. raise ValueError('`kernel_size` must be tuple or list of tuples')
  168. @staticmethod
  169. def _extend_for_multilayer(param, num_layers):
  170. """
  171. 扩展到多层lstm情况
  172. :param param:
  173. :param num_layers:
  174. :return:
  175. """
  176. if not isinstance(param, list):
  177. param = [param] * num_layers
  178. return param
  179. if __name__ == "__main__":
  180. data = torch.randn((5,6,3,30,30))
  181. model = ConvLSTM(input_dim=3,
  182. hidden_dim=[64, 64, 128],
  183. kernel_size=[(3, 3),(5,5),(7,7)],
  184. num_layers=3,
  185. batch_first=True,
  186. bias = True,
  187. return_all_layers = True)
  188. layer_output_list, last_state_list = model(data)
  189. last_layer_output = layer_output_list[-1]
  190. last_layer_last_h,last_layer_last_c = last_state_list[-1]
  191. print(last_layer_output[:,-1,...]==last_layer_last_h)

 

注意事项:在用以上代码构建convLSTM时,要注意hidden_dim,kernel_size,num_layers三个参数在LSTM层上的一致。即代码中的:

len(kernel_size) == len(hidden_dim) == num_layers

如果hidden_dim=64,kernel_size = (3,3), num_layers=3: 会搭建一个3层的convLSTM网络,每一层的隐状态都是64通道,kernel_size=(3,3)

如果hidden_dim=[64,128,256],kernel_size = (3,3), num_layers=3: 会搭建一个3层的convLSTM网络,各层的隐状态通道数分别是[64,128,256],所有层的kernel_size==(3,3)

如果hidden_dim=[64,128,256],kernel_size = [(3,3),(5,5),(7,7)], num_layers=3: 会搭建一个3层的convLSTM网络,各层的隐状态通道数分别是[64,128,256],各层的kernel_size分别是[(3,3),(5,5),(7,7)]

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/219240
推荐阅读
相关标签
  

闽ICP备14008679号