当前位置:   article > 正文

图神经网络预训练 (3) - Context Prediction + 监督学习 代码

context prediction

前两篇内容概述了Weihua Hu*, Bowen Liu*图神经网络预训练的方法,以及context prediction进行预训练的实施代码。

context prediction 学习的图内的原子/边信息的表征,并没有包括图层面的信息。

这一部分的监督学习,是图层次的监督学习,目的是把图层面的信息增加到图的表征向量G(h)中。经过图层次的监督学习,得到的模型就可以直接用于下游的任务。

文章方法:在节点层面预训练的模型后加上一个简单的线性模型,用于图层面的监督训练

网络结构如下图:

在文献中,作者的图层面任务的监督学习是多任务学习的方法,使用chembl_filtered数据集。再经过这一层训练以后,往往还加上Fine-tuning,也就是特定任务的训练,例如:BBBP。

但是由于版本问题,chembl_filtered数据集无法加载。所以这里使用esol和lipophilicity等数据集,直接作为Supervised pre-training和Fine-tuning。

以下为代码部分:

一、导入相关包

  1. 导入相关包
  2. import pandas as pd
  3. from tqdm import tqdm
  4. import numpy as np
  5. import os
  6. import math
  7. import random
  8. import torch
  9. import torch.optim as optim
  10. torch.manual_seed(0)
  11. np.random.seed(0)
  12. from rdkit import Chem
  13. from rdkit.Chem import Descriptors
  14. from rdkit.Chem import AllChem
  15. from rdkit import DataStructs
  16. from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
  17. import torch.nn.functional as F
  18. from torch_geometric.data import Data
  19. from torch_geometric.data import DataLoader
  20. from torch_geometric.data import InMemoryDataset
  21. from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
  22. from torch_geometric.utils import add_self_loops, degree, softmax
  23. from torch_geometric.nn.inits import glorot, zeros
  24. from sklearn.model_selection import train_test_split
  25. import seaborn as sns
  26. import matplotlib.pyplot as plt
  27. #运行设备
  28. device = torch.device(torch.device('cuda')if torch.cuda.is_available() else torch.device('cpu'))

数据加载,大部分类似与预训练的时候,但是有标签,放在data.y里面。

data.y = torch.tensor([label])

分子预处理过程和结果与预训练时一致。这里也是定义一个MolecularDataset。读取文件的方式上,有一些变化,因为监督学习的数据格式是csv:

  1. input_path = self.raw_paths[0]
  2. input_df = pd.read_csv(input_path, sep=',', dtype='str')
  3. smiles_list = list(input_df['smiles'])
  4. smiles_id_list = list(input_df.index.values)
  5. y_list = list(input_df['exp'].values)

数据加载部分代码:

  1. #PYG数据集
  2. class MoleculeDataset(InMemoryDataset):
  3. '''
  4. 将zinc数据集加载成PYG的Dataset
  5. '''
  6. def __init__(self, root, dataset='zinc250k',
  7. transform=None, pre_transform=None,
  8. pre_filter=None):
  9. self.dataset = dataset
  10. self.root = root
  11. super(MoleculeDataset, self).__init__(root, transform, pre_transform,
  12. pre_filter) # 要放在后面
  13. print(self.processed_paths[0])
  14. self.data, self.slices = torch.load(self.processed_paths[0])
  15. @property # 返回原始文件列表
  16. def raw_file_names(self):
  17. file_name_list = os.listdir(self.raw_dir)
  18. return file_name_list
  19. @property # 返回需要跳过的文件列表
  20. def processed_file_names(self):
  21. return 'geometric_data_processed.pt'
  22. def process(self):
  23. data_smiles_list = []
  24. data_list = []
  25. input_path = self.raw_paths[0]
  26. input_df = pd.read_csv(input_path, sep=',', dtype='str')
  27. smiles_list = list(input_df['smiles'])
  28. smiles_id_list = list(input_df.index.values)
  29. y_list = list(input_df['exp'].values)
  30. for i in range(len(smiles_list)):
  31. if i % 1000 == 0:
  32. print(str(i) + '...')
  33. s = smiles_list[i]
  34. label = float(y_list[i])
  35. # each example contains a single species
  36. try:
  37. rdkit_mol = AllChem.MolFromSmiles(s, sanitize=True)
  38. if rdkit_mol != None: # ignore invalid mol objects
  39. # # convert aromatic bonds to double bonds
  40. # Chem.SanitizeMol(rdkit_mol,sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)
  41. data = mol_to_graph_data_obj_simple(rdkit_mol)
  42. if 119 in list(data.x[:, 0]):
  43. print(s)
  44. if 4 in list(data.edge_attr[:, 0]):
  45. print(s)
  46. # manually add mol id
  47. id = int(smiles_id_list[i])
  48. data.id = torch.tensor([id])
  49. data.y = torch.tensor([label])
  50. # data.y = torch.tensor([y_list[i]])
  51. # print('NNNNN')
  52. # print(y_list)
  53. data_list.append(data)
  54. data_smiles_list.append(smiles_list[i])
  55. except:
  56. continue
  57. # 过滤器
  58. if self.pre_filter is not None:
  59. data_list = [data for data in data_list if self.pre_filter(data)]
  60. # 转换器,
  61. if self.pre_transform is not None:
  62. data_list = [self.pre_transform(data) for data in data_list]
  63. # write data_smiles_list in processed paths
  64. data_smiles_series = pd.Series(data_smiles_list)
  65. data_smiles_series.to_csv(os.path.join(self.processed_dir,
  66. 'smiles.csv'), index=False,
  67. header=False)
  68. # InMemoryDataset的方法,将 torch_geometric.data.Data的list,转化为内部存储
  69. # 这里设置的保存路径为processedpath[0]
  70. data, slices = self.collate(data_list)
  71. torch.save((data, slices), self.processed_paths[0])
  72. # 显示属性
  73. def __repr__(self):
  74. return '{}()'.format(self.dataname)
  1. #从SMILES生成PYG的数据类型,与预训练过程一致
  2. allowable_features = {
  3. 'possible_atomic_num_list': list(range(1, 119)),
  4. 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
  5. 'possible_chirality_list': [
  6. Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
  7. Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
  8. Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
  9. Chem.rdchem.ChiralType.CHI_OTHER
  10. ],
  11. 'possible_hybridization_list': [
  12. Chem.rdchem.HybridizationType.S,
  13. Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
  14. Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
  15. Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
  16. ],
  17. 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8],
  18. 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
  19. 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
  20. 'possible_bonds': [
  21. Chem.rdchem.BondType.SINGLE,
  22. Chem.rdchem.BondType.DOUBLE,
  23. Chem.rdchem.BondType.TRIPLE,
  24. Chem.rdchem.BondType.AROMATIC
  25. ],
  26. 'possible_bond_dirs': [ # only for double bond stereo information
  27. Chem.rdchem.BondDir.NONE,
  28. Chem.rdchem.BondDir.ENDUPRIGHT,
  29. Chem.rdchem.BondDir.ENDDOWNRIGHT
  30. ]
  31. }
  32. def mol_to_graph_data_obj_simple(mol):
  33. # atoms
  34. num_atom_features = 2 # atom type, chirality tag
  35. atom_features_list = []
  36. for atom in mol.GetAtoms():
  37. atom_feature = [allowable_features['possible_atomic_num_list'].index(
  38. atom.GetAtomicNum())] + [allowable_features[
  39. 'possible_chirality_list'].index(atom.GetChiralTag())]
  40. atom_features_list.append(atom_feature)
  41. x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
  42. # bonds
  43. num_bond_features = 2 # bond type, bond direction
  44. if len(mol.GetBonds()) > 0: # mol has bonds
  45. edges_list = []
  46. edge_features_list = []
  47. for bond in mol.GetBonds():
  48. i = bond.GetBeginAtomIdx()
  49. j = bond.GetEndAtomIdx()
  50. edge_feature = [allowable_features['possible_bonds'].index(
  51. bond.GetBondType())] + [allowable_features[
  52. 'possible_bond_dirs'].index(
  53. bond.GetBondDir())]
  54. edges_list.append((i, j))
  55. edge_features_list.append(edge_feature)
  56. edges_list.append((j, i))
  57. edge_features_list.append(edge_feature)
  58. # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
  59. edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
  60. # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
  61. edge_attr = torch.tensor(np.array(edge_features_list),
  62. dtype=torch.long)
  63. else: # mol has no bonds
  64. edge_index = torch.empty((2, 0), dtype=torch.long)
  65. edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
  66. data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
  67. return data

二、模型

预训练时使用到的GIN层和GIN模型,注意这里一定要与预训练的模型一致, context部分我们使用的是GIN模型,直接加载之前的参数即可,这一部分跳过了。

在预训练GIN模型之后要接一个线性层,组成我们的用于分子性质预测整个模型GNN_graphpred。线性层如下:

  1. #预训练GIN模型与线性层组合成为预测模型
  2. class GNN_graphpred(torch.nn.Module):
  3. '''
  4. 使用预训练相同的结构的gnn,并添加简单的线性层
  5. '''
  6. def __init__(self, pre_model, pre_model_files, graph_pred_linear, drop_ratio=0.05, graph_pooling = "mean", if_pretrain=True):
  7. super(GNN_graphpred, self).__init__()
  8. self.drop_layer = torch.nn.Dropout(p=drop_ratio)
  9. self.gnn = pre_model
  10. self.pre_model_files = pre_model_files
  11. #Different kind of graph pooling
  12. if graph_pooling == "sum":
  13. self.pool = global_add_pool
  14. elif graph_pooling == "mean":
  15. self.pool = global_mean_pool
  16. elif graph_pooling == "max":
  17. self.pool = global_max_pool
  18. elif graph_pooling == "attention":
  19. if self.JK == "concat":
  20. self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
  21. else:
  22. self.pool = GlobalAttention(gate_nn = torch.nn.Linear(emb_dim, 1))
  23. else:
  24. raise ValueError("Invalid graph pooling type.")
  25. self.graph_pred_linear = graph_pred_linear #线性层
  26. #加载预训练模型参数:
  27. if if_pretrain:
  28. self.from_pretrained()
  29. self.gnn = self.gnn.eval() # 预训练模型不在参与训练?
  30. def from_pretrained(self,):
  31. '''
  32. 加载预训练好的参数
  33. '''
  34. #self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio)
  35. self.gnn = torch.load(self.pre_model_files)
  36. self.gnn = self.gnn.eval() #预训练模型部分不参与训练
  37. def forward(self, data):
  38. batch = data.batch
  39. node_representation = self.gnn(data)
  40. result = self.pool(node_representation, batch)
  41. result = self.drop_layer(result)
  42. result = self.graph_pred_linear(result)
  43. return result

三、训练过程

单次epoch的训练函数:

使用with torch.no_grad():对测试集进行预测,避免在迭代过程中,显存逐渐增大。

  1. #单次epcoh训练函数
  2. def train(model, device, loader_train, loader_test, optimizer, criterion):
  3. loss_train = []
  4. r2_train = []
  5. corr_train = []
  6. loss_test = []
  7. r2_test = []
  8. corr_test = []
  9. model.train()
  10. for step, batch in enumerate(tqdm(loader_train, desc="Iteration")):
  11. batch = batch.to(device)
  12. pred = model(batch)
  13. y = batch.y.view(pred.shape).to(torch.float64)
  14. R2 = torch.sum((pred - torch.mean(y))**2) / torch.sum((y - torch.mean(y))**2)
  15. #Whether y is non-null or not.
  16. is_valid = y**2 > 0
  17. #Loss matrix
  18. loss_mat = criterion(pred.double(), (y+1)/2)
  19. #loss matrix after removing null target
  20. loss_mat = torch.where(is_valid, loss_mat, torch.zeros(loss_mat.shape).to(loss_mat.device).to(loss_mat.dtype))
  21. optimizer.zero_grad()
  22. loss = torch.sum(loss_mat)/torch.sum(is_valid) - 0.3 * R2 + 0.3 #添加0.3的R2作为损失
  23. loss.backward()
  24. optimizer.step()
  25. #计算预测值与真实值的R2
  26. pred = pred.detach().cpu().reshape(-1).numpy()
  27. y = y.detach().cpu().reshape(-1).numpy()
  28. # r2 = 1 - np.sum((y - pred)**2) / np.sum((y - np.mean(y))**2)
  29. r2 = np.sum((pred - np.mean(y))**2) / np.sum((y - np.mean(y))**2)
  30. # r2 = r2_score(y, pred)
  31. corr = np.corrcoef(y, pred)[0,1]
  32. loss = loss.detach().cpu().numpy()
  33. loss_train.append(loss)
  34. r2_train.append(r2)
  35. corr_train.append(corr)
  36. with torch.no_grad():
  37. for step, batch in enumerate(tqdm(loader_test, desc="Iteration")):
  38. batch = batch.to(device)
  39. pred = model(batch)
  40. y = batch.y.view(pred.shape).to(torch.float64)
  41. R2 = torch.sum((pred - torch.mean(y))**2) / torch.sum((y - torch.mean(y))**2)
  42. #Whether y is non-null or not.
  43. is_valid = y**2 > 0
  44. #Loss matrix
  45. loss_mat = criterion(pred.double(), (y+1)/2)
  46. #loss matrix after removing null target
  47. loss_mat = torch.where(is_valid, loss_mat, torch.zeros(loss_mat.shape).to(loss_mat.device).to(loss_mat.dtype))
  48. loss = torch.sum(loss_mat)/torch.sum(is_valid) - 0.3 * R2 + 0.3 #添加0.3的R2作为损失
  49. loss_test_ = loss.detach().cpu().numpy()
  50. #计算预测值与真实值的R2
  51. pred = pred.detach().cpu().reshape(-1).numpy()
  52. y = y.detach().cpu().reshape(-1).numpy()
  53. r2_test_ = np.sum((pred - np.mean(y))**2) / np.sum((y - np.mean(y))**2)
  54. # r2_test_ = r2_score(y, pred)
  55. corr_test_ = np.corrcoef(y, pred)[0,1]
  56. loss_test.append(loss_test_)
  57. r2_test.append(r2_test_)
  58. corr_test.append(corr_test_)
  59. l = len(loss_train)
  60. return sum(loss_train)/l, sum(r2_train)/l, sum(corr_train)/l, sum(loss_test)/l, sum(r2_test)/l, sum(corr_test)/l

接下来,就要比较,有预训练和没有预训练的差别,代码如下:

先使用sklearn的train_test_split函数,将监督学习的ESOL等数据集随机划分为训练集和测试集,用于模型性能检测。分别比较预训练和没有预训练的差异。

  1. if __name__ == '__main__':
  2. #训练次数
  3. epoches = 1000
  4. # 划分数据集,训练集和测试集,要注意PYG的数据存储形式
  5. data = pd.read_csv('dataset/lipophilicity/raw/Lipophilicity.csv')
  6. data_train, data_test = train_test_split(data, test_size=0.25, random_state=88)
  7. data_train.to_csv('dataset/lipophilicity/raw/lipophilicity-train.csv',index=False)
  8. data_test.to_csv('dataset/lipophilicity/raw/lipophilicity-test.csv',index=False)
  9. #训练集
  10. dataset_train = MoleculeDataset(root="dataset/lipophilicity", dataset='lipophilicity-train')
  11. loader_train = DataLoader(dataset_train, batch_size=64, shuffle=True, num_workers = 8)
  12. #测试集
  13. dataset_test = MoleculeDataset(root="dataset/lipophilicity", dataset='lipophilicity-test')
  14. loader_test = DataLoader(dataset_test, batch_size=64, shuffle=True, num_workers = 8)
  15. '''
  16. 有预训练条件下
  17. '''
  18. #定义使用预训练GIN模型的模型
  19. pre_model = GNN(7,512) #参数要和预训练的一致,模型结构先实例化一遍
  20. #线性层
  21. linear_model = Graph_pred_linear(512, 256, 1)
  22. #连成新的预测模型
  23. model = GNN_graphpred(pre_model=pre_model, pre_model_files='Context_Pretrain_Gat.pth', graph_pred_linear=linear_model)
  24. model = model.to(device)
  25. #优化器与损失函数
  26. optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4) # 仅训练model的graph_pred_linear层
  27. criterion = torch.nn.MSELoss()
  28. #训练过程
  29. log_loss = []
  30. log_r2 = []
  31. log_corr = []
  32. log_loss_test = []
  33. log_r2_test = []
  34. log_corr_test = []
  35. for epoch in range(1, epoches):
  36. print("====epoch " + str(epoch))
  37. loss, r2, corr, loss_test, r2_test, corr_test = train(model, device, loader_train, loader_test, optimizer, criterion)
  38. log_loss.append(loss)
  39. log_r2.append(r2)
  40. log_corr.append(corr)
  41. log_loss_test.append(loss_test)
  42. log_r2_test.append(r2_test)
  43. log_corr_test.append(corr_test)
  44. print('loss:{:.4f}, r2:{:.4f}, corr:{:.4f}, loss_test:{:.4f}, r2_test:{:.4f}, corr_test:{:.4f}'.format(loss, r2, corr, loss_test, r2_test, corr_test))
  45. #保存整个模型
  46. torch.save(model, "context_pretrian_supervised.pth")
  47. torch.save(model.state_dict(), "context_pretrian_supervised_para.pth")
  48. #保存训练过程
  49. np.save("Context_Supervised_log_train_loss.npy", log_loss)
  50. np.save("Context_Supervised_log_train_corr.npy", log_corr)
  51. np.save("Context_Supervised_log_train_r2.npy", log_r2)
  52. np.save("Context_Supervised_log_train_loss_test.npy", log_loss_test)
  53. np.save("Context_Supervised_log_train_corr_test.npy", log_corr_test)
  54. np.save("Context_Supervised_log_train_r2_test.npy", log_r2_test)
  55. #对测试集的预测
  56. y_all = []
  57. y_pred_all = []
  58. for step, batch in enumerate(loader_test):
  59. batch = batch.to(device)
  60. pred = model(batch)
  61. y = batch.y.view(pred.shape).to(torch.float64)
  62. pred = list(pred.detach().cpu().reshape(-1).numpy())
  63. y = list(y.detach().cpu().reshape(-1).numpy())
  64. y_all = y_all + y
  65. y_pred_all = y_pred_all + pred
  66. sns.regplot(y_all, y_pred_all, label='pretrain')
  67. plt.ylabel('y true')
  68. plt.xlabel('predicted')
  69. plt.legend()
  70. plt.savefig('Context_Supervised_Test_curve.png') #保存图片
  71. plt.cla()
  72. plt.clf()
  73. '''
  74. 没有预训练的条件下
  75. '''
  76. pre_model = GNN(7,512) #参数要和预训练的一致
  77. #线性层
  78. linear_model = Graph_pred_linear(512, 256, 1)
  79. #连成新的模型
  80. model = GNN_graphpred(pre_model=pre_model, pre_model_files='Context_Pretrain_Gat.pth',
  81. graph_pred_linear=linear_model, if_pretrain=False) # if_pretrain控制不使用预训练的权重
  82. model = model.to(device)
  83. optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
  84. criterion = torch.nn.MSELoss()
  85. un_log_loss = []
  86. un_log_r2 = []
  87. un_log_corr = []
  88. un_log_loss_test = []
  89. un_log_r2_test = []
  90. un_log_corr_test = []
  91. for epoch in range(1, epoches):
  92. print("====epoch " + str(epoch))
  93. loss, r2, corr, loss_test, r2_test, corr_test = train(model, device, loader_train, loader_test, optimizer, criterion)
  94. un_log_loss.append(loss)
  95. un_log_r2.append(r2)
  96. un_log_corr.append(corr)
  97. un_log_loss_test.append(loss_test)
  98. un_log_r2_test.append(r2_test)
  99. un_log_corr_test.append(corr_test)
  100. print('loss:{:.4f}, r2:{:.4f}, corr:{:.4f}, loss_test:{:.4f}, r2_test:{:.4f}, corr_test:{:.4f}'.format(loss, r2, corr, loss_test, r2_test, corr_test))
  101. #对测试集的预测
  102. y_all = []
  103. y_pred_all = []
  104. for step, batch in enumerate(loader_test):
  105. batch = batch.to(device)
  106. pred = model(batch)
  107. y = batch.y.view(pred.shape).to(torch.float64)
  108. pred = list(pred.detach().cpu().reshape(-1).numpy())
  109. y = list(y.detach().cpu().reshape(-1).numpy())
  110. y_all = y_all + y
  111. y_pred_all = y_pred_all + pred
  112. sns.regplot(y_all, y_pred_all, label='unpretrain')
  113. plt.ylabel('y true')
  114. plt.xlabel('predicted')
  115. plt.legend()
  116. plt.savefig('Derectly_Supervised_Test_curve.png') #保存图片
  117. plt.cla()
  118. plt.clf()
  119. '''
  120. 保存图片,比较有预训练和没有预训练的差距
  121. '''
  122. plt.figure(figsize=(15,6))
  123. plt.subplot(1,3,1)
  124. plt.plot(log_loss, label='loss')
  125. plt.plot(log_loss_test, label='loss_test')
  126. plt.plot(un_log_loss, label='unpretrain_loss')
  127. plt.plot(un_log_loss_test, label='unpretrain_loss_test')
  128. plt.xlabel('Epoch')
  129. plt.ylabel('MSE Loss')
  130. plt.legend()
  131. plt.subplot(1,3,2)
  132. plt.plot(log_corr, label='corr')
  133. plt.plot(log_corr_test, label='corr_test')
  134. plt.plot(un_log_corr, label='unpretrain_corr')
  135. plt.plot(un_log_corr_test, label='unpretrain_corr_test')
  136. plt.xlabel('Epoch')
  137. plt.ylabel('Corr')
  138. plt.legend()
  139. plt.subplot(1,3,3)
  140. plt.plot(log_r2[1:], label='r2')
  141. plt.plot(log_r2_test[1:], label='r2_test')
  142. plt.plot(un_log_r2[1:], label='unpretrain_r2')
  143. plt.plot(un_log_r2_test[1:], label='unpretrain_r2_test')
  144. plt.ylim(0,1)
  145. plt.xlabel('Epoch')
  146. plt.ylabel('R2')
  147. plt.legend()
  148. plt.savefig('Comversion_Train_process.png')

结果文件中,

Comversion_Train_process.png:损失函数、相关系数、R2的对比;

Derectly_Supervised_Test_curve.png:不经过预训练,直接从头训练的最后拟合曲线;

Context_Supervised_Test_curve.png:预训练模型,最后的拟合曲线;

*.pth:模型。

四、结果

使用context预训练的GIN模型esol数据集,100 epochs:

没有使用预训练的GIN模型esol数据集:

从结果来看,不管是从训练集还是测试集的loss或者相关系数来看,context预训练的结果很明显。在下图500个循环中,也很明显。

 在Lipophilicity数据集上效果也是很明显,如下图:

 

 在下图1000个循环中,也很明显。说明经过预训练可以减少训练的迭代次数,减少过拟合。

下图中左为预训练模型的测试集拟合曲线,右图为未预训练模型的测试集拟合曲线。可以看出,经过预训练以后,模型性能确实得到了较大的提高。

目前存在的问题是:相关系数已经很高了,高达0.99,但是R2却只有0.5左右。所以,我们考虑将损失函数中,添加少量的R2,添加比例为0.3,这里建议不超过0.5,否则一开始的R2就会接接近1,在运行过程中,R2和corr的波动也会很大。结果如下:

 

从上图结果来看,预训练还是有效果的。如果迭代次数很多,效果不会非常明显,模型性能提升有限。

五、源代码下载

链接:https://pan.baidu.com/s/14cxHjU2zwzkqPfwwfuSx0Q 
提取码:795y

 

 

 

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号