赞
踩
对于输入序列中的每个元素,每层计算以下函数:
r
t
=
σ
(
W
i
r
x
t
+
b
i
r
+
W
h
r
h
(
t
−
1
)
+
b
h
r
)
r_t=\sigma(W_{ir}x_t + b_{ir} + W_{hr}h_{(t-1)} + b_{hr})
rt=σ(Wirxt+bir+Whrh(t−1)+bhr)
z
t
=
σ
(
W
i
z
x
t
+
b
i
z
+
W
h
z
h
(
t
−
1
)
+
b
h
z
)
z_t=\sigma(W_{iz}x_t + b_{iz} + W_{hz}h_{(t-1)} + b_{hz})
zt=σ(Wizxt+biz+Whzh(t−1)+bhz)
n
t
=
t
a
n
h
(
W
i
n
x
t
+
b
i
n
+
r
t
∗
(
W
h
n
h
(
t
−
1
)
)
+
b
h
n
)
n_t=tanh(W_{in}x_t + b_{in} +r_t*(W_{hn}h_{(t-1)}) + b_{hn})
nt=tanh(Winxt+bin+rt∗(Whnh(t−1))+bhn)
h
t
=
(
1
−
z
t
)
∗
n
t
+
z
∗
h
(
t
−
1
)
h_t=(1-z_t)*n_t + z*h_{(t-1)}
ht=(1−zt)∗nt+z∗h(t−1)
其中各个变量的含义如下:
rnn = nn.GRU(10, 20, 2)# embedding_size, hidden_size, num_layer
input = torch.randn(5, 3, 10)# sequence length, batch size, embedding_size
h0 = torch.randn(2, 3, 20)# num_layer*dirc, batch size, hidden_size
output, hn = rnn(input, h0)
output.shape
torch.Size([5, 3, 20])# sequence length, batch size, hidden_size
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。