赞
踩
(这里主要讲pytorch的实现,具体RNN细节请参考
链接: link)
RNN是用来处理带有时间序列的信息的,每一个时间段采用一个Cell来存储这个时间段之前的所有信息,即h0。
最一开始需要我们初始化建立一个h0表示在输入数据前起始的Cell状态,然后该h0与第一个时间上的信息x0经过运算相加形成第一个带有记忆信息的Cell,即h1。
在pytorch中很好地集成了RNN类,可以通过nn.RNN来实例化。
torch.nn.RNN(*args, **kwargs)
上面为实例化RNN时的参数,这里我们只关心前三个
input_size:表示输入的单个x的维度,比如单词的embedding的维度
hidden_size:表示隐藏层h的维度;比如隐藏层维度是10,每个单词的embedding维度是20,通过隐藏层后把单词的embedding变成10维表示
num_layers:表示RNN有多少层,这里的层数与后面的使用RNN时需要初始化的h0的大小有关
Wxh的维度为:[hidden_size, feature_size] --> X@Wxh.T(内部操作)
Whh的维度为:[hidden_size, hidden_size] – > Hi@Whh.T(内部操作)
rnn = nn.RNN(input_size=100, hidden_size=20, num_layers=1)
print(rnn)
x = torch.randn(10, 3, 100)
out, h = rnn(x, torch.zeros(1, 3, 20))
print(out.shape, h.shape)
RNN的2个输入分别是数据input和初始化的隐藏层状态h0
特别注意input的大小是:[seq_len, batch_size, input_size],这与CV中把batch_size放在第一维不同,主要是为了方便以每个时间步为单位来计算。
h0的大小为:[num_layers * bidirectional, batch_size, hidden_szie]
rnn的输出为output 和 h_n
out_put记录的是在每个时间步上,最后一层的隐藏层状态,其大小为:[seq_len, batch_size, hidden_size * bidirectional],因为如果是双向rnn的话,会把来回两个方向的hidden_size拼接,所以ouput最后一个维度大小是hidden_size * bidirectional
h_n记录的是最后一个时间步所有层(包括双向的两层)上的隐藏层状态信息,其大小是:[num_layers * bidirectional, batch_size, hidden_size]
rnn = nn.RNN(input_size=100, hidden_size=20, num_layers=4)
print(rnn)
x = torch.randn(10, 3, 100)
out, h = rnn(x, torch.zeros(4, 3, 20))
print(out.shape, h.shape)
上述的RNN是一整个流程,从第一个时间步到最后一个时间步都给我们集成好了。下面介绍的是一个Cell的RNN,也叫RNNCell,这个需要我们自己手动往后更新hi计算每个时间步。
torch.nn.RNNCell(input_size, hidden_size, bias=True, nonlinearity='tanh', device=None, dtype=None)
这里的input_size 和 hidden_size 与上面的RNN的意义一样
cell1 = nn.RNNCell(100, 20)
h1 = torch.zeros(3, 20)
for xt in x:
h1 = cell1(xt, h1)
print(h1.shape)
这里与RNN不同的是,cell1的输入xt是X在每个时间步上的数据,对时间步做一个for循环,xt大小是:[batch_size,
input_size];同样,这只表示一个cell,则h1的初始化也不用关心多少层是否双向的问题,其大小是:[batch_size,
hidden_size]
cell1的输出也只有一个h1,这个h1用来更新隐藏层用来手动向后计算,大小与初始化的一样*[batch_size, input_size]*
cell1 = nn.RNNCell(100, 30)
cell2 = nn.RNNCell(30, 20)
h1 = torch.zeros(3, 30)
h2 = torch.zeros(3, 20)
for xt in x:
h1 = cell1(xt, h1)
h2 = cell2(h1, h2)
print(h2.shape)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。