赞
踩
【功能模块】完整代码在附件,数据集需要的话也可以提供
class EmbeddingImagenet(nn.Cell): def __init__(self,emb_size,cifar_flag=False): super(EmbeddingImagenet, self).__init__() # set size self.hidden = 64 self.last_hidden = self.hidden * 25 if not cifar_flag else self.hidden * 4 self.emb_size = emb_size self.out_dim = emb_size # set layers self.conv_1 = nn.SequentialCell(nn.Conv2d(in_channels=3, out_channels=self.hidden, kernel_size=3, padding=1, pad_mode='pad', has_bias=False), nn.BatchNorm2d(num_features=self.hidden), nn.MaxPool2d(kernel_size=2,stride=2), nn.LeakyReLU(alpha=0.2)) self.conv_2 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden, out_channels=int(self.hidden*1.5), kernel_size=3, padding=1, pad_mode='pad', has_bias=False), nn.BatchNorm2d(num_features=int(self.hidden*1.5)), nn.MaxPool2d(kernel_size=2,stride=2), nn.LeakyReLU(alpha=0.2)) self.conv_3 = nn.SequentialCell(nn.Conv2d(in_channels=int(self.hidden*1.5), out_channels=self.hidden*2, kernel_size=3, padding=1, pad_mode='pad', has_bias=False), nn.BatchNorm2d(num_features=self.hidden * 2), nn.MaxPool2d(kernel_size=2,stride=2), nn.LeakyReLU(alpha=0.2), nn.Dropout(0.6)) self.conv_4 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden*2, out_channels=self.hidden*4, kernel_size=3, padding=1, pad_mode='pad', has_bias=False), nn.BatchNorm2d(num_features=self.hidden * 4), # 16 * 64 * (5 * 5) nn.MaxPool2d(kernel_size=2,stride=2), nn.LeakyReLU(alpha=0.2), nn.Dropout(0.5)) # self.layer_last = nn.SequentialCell(nn.Dense(in_channels=self.last_hidden * 4, # out_channels=self.emb_size, has_bias=True), # nn.BatchNorm1d(self.emb_size)) self.layer_last = nn.Dense(in_channels=self.last_hidden * 4,out_channels=self.emb_size, has_bias=True) #self.bn = nn.BatchNorm1d(self.emb_size) def construct(self, input_data): #print("img:",input_data[0]) x = self.conv_1(input_data) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) #x = ops.Reshape()(x,(x.shape[0],-1)) print("feat:", input_data[0]) #x = self.layer_last(x) x = self.layer_last(x.view(x.shape[0],-1)) print("last--------------------------------:",x[0]) return x
class NodeUpdateNetwork(nn.Cell): def __init__(self, in_features, num_features, ratio=[2, 1], dropout=0.0): super(NodeUpdateNetwork, self).__init__() # set size self.in_features = in_features self.num_features_list = [num_features * r for r in ratio] self.dropout = dropout self.eye = ops.Eye() self.bmm = ops.BatchMatMul() self.cat = ops.Concat(-1) self.split = ops.Split(1,2) self.repeat = ops.Tile() self.unsqueeze = ops.ExpandDims() self.squeeze = ops.Squeeze() self.transpose = ops.Transpose() # layers layer_list = OrderedDict() for l in range(len(self.num_features_list)): layer_list['conv{}'.format(l)] = nn.Conv2d( in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3, out_channels=self.num_features_list[l], kernel_size=1, has_bias=False) layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],) layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2) if self.dropout > 0 and l == (len(self.num_features_list) - 1): layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout) self.network = nn.SequentialCell(layer_list) def construct(self, node_feat, edge_feat): # get size num_tasks = node_feat.shape[0] num_data = node_feat.shape[1] # get eye matrix (batch_size x 2 x node_size x node_size) diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(num_data,num_data,ms.float32),0),0),(num_tasks,2,1,1)) # set diagonal as zero and normalize 原论文是l1归一化 # edge_feat = edge_feat * diag_mask # edge_feat = edge_feat / ops.clip_by_value(ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1),Tensor(0,ms.float32),Tensor(num_data,ms.float32)) edge_feat = ops.L2Normalize(-1)(edge_feat * diag_mask) # compute attention and aggregate aggr_feat = self.bmm(self.squeeze(ops.Concat(2)(self.split(edge_feat))),node_feat) node_feat = self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]).swapaxes(1,2) #node_feat = self.transpose(self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]),(0,2,1)) node_feat = self.network(self.unsqueeze(node_feat,(-1))).swapaxes(1,2).squeeze() #node_feat = self.squeeze(self.transpose(self.network(self.unsqueeze(node_feat,(-1))),(0,2,1,3))) return node_feat class EdgeUpdateNetwork(nn.Cell): def __init__(self, in_features, num_features, ratio=[2, 2, 1, 1], separate_dissimilarity=False, dropout=0.0): super(EdgeUpdateNetwork, self).__init__() # set size self.in_features = in_features self.num_features_list = [num_features * r for r in ratio] self.separate_dissimilarity = separate_dissimilarity self.dropout = dropout self.eye = ops.Eye() self.repeat = ops.Tile() self.unsqueeze = ops.ExpandDims() # layers layer_list = OrderedDict() for l in range(len(self.num_features_list)): # set layer layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features, out_channels=self.num_features_list[l], kernel_size=1, has_bias=False) layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l], ) layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2) if self.dropout > 0: layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout) layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1], out_channels=1, kernel_size=1) self.sim_network = nn.SequentialCell(layer_list) def construct(self, node_feat, edge_feat): # compute abs(x_i, x_j) x_i = ops.ExpandDims()(node_feat,2) x_j = x_i.swapaxes(1,2) #x_j = ops.Transpose()(x_i,(0,2,1,3)) #x_ij = (x_i-x_j)**2 x_ij = ops.Abs()(x_i-x_j) #print("x_ij:",x_ij[0,0,:,:]) x_ij = ops.Transpose()(x_ij,(0,3,2,1)) sim_val = self.sim_network(x_ij) sim_val = ops.Sigmoid()(sim_val) #print("sim_val", sim_val[0, 0, :, :]) dsim_val = 1.0 - sim_val diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),0),(node_feat.shape[0],2,1,1)) edge_feat = edge_feat * diag_mask merge_sum = ops.ReduceSum(keep_dims=True)(edge_feat,-1) # set diagonal as zero and normalize # edge_feat = ops.Concat(1)([sim_val,dsim_val])*edge_feat # edge_feat = edge_feat / ops.clip_by_value((ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1)),Tensor(0,ms.float32),Tensor(num_data,ms.float32)) # edge_feat = edge_feat*merge_sum edge_feat = ops.L2Normalize(-1)(ops.Concat(1)([sim_val,dsim_val])*edge_feat)*merge_sum force_edge_feat = self.repeat(self.unsqueeze(ops.Concat(0)([self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),self.unsqueeze(ops.Zeros()((node_feat.shape[1],node_feat.shape[1]),ms.float32),0)]),0),(node_feat.shape[0],1,1,1)) edge_feat = edge_feat + force_edge_feat edge_feat = edge_feat + 1e-6 #print("sum_edge",self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1))[0,0]) edge_feat = edge_feat / self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1)) return edge_feat class GraphNetwork(nn.Cell): def __init__(self, in_features, node_features, edge_features, num_layers, dropout=0.0 ): super(GraphNetwork, self).__init__() # set size self.in_features = in_features self.node_features = node_features self.edge_features = edge_features self.num_layers = num_layers self.dropout = dropout self.layers = nn.CellList() # for each layer for l in range(self.num_layers): # set edge to node edge2node_net = NodeUpdateNetwork(in_features=self.in_features if l == 0 else self.node_features, num_features=self.node_features, dropout=self.dropout if l < self.num_layers-1 else 0.0) # set node to edge node2edge_net = EdgeUpdateNetwork(in_features=self.node_features, num_features=self.edge_features, separate_dissimilarity=False, dropout=self.dropout if l < self.num_layers-1 else 0.0) self.layers.append(nn.CellList([edge2node_net,node2edge_net])) # forward def construct(self, node_feat, edge_feat): # for each layer edge_feat_list = [] #print("node_feat---------------------------------------------------------- -1", node_feat[0, 0, :]) for l in range(self.num_layers): # (1) edge to node node_feat = self.layers[l][0](node_feat, edge_feat) # (2) node to edge edge_feat = self.layers[l][1](node_feat, edge_feat) # save edge feature edge_feat_list.append(edge_feat) return edge_feat_list
【操作步骤&问题现象】
我们代码主要功能是用4层卷积加一层全连接层提取图片特征,之后将图片的特征当成图网络每个节点,用GNN。(代码在附件上)
1、在训练了很多个batch之后,提取出来的特征(经过了4层卷积层和全连接层)出现了很大很大的值,之后几个batch后出现NAN,而在没有经过全连接层的时候,特征数字还是正常的
2、
【截图信息】
这是代码输出的特征
last--------------------------------: [ 1.918492 -0.8280923 2.0575197 0.3089749 -1.0514854 0.5368729
0.14135109 1.5270222 -1.4794292 -1.4336827 1.0335447 -0.7093582
-0.41919574 -0.5667086 -0.3535831 1.5567536 0.5002996 -1.4093596
0.9674009 -0.18156137 0.14888959 0.6358457 1.406878 -0.03820777
-0.24577822 -0.25783274 0.5756687 -1.4558431 -1.1002262 0.68062806
-1.6467474 0.88712454 0.3551372 -1.3449378 -1.7011788 -0.8629771
-0.92482185 0.9867192 -1.5548937 1.340383 -2.299356 -0.3421743
1.3239275 -1.3792732 -0.31955895 -0.58364254 -3.7381008 -1.2121737
-0.75104207 -0.7562581 0.04980466 0.45131734 -1.2448095 -0.33418307
0.86268485 -1.3601649 1.2753168 2.469506 -1.7358601 -2.9104383
-0.07392117 -0.73263663 0.11657254 -0.05724781 0.34374043 -0.31884825
0.13456154 2.3561432 -0.18908082 0.5410311 1.7249999 0.9508886
-0.30631644 1.6836481 1.1513023 -0.33672807 -0.889638 -0.76715356
-0.7316199 1.597606 -1.6586273 0.4502733 0.5224928 -3.5851111
-2.906651 -1.5284328 0.83426046 1.354644 -1.4453334 2.0504599
-1.3200179 -0.50427496 0.97681373 0.30048305 0.17170379 0.8179815
-0.92994857 1.333491 -1.2931286 -0.3569969 2.7953048 -3.352736
1.878619 2.018083 -1.1191074 -1.1341975 1.4532931 -0.66957355
2.3269157 -0.4198427 0.7148121 0.5458231 -1.3050007 -0.34666243
2.519589 0.804219 0.91191477 1.3088121 0.6767241 2.1667008
0.24471135 1.2600335 -1.8683847 2.5641935 -0.9636249 -1.0340385
-0.32570755 -1.7694132 ]
------------------------------------------
------------------------------------------------------------------------------- 1 0.7806913
---------------------------------------------
feat: [[[-1.6726604 -1.6897851 -1.7069099 ... 0.43368444 0.46793392
0.41655967]
[-1.7069099 -1.7069099 -1.7069099 ... 0.5364329 0.5193082
0.4850587 ]
[-1.7240347 -1.7240347 -1.7069099 ... 0.60493195 0.5535577
0.4850587 ]
...
[-0.6622999 -0.8335474 -0.8677969 ... -0.02868402 0.00556549
-0.02868402]
[-0.6622999 -0.69654936 -0.69654936 ... -0.11430778 -0.11430778
-0.14855729]
[-0.95342064 -0.8335474 -0.78217316 ... -0.26843056 -0.30268008
-0.31980482]]
[[-1.7556022 -1.7731092 -1.7906162 ... -0.617647 -0.582633
-0.635154 ]
[-1.7906162 -1.7906162 -1.7906162 ... -0.512605 -0.512605
-0.565126 ]
[-1.8081232 -1.8081232 -1.7906162 ... -0.460084 -0.495098
-0.565126 ]
...
[-0.28501397 -0.37254897 -0.40756297 ... -1.0028011 -0.9677871
-1.0203081 ]
[-0.26750696 -0.33753496 -0.32002798 ... -1.12535 -1.1428571
-1.160364 ]
[-0.53011197 -0.53011197 -0.44257697 ... -1.2829131 -1.317927
-1.317927 ]]
[[-1.68244 -1.6998693 -1.7172985 ... -1.490719 -1.4558606
-1.490719 ]
[-1.7172985 -1.7172985 -1.7172985 ... -1.4732897 -1.4384314
-1.4732897 ]
[-1.7347276 -1.7347276 -1.7172985 ... -1.4558606 -1.4732897
-1.5255773 ]
...
[-1.3338562 -1.4210021 -1.4210021 ... -1.6127234 -1.5430065
-1.5604358 ]
[-1.2815686 -1.3512855 -1.3338562 ... -1.6127234 -1.5952941
-1.6127234 ]
[-1.5081482 -1.4732897 -1.4210021 ... -1.5778649 -1.6127234
-1.6301525 ]]]
last--------------------------------: [-9.7715964e+37 -1.3229437e+37 -1.5262715e+38 -2.5811514e+38
3.2964988e+38 -7.1266450e+37 -7.2963347e+37 -3.0699307e+38
-1.6108344e+38 5.8011444e+37 -3.9925391e+37 -9.5891957e+37
-1.7783365e+38 2.2280316e+38 -4.4186918e+37 3.4825655e+37
5.8457292e+37 7.2160006e+37 1.4259578e+38 9.4037617e+37
7.4650717e+37 1.8146209e+37 -2.5143476e+38 2.4387442e+38
-7.5397363e+37 1.4157064e+38 -1.1084308e+38 1.9522180e+38
2.5864164e+37 -8.5381704e+37 3.3140050e+36 -1.2379668e+38
-3.3449897e+37 1.6203643e+38 1.4627435e+38 6.6909600e+37
6.0661751e+37 -1.2335753e+38 1.3377397e+38 -3.7530971e+37
3.5314601e+37 -1.4393099e+37 -inf -6.0411279e+37
-7.0721061e+37 1.5951782e+38 9.0163464e+37 1.3680580e+37
-1.2254094e+37 1.0919689e+38 -1.5229139e+37 -3.4862508e+36
-8.9739065e+37 2.8713203e+38 9.4768839e+37 7.8658815e+37
-2.6619306e+38 -7.8224467e+37 6.8780734e+37 inf
-9.8889302e+37 -1.9009123e+38 -1.4562352e+38 -4.5324568e+37
-2.6728082e+38 1.0300855e+38 -5.7767852e+37 1.3662499e+37
-4.0048543e+37 -3.1911765e+37 -1.9702732e+38 -6.5395945e+37
1.0223747e+38 -2.8775531e+38 -1.1156091e+38 -1.8772822e+38
1.2472896e+38 1.2465860e+38 -6.7286062e+37 -8.9167649e+37
-2.8327554e+37 -2.7379526e+37 -1.5994879e+37 1.1577176e+38
1.1864721e+38 1.7089999e+38 -1.5323652e+37 -1.5374746e+38
1.2187025e+38 -8.9546139e+37 1.7550813e+38 -5.7048014e+37
-8.5996788e+37 -5.2310546e+36 -1.4450948e+37 -1.9950120e+37
4.2429252e+37 -1.4849557e+38 1.0697206e+38 -7.6313524e+37
-inf 1.7437526e+38 -1.0569269e+38 -1.5577321e+38
-7.8117285e+37 6.4801082e+37 -3.3032475e+37 -6.4655517e+36
-2.3770844e+38 1.0880277e+38 3.6430118e+37 -6.9370110e+37
8.5146681e+37 1.1550550e+38 -2.5614073e+38 -2.1489826e+38
-8.3233807e+37 2.7233982e+37 -1.3777926e+38 -9.6201629e+37
-2.1125345e+38 -1.4252791e+36 3.6633845e+37 2.6106833e+37
9.6643025e+37 -1.4538810e+37 -1.3660478e+38 1.9220696e+38]
1 采用warmup调整一下学习率,最大学习率设置为0.01;
2 采用梯度剪裁方法进行保护;
3 检查最后是否进行归一处理,估计可能取值范围不在0-1之间。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。