赞
踩
def position_embedding()
位置嵌入函数
加入得理由:受Conv2S(Gehring等人,2017)的启发,我们在时空网络序列中加入了位置嵌入,使模型能够考虑空间和时间信息,从而增强时空相关性的建模能力。
时空网络序列$ X_G \in R^{N\times C\times T}$
可学习得时间嵌入矩阵 T e m b ∈ R C × T T_{emb} \in R^{C\times T} Temb∈RC×T
可学习得空间嵌入矩阵 S e m b ∈ R N × C S_{emb} \in R ^{N\times C} Semb∈RN×C
通过广播操作将这两个嵌入矩阵添加到时空网络序列中,得到新得时空网络序列
X G + t e m b + s e m b = X G + T e m b + S e m b ∈ R N × C × T X_{G+t_{emb}+s_{emb}}=X_G+T_{emb}+S_{emb} \in R^{N\times C \times T} XG+temb+semb=XG+Temb+Semb∈RN×C×T
def position_embedding(data,
input_length, num_of_vertices, embedding_size,
temporal=True, spatial=True,
init=mx.init.Xavier(magnitude=0.0003), prefix=""):
'''
Parameters
----------
data: mx.sym.var, shape is (B, T, N, C) (batch_size, 时间,顶点,特征个数)
input_length: int, length of time series, T 时间序列得输入长度
num_of_vertices: int, N 顶点个数
embedding_size: int, C 嵌入大小
temporal, spatial: bool, whether equip this type of embeddings 嵌入类型
init: mx.initializer.Initializer 初始化的一种
prefix: str 前缀
Returns
----------
data: output shape is (B, T, N, C) ( batch_size, 时间,顶点,特征个数)
'''
# 先设置时空嵌入得变量为 None
temporal_emb = None
spatial_emb = None
如果需要时间嵌入得话,采用“字符式编程”创建对应的时间嵌入矩阵
# 如果是时间嵌入 ,创建对应的变量(占位符)
if temporal:
# shape is (1, T, 1, C) # 根据公式原理(C*T)创建对应得时间嵌入矩阵shape
temporal_emb = mx.sym.var(
"{}_t_emb".format(prefix), # 含前缀的变量
shape=(1, input_length, 1, embedding_size), # 占位符 的 形式
init=init
)
如果需要空间嵌入得话,
# 如果是空间嵌入,创建对应的变量(占位符)
if spatial:
# shape is (1, 1, N, C)
spatial_emb = mx.sym.var(
"{}_v_emb".format(prefix),
shape=(1, 1, num_of_vertices, embedding_size),
init=init
)
将初始数据和 嵌入表示 广播相加
# 将数据data 和 嵌入表示 广播加起来
if temporal_emb is not None:
data = mx.sym.broadcast_add(data, temporal_emb) # broadcast_add 逐元素加法,拥有广播机制
if spatial_emb is not None:
data = mx.sym.broadcast_add(data, spatial_emb)
return data
def gcn_operation()
图卷积操作
图形卷积运算可以聚合每个节点与其邻居的特征。我们在顶点域定义了一个图形卷积运算来聚集时空网络中的局部时空特征。图形卷积运算的输入是局部时空图形的图形信号矩阵。在我们的图卷积运算中,每个节点在相邻的时间步长上聚合自己和邻居的特征。聚合函数是一种线性组合,其权重等于节点与其邻居之间的边的权重。然后我们部署一个具有激活功能的全连接层,将节点的特征转换到一个新的空间。这种图形卷积运算可以表述如下:
G C N ( h ( l − 1 ) ) = h ( l ) = σ ( A ′ h ( l − 1 ) W + b ) ∈ R 3 B × C ′ GCN(h^{(l-1)})=h^{(l)}=\sigma (A'h^{(l-1)}W+b) \in R^{3B\times C'} GCN(h(l−1))=h(l)=σ(A′h(l−1)W+b)∈R3B×C′
其中,
A ′ ∈ R 3 N × 3 N A' \in R^{3N\times 3N} A′∈R3N×3N 表示局部时空图的邻接矩阵
h ( l − 1 ) ∈ R ( 3 N × C ) h^{(l-1)}\in R^{(3N\times C)} h(l−1)∈R(3N×C) 是输入层
W ∈ R C × C ′ W \in R^{C\times C'} W∈RC×C′, b ∈ R C ′ b\in R^{C'} b∈RC′ 可学习参数
h ( l ) = ( A ′ h ( l − 1 ) W 1 + b 1 ) ⊗ s i g m o i d ( A ′ h ( l − 1 ) W 2 + b 2 ) h^{(l)} =(A'h^{(l-1)}W_1+b_1)\otimes sigmoid (A'h^{(l-1)}W_2+b_2) h(l)=(A′h(l−1)W1+b1)⊗sigmoid(A′h(l−1)W2+b2)
其中
W 1 ∈ R ( C × C ′ ) , W 2 ∈ R ( C × C ′ ) W_1\in R^{(C\times C')},W_2\in R^{(C\times C')} W1∈R(C×C′),W2∈R(C×C′)
b 1 , b 2 ∈ R ( C ′ ) b_1,b_2\in R^{(C')} b1,b2∈R(C′)
s i g m o i d ( x ) = 1 1 + e − x sigmoid(x)=\frac{1}{1+e^{-x}} sigmoid(x)=1+e−x1
$\otimes $代表点积
def gcn_operation(data, adj,
num_of_filter, num_of_features, num_of_vertices,
activation, prefix=""):
'''
graph convolutional operation, a simple GCN we defined in paper
Parameters
----------
data: mx.sym.var, shape is (3N, B, C) (3N,batch_size,特征)
adj: mx.sym.var, shape is (3N, 3N)
num_of_filter: int, C' 过滤器数量
num_of_features: int, C 特征数量
num_of_vertices: int, N 顶点个数
activation: str, {'GLU', 'relu'} 激活函数
prefix: str
Returns
----------
output shape is (3N, B, C')(3N,batch_size, 过滤器个数)
'''
assert activation in {'GLU', 'relu'} # 确认 激活函数 是 GLU 和 relu, f否则报错
d a t a : A ′ h ( l − 1 ) data:A'h^{(l-1)} data:A′h(l−1)
# shape is (3N, B, C)
data = mx.sym.dot(adj, data) # mx.sym.dot 两个数组的点积
如果采用‘GLU’
对上述得数据进行两个全连接层:(W1,b1)和(W2,b2),因为都是全连接层且形状相同,所以,可以直接设置一个过滤器为2倍全连接层,之后再进行数据分割。最后再根据公式求结果。
if activation == 'GLU':
# shape is (3N, B, 2C')
data = mx.sym.FullyConnected(
data,
flatten=False,
num_hidden=2 * num_of_filter
)
# shape is (3N, B, C'), (3N, B, C')
lhs, rhs = mx.sym.split(data, num_outputs=2, axis=2) # 沿特定轴将数组拆分为多个子数组
# shape is (3N, B, C')
return lhs * mx.sym.sigmoid(rhs)
如果采用’Relu’得话
直接对全连接层进行一个relu激活函数即可。
elif activation == 'relu':
# shape is (3N, B, C')
return mx.sym.Activation(
mx.sym.FullyConnected(
data,
flatten=False,
num_hidden=num_of_filter
), activation
)
def stgcm()
spatial-temporal synchronous graph convolutional module 如下图fig4所示
注意到,这里是对每个时刻的数据,进行局部处理
如上图所示,经过两层图卷积层后得结果喂入AGG层.
AGG层得原理包括两个部分:Aggregating & Cropping
聚合操作:最大池化层
它对STSGCM中所有图形卷积的输出应用元素式最大运算。最大运算要求所有输出具有相同的大小,因此一个模块中图形卷积运算的内核数量应该相等。最大聚合操作可以表述为:
h A G G = max ( h ( 1 ) , h ( 2 ) , . . . , h ( L ) ) ∈ R 3 N × C o u t h_{AGG} =\max (h^{(1)},h^{(2)},...,h^{(L)})\in R^{3N\times C_{out}} hAGG=max(h(1),h(2),...,h(L))∈R3N×Cout
其中,
C o u t C_{out} Cout表示图卷积层得核得个数
剪枝操作:
裁剪操作(图4 ©)删除了上一个和下一个时间步骤中节点的所有特征,只保留了中间时刻的节点。原因是图卷积运算已经聚集了来自前一个和下一个时间步骤的信息。尽管我们裁剪了两个时间步长,但每个节点都包含局部的时空相关性。如果我们堆叠多个STSGCMs并保留所有相邻时间步长的特征,那么模型中将存在大量冗余信息,这将严重影响模型的性能。
def stsgcm(data, adj,
filters, num_of_features, num_of_vertices,
activation, prefix=""):
'''
STSGCM, multiple stacked gcn layers with cropping and max operation
Parameters
----------
data: mx.sym.var, shape is (3N, B, C) 数据
adj: mx.sym.var, shape is (3N, 3N) 变量
filters: list[int], list of C' 过滤器
num_of_features: int, C 特征个数
num_of_vertices: int, N 顶点个数
activation: str, {'GLU', 'relu'} 激活函数
prefix: str
Returns
----------
output shape is (N, B, C')
'''
注意到这里得fileter是一个列表。在Experiment Settings部分看到,STSGCM共包含3个图卷积操作,filter得个数分别为64,64,64.
所以,filters=[64,64,64]
need_concat = []
for i in range(len(filters)): # 过滤器长度得列表得元素个数
data = gcn_operation( # 调用gcn_operation 函数
data, adj,
filters[i], num_of_features, num_of_vertices,
activation=activation,
prefix="{}_gcn_{}".format(prefix, i)
)
need_concat.append(data)
num_of_features = filters[i]
案例解析
第一次循环 |
data:(3N,B,C)
邻接矩阵:adj–>(3N,3N)
过滤器数量:filters[0]–>64
特征个数:num_of_featres 函数中没有用到
顶点个数:num_of_vertices函数中没有用到
data:(3N,B,filters[0])
need_concat[0]
第二次循环 |
data:(3N,B,filters[0])
过滤器数量:filters[1]–>64
data:(3N,B,filters[1])
need_concat[1]
第三次循环 |
data:(3N,B,filters[1])
过滤器数量:filters[2]–>64
data:(3N,B,filters[2])
need_concat[2]
结果 |
need_concat=[data1,data2,data3]
data1,data2,data3都是(3N,B,64)的形状
cropping操作 |
在need_concat中对每个元素(3N,B,C’)裁剪中间元素,也就是(N->2N,B,C’)
# need_concat中的每个元素结构为(3N,B,C')
# 接下来的部分是对中间部分进行裁剪(3N,B,C')-->(N,B,C')
# shape of each element is (1, N, B, C')
need_concat = [
mx.sym.expand_dims(
mx.sym.slice(
i,
begin=(num_of_vertices, None, None),
end=(2 * num_of_vertices, None, None)
), 0
) for i in need_concat
]
函数解析
begin
和end
以及相应的step
)。结果 |
need_concat=[data1,data2,data3]
data1,data2,data3都是(1,N,B,64)的形状
max-pooling操作 |
# shape is (N, B, C')
return mx.sym.max(mx.sym.concat(*need_concat, dim=0), axis=0)
函数解析
mx.sym.concat(*need_concat,dim=0)会将三个(1,N,B,64)按照axis=0,拼接为(3,N,B,64)
mx.sym.max(…,axis=0),会按照axis=0,进行最大值操作,结果为(N,B,64)
结果 |
返回数据格式为 (N,B,64)
def stgcnl()
如下图fig2所示,stgcn layer的结构如图所示。
为了捕捉整个网络系列的长期时空相关性,我们使用滑动窗口来剪切不同的周期。由于时空数据的异构性,最好使用多个STSGCMs来模拟不同的周期,而不是为所有周期共享一个。多个STSGCMs允许每个STSgCMs专注于在局部图中建模局部时空相关性。我们部署了一组时空同步图形卷积层来提取远程时空特征,如图2所示。
输入:嵌入后的时空网络序列 X X + t e m b + s e m b ∈ R T × N × C X_{X+temb+semb}\in R^{T\times N\times C} XX+temb+semb∈RT×N×C
stgcn layer:包含 T − 2 T-2 T−2个stgcm
结果: M = [ M 1 , M 2 , . . . , M T − 2 ] ∈ R ( T − 2 ) × N × C o u t M=[M_1,M_2,...,M_{T-2}]\in R^{(T-2)\times N \times C_{out}} M=[M1,M2,...,MT−2]∈R(T−2)×N×Cout
用1个小时的数据预测未来1个小时的数据,T=12
def stsgcl(data, adj,
T, num_of_vertices, num_of_features, filters,
module_type, activation, temporal_emb=True, spatial_emb=True,
prefix=""):
'''
STSGCL
Parameters
----------
data: mx.sym.var, shape is (B, T, N, C)
adj: mx.sym.var, shape is (3N, 3N)
T: int, length of time series, T= 12
num_of_vertices: int, N
num_of_features: int, C
filters: list[int], list of C'
module_type: str, {'sharing', 'individual'}
activation: str, {'GLU', 'relu'}
temporal_emb, spatial_emb: bool
prefix: str
Returns
----------
output shape is (B, T-2, N, C')
'''
# 确定 传入的模型是集合中的元素,否则报错
assert module_type in {'sharing', 'individual'}
如果模型是独立的话,调用sthgcn_layer_individual
函数
if module_type == 'individual':
return sthgcn_layer_individual(
data, adj,
T, num_of_vertices, num_of_features, filters,
activation, temporal_emb, spatial_emb, prefix
)
如果模型是共享的话,调用sthgcn_layer_sharing
else:
return sthgcn_layer_sharing(
data, adj,
T, num_of_vertices, num_of_features, filters,
activation, temporal_emb, spatial_emb, prefix
)
def sthgcn_layer_individual()
def sthgcn_layer_individual(data, adj,
T, num_of_vertices, num_of_features, filters,
activation, temporal_emb=True, spatial_emb=True,
prefix=""):
'''
STSGCL, multiple individual STSGCMs
Parameters
----------
data: mx.sym.var, shape is (B, T, N, C)
adj: mx.sym.var, shape is (3N, 3N)
T: int, length of time series, T
num_of_vertices: int, N
num_of_features: int, C
filters: list[int], list of C'
activation: str, {'GLU', 'relu'}
temporal_emb, spatial_emb: bool
prefix: str
Returns
----------
output shape is (B, T-2, N, C')
'''
处理好的初始数据data:(B,T,N,C) |
进行位置嵌入 |
# shape is (B, T, N, C)
data = position_embedding(data, T, num_of_vertices, num_of_features,
temporal_emb, spatial_emb,
prefix="{}_emb".format(prefix))
结果-->data:(B,T,N,C) |
进行T-2个stgcm |
need_concat = [] # 用于存放T-2个结果
for i in range(T - 2):
# 每个stgcm需要的数据时间点为3,通过切片技术获取结构为(B,3,N,C)
# shape is (B, 3, N, C)
t = mx.sym.slice(data, begin=(None, i, None, None),
end=(None, i + 3, None, None))
# 修改结构的维度-->向stgcm的输入数据变化
# shape is (B, 3N, C)
t = mx.sym.reshape(t, (-1, 3 * num_of_vertices, num_of_features))
# 修改数据的转置-->向stgcm的输入数据变化
# shape is (3N, B, C)
t = mx.sym.transpose(t, (1, 0, 2))
# 调用stgcm函数
# shape is (N, B, C')
t = stsgcm(
t, adj, filters, num_of_features, num_of_vertices,
activation=activation,
prefix="{}_stsgcm_{}".format(prefix, i)
)
# 返回的结果为(N,B,C')转置为(B,N,C')
# shape is (B, N, C')
t = mx.sym.swapaxes(t, 0, 1)
'''
mx.sym.syapaxes(data=None, dim1=_Null, dim2=_Null, name=None, attr=None, out=None, **kwargs)互换数组的两个轴。
假设x.shape=(2,3)
mx.sym.syapaxes(x,0,1)-->结果.shape=(3,2)
'''
# 对(B,N,C')的结果扩维(B,1,N,C’)后,将其放入need_concat列表中
# shape is (B, 1, N, C')
need_concat.append(mx.sym.expand_dims(t, axis=1))
# 对T-2的结果合并在一起为(B,T-2,N,C')
# shape is (B, T-2, N, C')
return mx.sym.concat(*need_concat, dim=1)
结果-->data:(B,T-2,N,C) |
def sthgcn_layer_sharing()
def sthgcn_layer_sharing(data, adj,
T, num_of_vertices, num_of_features, filters,
activation, temporal_emb=True, spatial_emb=True,
prefix=""):
'''
STSGCL, multiple a sharing STSGCM
Parameters
----------
data: mx.sym.var, shape is (B, T, N, C)
adj: mx.sym.var, shape is (3N, 3N)
T: int, length of time series, T
num_of_vertices: int, N
num_of_features: int, C
filters: list[int], list of C'
activation: str, {'GLU', 'relu'}
temporal_emb, spatial_emb: bool
prefix: str
Returns
----------
output shape is (B, T-2, N, C')
'''
处理好的初始数据data:(B,T,N,C) |
进行位置嵌入 |
# shape is (B, T, N, C)
data = position_embedding(data, T, num_of_vertices, num_of_features,
temporal_emb, spatial_emb,
prefix="{}_emb".format(prefix))
结果-->data:(B,T,N,C) |
对T-2个网络序列的构造 |
need_concat = []
for i in range(T - 2):
# shape is (B, 3, N, C)
t = mx.sym.slice(data, begin=(None, i, None, None),
end=(None, i + 3, None, None))
# 修改数据结构
# shape is (B, 3N, C)
t = mx.sym.reshape(t, (-1, 3 * num_of_vertices, num_of_features))
# shape is (3N, B, C)
t = mx.sym.swapaxes(t, 0, 1)
# 将修改好的数据结构放入need_concat
need_concat.append(t)
# 对need_concat中的数据扩维并合并
# shape is (3N, (T-2)*B, C)
t = mx.sym.concat(*need_concat, dim=1)
结果 |
数据t:(N,(T-2)*B ,C)
共用一个stgcm |
# 对构造好的数据集进行stsgcm
# shape is (N, (T-2)*B, C')
t = stsgcm(
t, adj, filters, num_of_features, num_of_vertices,
activation=activation,
prefix="{}_stsgcm".format(prefix)
)
结果 |
结果t:(N,(T-2)*B,C’) C’=64
数据变形 |
# shape is (N, T - 2, B, C)
t = t.reshape((num_of_vertices, T - 2, -1, filters[-1]))
# shape is (B, T - 2, N, C)
return mx.sym.swapaxes(t, 0, 2)
结果 |
返回的是(B,T-2,N,C)
def output_layer()
输出层
设计了一个输出层,将最后一个STSGCL的输出转换成预期的预测。
最后一个STSGCL的输出。 X ∈ R T × N × C 0 X \in R ^{T \times N \times C0} X∈RT×N×C0
数据转置
X ∈ R N × T C 0 X \in R^{N\times TC0} X∈RN×TC0
two-fully-connected-layers
KaTeX parse error: \tag works only in display equations
其中, y ~ ( i ) \tilde{y}^{(i)} y~(i)是在时间i的 预测值
W 1 ( i ) ∈ R T C 0 × C 1 W_1^{(i)} \in R^{TC0\times C1} W1(i)∈RTC0×C1, b 1 ( i ) ∈ R C 1 b_1^{(i)}\in R^{C1} b1(i)∈RC1 , W 2 ( i ) ∈ R C 1 × 1 W_2^{(i)}\in R^{C1\times 1} W2(i)∈RC1×1 , b 2 ( i ) ∈ R b_2^{(i)} \in R b2(i)∈R是可学习参数
C 1 C1 C1代表着输出层的第一个全连接层的特征个数
T’个双层全连接
Y ~ = [ y ~ ( 1 ) , y ~ ( 2 ) , . . . , y ( T ′ ) ~ ] ∈ R N × T ′ \tilde{Y} =[\tilde{y}^{(1)},\tilde{y}^{(2)},...,\tilde{y^{(T')}}]\in R^{N\times T'} Y~=[y~(1),y~(2),...,y(T′)~]∈RN×T′
STSGCN 的结果为 Y ~ \tilde{Y} Y~
def output_layer(data, num_of_vertices, input_length, num_of_features,
num_of_filters=128, predict_length=12):
'''
Parameters
----------
data: mx.sym.var, shape is (B, T, N, C)
num_of_vertices: int, N
input_length: int, length of time series, T
num_of_features: int, C
num_of_filters: int, C'
predict_length: int, length of predicted time series, T'
Returns
----------
output shape is (B, T', N)
'''
注意到,
num_of_filters=128,是第一层全连接的filter个数
predict_length=12,是第二层全连接的filter个数
数据 |
输入的数据格式是(Batch_size, 时间长度,顶点个数N,C’=64)
# data shape is (B, N, T, C) # 将T和C放置在一块
data = mx.sym.swapaxes(data, 1, 2)
# (B, N, T * C) # 数据转换为3维
data = mx.sym.reshape(
data, (-1, num_of_vertices, input_length * num_of_features)
)
结果data:(B,N,T*C)
两次全连接层 |
# (B, N, C')
data = mx.sym.Activation(
mx.sym.FullyConnected(
data,
flatten=False,
num_hidden=num_of_filters
), 'relu'
)
# (B, N, T')
data = mx.sym.FullyConnected(
data,
flatten=False,
num_hidden=predict_length
)
结果 |
# (B, T', N)
data = mx.sym.swapaxes(data, 1, 2)
return data
data:(Batch_size,预测长度,节点个数)
def huber_loss()
损失函数Huber loss
L
(
Y
,
Y
~
)
=
{
1
/
2
(
Y
=
Y
~
)
2
,
∣
Y
−
Y
~
∣
≤
δ
δ
∣
Y
−
Y
~
∣
−
1
/
2
δ
2
,
o
t
h
e
r
s
L(Y,\tilde{Y})=
def huber_loss(data, label, rho=1):
'''
Parameters
----------
data: mx.sym.var, shape is (B, T', N)
label: mx.sym.var, shape is (B, T', N)
rho: float
Returns
----------
loss: mx.sym
'''
loss = mx.sym.abs(data - label)
loss = mx.sym.where(loss > rho, loss - 0.5 * rho,
(0.5 / rho) * mx.sym.square(loss))
loss = mx.sym.MakeLoss(loss)
return loss
def weighted_loss()
基于huber loss的加权损失
def weighted_loss(data, label, input_length, rho=1):
'''
weighted loss build on huber loss
Parameters
----------
data: mx.sym.var, shape is (B, T', N)
label: mx.sym.var, shape is (B, T', N)
input_length: int, T'
rho: float
Returns
----------
agg_loss: mx.sym
'''
# 设置加权矩阵的形状
# shape is (1, T', 1)
weight = mx.sym.expand_dims(
mx.sym.expand_dims(
mx.sym.flip(mx.sym.arange(1, input_length + 1), axis=0),
axis=0
), axis=-1
)
'''
mx.sym.flip(data=None, axis=_Null, name=None, attr=None, out=None, **kwargs)
在保留数组形状的同时,反转元素沿给定轴的顺序。
假设x为[[0,1,2,3],[4,5,6,7]]
对x沿axis=0翻转,[[4,5,6,7],[0,1,2,3]]
对x沿axis=1翻转,[[3,2,1,0],[7,6,5,4]]
'''
agg_loss = mx.sym.broadcast_mul(
huber_loss(data, label, rho),
weight
)
'''
broadcast_mul(lhs=None, rhs=None, name=None, attr=None, out=None, **kwargs)
返回具有广播的输入数组的逐元素乘积。
'''
return agg_loss
def stsgcn()
def stsgcn(data, adj, label,
input_length, num_of_vertices, num_of_features,
filter_list, module_type, activation,
use_mask=True, mask_init_value=None,
temporal_emb=True, spatial_emb=True,
prefix="", rho=1, predict_length=12):
'''
data shape is (B, T, N, C)
adj shape is (3N, 3N)
label shape is (B, T, N)
'''
if use_mask:
# 如果掩膜矩阵为空,则报错
if mask_init_value is None:
raise ValueError("mask init value is None!")
mask = mx.sym.var("{}_mask".format(prefix),
shape=(3 * num_of_vertices, 3 * num_of_vertices),
init=mask_init_value)
adj = mask * adj
for idx, filters in enumerate(filter_list):
# 调用 stsgc layer层函数,filter_list有几个元素就进行几层stsgcl
data = stsgcl(data, adj, input_length, num_of_vertices,
num_of_features, filters, module_type,
activation=activation,
temporal_emb=temporal_emb,
spatial_emb=spatial_emb,
prefix="{}_stsgcl_{}".format(prefix, idx))
# 下次循环的 input_length(时间长度)-2
input_length -= 2
# 下一轮的num_of_features应该是这一轮最后的结果的输出的filter个数
num_of_features = filters[-1]
# (B, 1, N)
need_concat = []
# 对预测长度进行循环
for i in range(predict_length):
# 将每次循环的结果放入need_concat,循环结果为输出层的结果
need_concat.append(
output_layer(
data, num_of_vertices, input_length, num_of_features,
num_of_filters=128, predict_length=1
)
)
# 将结果按照axis=1合并,data:(B,predict_length,N)
data = mx.sym.concat(*need_concat, dim=1)
# 计算真实值和预测值之间的损失
loss = huber_loss(data, label, rho=rho)
return mx.sym.Group([loss, mx.sym.BlockGrad(data, name='pred')])
# mx.sym.BlockGrad是反向传播
# mx.sym.Gtoup是将多个symbol类型聚合在一起。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。