赞
踩
class torch.nn.LSTM(*args, **kwargs)
参数列表:
输入数据格式:(默认输入格式,batch_first=False)
input(seq_len, batch, input_size)
h0(num_layers * num_directions, batch, hidden_size)
c0(num_layers * num_directions, batch, hidden_size)
输出数据格式:
output(seq_len, batch, hidden_size * num_directions)
h_n(num_layers * num_directions, batch, hidden_size)
c_n(num_layers * num_directions, batch, hidden_size)
Pytorch里的LSTM单元接受的输入都必须是3维的张量(Tensors).每一维代表的意思不能弄错。
第一维体现的是序列(sequence)结构,也就是序列的frame个数
第二维度体现的是batch_size,也就是一次性喂给网络的序列的个数
第三维度体现的是输入的元素特征(elements of input),也就是,每一个frame的feature
H0-Hn是什么意思呢?就是每个时刻中间神经元应该保存的这一时刻的根据输入和上一时刻的中间状态值应该产生的本时刻的状态值,
这个数据单元是起的作用就是记录这一时刻之前考虑到所有之前输入的状态值,形状应该是和特定时刻的输出一致
c0-cn就是开关,决定每个神经元的隐藏状态值是否会影响的下一时刻的神经元的处理,形状应该和h0-hn一致。
使用注意事项:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。